#!/usr/bin/python
# Copyright (c) 2008 Alon Swartz <alon@turnkeylinux.org> - all rights reserved

"""
Configure MySQL (sets MySQL password and optionally executes query)

Options:
    -u --user=    mysql username (default: root)
    -p --pass=    unless provided, will ask interactively via debconf

    --query=      optional query to execute after setting password
    --chroot=     path to chroot of the mysql daemon (default: /target)
                  note: to disable chrooting, specify --chroot=/

"""

import re
import sys
import time
import getopt

import common
from common import ExecError

DEBIAN_CNF = "/etc/mysql/debian.cnf"

class Error(Exception):
    pass

def escape_chars(s):
    """escape special characters: required by nested quotes in query"""
    s = s.replace("\\", "\\\\")  # \  ->  \\
    s = s.replace('"', '\\"')    # "  ->  \"
    s = s.replace("'", "'\\''")  # '  ->  '\''
    return s

class MySQL:
    def __init__(self, chroot_path=None):
        if chroot_path:
            chroot = common.Chroot(chroot_path)
            self.system = chroot.system
        else:
            self.system = common.system

        self.system("mkdir -p /var/run/mysqld")
        self.system("chown mysql:root /var/run/mysqld")

        self.selfstarted = False
        if not self._is_alive():
            self._start()
            self.selfstarted = True

    def _is_alive(self):
        try:
            self.system('mysqladmin -s ping >/dev/null 2>&1')
        except ExecError:
            return False

        return True

    def _start(self):
        self.system("mysqld --skip-networking >/dev/null 2>&1 &")
        for i in range(6):
            if self._is_alive():
                return

            time.sleep(1)

        raise Error("could not start mysqld")

    def _stop(self):
        if self.selfstarted:
            self.system("mysqladmin --defaults-file=%s shutdown" % DEBIAN_CNF)

    def __del__(self):
        self._stop()

    def execute(self, query):
        self.system("mysql --defaults-file=%s -B -e '%s'" % (DEBIAN_CNF, query))

def usage(s=None):
    if s:
        print >> sys.stderr, "Error:", s
    print >> sys.stderr, "Syntax: %s [options]" % sys.argv[0]
    print >> sys.stderr, __doc__
    sys.exit(1)

def main():
    try:
        opts, args = getopt.gnu_getopt(sys.argv[1:], "hu:p:",
                     ['help', 'user=', 'pass=', 'query=', 'chroot='])

    except getopt.GetoptError, e:
        usage(e)

    username="root"
    password=""
    chroot_path="/target"
    queries=[]

    for opt, val in opts:
        if opt in ('-h', '--help'):
            usage()
        elif opt in ('-u', '--user'):
            username = val
        elif opt in ('-p', '--pass'):
            password = val
        elif opt in ('--query'):
            queries.append(val)
        elif opt in ('--chroot'):
            chroot_path = val

    if chroot_path == '/':
        chroot_path = None
    elif not common.target_mounted(chroot_path):
        sys.exit(10) #return to menu

    if not password:
        db = common.Debconf()
        password = db.get_password('New password for the MySQL "%s" user' % username)

    m = MySQL(chroot_path)

    # set password
    m.execute('update mysql.user set Password=PASSWORD(\"%s\") where User=\"%s\"; flush privileges;' % (escape_chars(password), username))

    # edge case: update DEBIAN_CNF
    if username == "debian-sys-maint":
        old = file(DEBIAN_CNF).read()
        new = re.sub("password = (.*)\n", "password = %s\n" % password, old)
        file(DEBIAN_CNF, "w").write(new)

    # execute any adhoc specified queries
    for query in queries:
        m.execute(query)

if __name__ == "__main__":
    main()

