diff options
Diffstat (limited to 'ca')
-rwxr-xr-x | ca/rpki-sql-setup | 466 |
1 files changed, 225 insertions, 241 deletions
diff --git a/ca/rpki-sql-setup b/ca/rpki-sql-setup index 297571a2..e282f887 100755 --- a/ca/rpki-sql-setup +++ b/ca/rpki-sql-setup @@ -18,289 +18,273 @@ # TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR # PERFORMANCE OF THIS SOFTWARE. +""" +Automated setup of SQL stuff used by the RPKI tools. Pulls +configuration from rpki.conf, prompts for SQL password when needed. +""" + import os +import pwd import sys -import glob import getpass import argparse -import datetime import rpki.config -import rpki.version -import rpki.autoconf -# This program implements its own schema versioning system as a poor -# substitute for schema migrations. Now that we're moving to Django -# ORM, this is pretty much useless, and should be removed at some point. +class Abstract_Driver(object): -from rpki.mysql_import import MySQLdb, _mysql_exceptions + # Kludge to make classes derived from this into singletons. Net + # of a Million Lies says this is Not Pythonic, but it seems to + # work, so long as one doesn't attempt to subclass the resulting + # driver classes. For our purposes, it will do. -ER_NO_SUCH_TABLE = 1146 # See mysqld_ername.h + __instance = None + def __new__(cls, *args, **kwargs): + if cls.__instance is None: + cls.__instance = object.__new__(cls, *args, **kwargs) + return cls.__instance -class RootDB(object): - """ - Class to wrap MySQL actions that require root-equivalent access so - we can defer such actions until we're sure they're really needed. - Overall goal here is to prompt the user for the root password once - at most, and not at all when not necessary. - """ - def __init__(self, mysql_defaults = None): - self.initialized = False - self.mysql_defaults = mysql_defaults - - def __getattr__(self, name): - # pylint: disable=W0201 - if self.initialized: - raise AttributeError - if self.mysql_defaults is None: - self.db = MySQLdb.connect(db = "mysql", - user = "root", - passwd = getpass.getpass("Please enter your MySQL root password: ")) +class MySQL_Driver(Abstract_Driver): + + _initialized = False + + def __init__(self, args): + from rpki.mysql_import import MySQLdb + self.driver = MySQLdb + self.args = args + + def _initialize(self): + if not self._initialized: + if self.args.mysql_defaults: + mysql_cfg = rpki.config.parser(set_filename = self.args.mysql_defaults, section = "client") + self._db = self.driver.connect(db = "mysql", + user = mysql_cfg.get("user"), + passwd = mysql_cfg.get("password")) + else: + self._db = self.driver.connect(db = "mysql", + user = "root", + passwd = getpass.getpass("Please enter your MySQL root password: ")) + self._cur = self._db.cursor() + self._initialized = True + + def _accessible(self, udb): + try: + self.driver.connect(db = udb.database, user = udb.username, passwd = udb.password).close() + except: + return False else: - mysql_cfg = rpki.config.parser(set_filename = self.mysql_defaults, section = "client") - self.db = MySQLdb.connect(db = "mysql", - user = mysql_cfg.get("user"), - passwd = mysql_cfg.get("password")) - self.cur = self.db.cursor() - self.cur.execute("SHOW DATABASES") - self.databases = set(d[0] for d in self.cur.fetchall()) - self.initialized = True - return getattr(self, name) + return True - def close(self): - if self.initialized: - self.db.close() # pylint: disable=E1101 + def _grant(self, udb): + self._cur.execute("GRANT ALL ON {0.database}.* TO {0.username}@localhost IDENTIFIED BY %s".format(udb), + (udb.password,)) + def create(self, udb): + if args.force or not self._accessible(udb): + self._initialize() + self._cur.execute("CREATE DATABASE IF NOT EXISTS {0.database}".format(udb)) + self._grant(udb) + self._db.commit() -class UserDB(object): - """ - Class to wrap MySQL access parameters for a particular database. + def drop(self, udb): + if args.force or self._accessible(udb): + self._initialize() + self._cur.execute("DROP DATABASE IF EXISTS {0.database}".format(udb)) + self._db.commit() - NB: The SQL definitions for the upgrade_version table is embedded in - this class rather than being declared in any of the .sql files. - This is deliberate: nothing but the upgrade system should ever touch - this table, and it's simpler to keep everything in one place. + def script_drop(self, udb): + self.args.script_output.write("DROP DATABASE IF EXISTS {};\n".format(udb.database)) - We have to be careful about SQL commits here, because CREATE TABLE - implies an automatic commit. So presence of the magic table per se - isn't significant, only its content (or lack thereof). - """ + def fix_grants(self, udb): + if args.force or not self._accessible(udb): + self._grant(udb) + self._db.commit() - upgrade_version_table_schema = """ - CREATE TABLE upgrade_version ( - version TEXT NOT NULL, - updated DATETIME NOT NULL - ) ENGINE=InnoDB - """ - def __init__(self, name): - self.name = name - self.database = cfg.get("sql-database", section = name) - self.username = cfg.get("sql-username", section = name) - self.password = cfg.get("sql-password", section = name) - self.db = None - self.cur = None - - def open(self): - self.db = MySQLdb.connect(db = self.database, user = self.username, passwd = self.password) - self.db.autocommit(False) - self.cur = self.db.cursor() - - def close(self): - if self.cur is not None: - self.cur.close() - self.cur = None - if self.db is not None: - # pylint: disable=E1101 - self.db.commit() - self.db.close() - self.db = None - - @property - def exists_and_accessible(self): - # pylint: disable=E1101 +class SQLite3_Driver(Abstract_Driver): + + def __init__(self, args): + import sqlite3 + self.driver = sqlite3 + self.args = args + self.can_chown = os.getuid() == 0 or os.geteuid() == 0 + + def _accessible(self, udb): try: - MySQLdb.connect(db = self.database, user = self.username, passwd = self.password).close() + self.driver.connect(udb.database).close() except: return False else: return True - @property - def version(self): + def _grant(self, udb): + if self.can_chown and udb.username: + pw = pwd.getpwnam(udb.username) + os.chown(udb.database, pw.pw_uid, pw.pw_gid) + + def create(self, udb): + if args.force or not self._accessible(udb): + self.driver.connect(udb.database).close() + self._grant(udb) + + def drop(self, udb): + if args.force or self._accessible(udb): + os.unlink(udb.database) + + def script_drop(self, udb): + pass + + def fix_grants(self, udb): + if args.force or not self._accessible(udb): + self._grant(udb) + + +class PostgreSQL_Driver(Abstract_Driver): + + def __init__(self, args): + import psycopg2 + self.driver = psycopg2 + self.args = args + if args.postgresql_root_username and (os.getuid() == 0 or os.geteuid() == 0): + self._pw = pwd.getpwnam(args.postgresql_root_username) + else: + self._pw = None + + def _execute(*sql_commands): + pid = None if self._pw is None else os.fork() + if pid == 0: + os.setgid(pw.pw_gid) + os.setuid(pw.pw_uid) + if not pid: + with self.driver.connect(database = self.args.postgresql_root_database) as db: + with db.cursor() as cur: + for sql_command in sql_commands: + cur.execute(command) + if pid == 0: + os._exit(0) + if pid: + os.waitpid(pid, 0) + + def _accessible(self, udb): try: - self.cur.execute("SELECT version FROM upgrade_version") - v = self.cur.fetchone() - return Version(None if v is None else v[0]) - except _mysql_exceptions.ProgrammingError, e: - if e.args[0] != ER_NO_SUCH_TABLE: - raise - log("Creating upgrade_version table in %s" % self.name) - self.cur.execute(self.upgrade_version_table_schema) - return Version(None) - - @version.setter - def version(self, v): - if v > self.version: - self.cur.execute("DELETE FROM upgrade_version") - self.cur.execute("INSERT upgrade_version (version, updated) VALUES (%s, %s)", (v, datetime.datetime.now())) - self.db.commit() # pylint: disable=E1101 - log("Updated %s to %s" % (self.name, v)) - - - -class Version(object): + self.driver.connect(database = udb.database, user = udb.username , password = usb.password).close() + except: + return False + else: + return True + + def create(self, udb): + if args.force or not self._accessible(udb): + # + # CREATE ROLE doesn't take a IF NOT EXISTS modifier, but we can fake it using plpgsql. + # http://stackoverflow.com/questions/8092086/create-postgresql-role-user-if-it-doesnt-exist + # + self._execute(''' + DO $$ BEGIN + IF NOT EXISTS (SELECT * FROM pg_catalog.pg_user WHERE usename = '{0.username}') THEN + CREATE ROLE {0.username} LOGIN PASSWORD '{0.password}'; + END IF; + END $$ + '''.format(udb), + "CREATE DATABASE IF NOT EXISTS {0.database} OWNER {0.username}".format(udb)) + + def drop(self, udb): + if args.force or self._accessible(udb): + self._execute("DROP DATABASE IF EXISTS {0.database}".format(udb)) + + def script_drop(self, udb): + self.args.script_output.write("DROP DATABASE IF EXISTS {};\n".format(udb.database)) + + def fix_grants(self, udb): + if args.force or not self._accessible(udb): + self._execute("ALTER DATABASE {0.database} OWNER TO {0.username}".format(udb), + "ALTER ROLE {0.username} WITH PASSWORD '{0.password}".format(udb)) + + +class UserDB(object): """ - A version number. This is a class in its own right to force the - comparision and string I/O behavior we want. + Class to wrap access parameters for a particular database. """ - def __init__(self, v): - if v is None: - v = "0.0" - self.v = tuple(v.lower().split(".")) + drivers = dict(sqlite3 = SQLite3_Driver, + mysql = MySQL_Driver, + postgresql = PostgreSQL_Driver) - def __str__(self): - return ".".join(self.v) + def __init__(self, args, name): + self.database = cfg.get("sql-database", section = name) + self.username = cfg.get("sql-username", section = name) + self.password = cfg.get("sql-password", section = name) + self.engine = cfg.get("sql-engine", section = name) + self.driver = self.drivers[self.engine](args) - def __cmp__(self, other): - return cmp(self.v, other.v) + def drop(self): + self.driver.drop(self) + def create(self): + self.driver.create(self) -class Upgrade(object): - """ - One upgrade script. Really, just its filename and the Version - object we parse from its filename, we don't need to read the script - itself except when applying it, but we do need to sort all the - available upgrade scripts into version order. - """ + def script_drop(self): + self.driver.script_drop(self) - @classmethod - def load_all(cls, name, dn): - g = os.path.join(dn, "upgrade-%s-to-*.py" % name) - for fn in glob.iglob(g): - yield cls(g, fn) - - def __init__(self, g, fn): - head, sep, tail = g.partition("*") # pylint: disable=W0612 - self.fn = fn - self.version = Version(fn[len(head):-len(tail)]) - - def __cmp__(self, other): - return cmp(self.version, other.version) - - def apply(self, db): - # db is an argument here primarily so the script we exec can get at it - log("Applying %s to %s" % (self.fn, db.name)) - with open(self.fn, "r") as f: - exec f # pylint: disable=W0122 - - -def do_drop(name): - db = UserDB(name) - if db.database in root.databases: - log("DROP DATABASE %s" % db.database) - root.cur.execute("DROP DATABASE %s" % db.database) - root.db.commit() # pylint: disable=E1101 - -def do_create(name): - db = UserDB(name) - log("CREATE DATABASE %s" % db.database) - root.cur.execute("CREATE DATABASE %s" % db.database) - log("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY ###" % (db.database, db.username)) - root.cur.execute("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY %%s" % (db.database, db.username), - (db.password,)) - root.db.commit() # pylint: disable=E1101 - db.open() - db.version = current_version - db.close() - -def do_script_drop(name): - db = UserDB(name) - print "DROP DATABASE IF EXISTS %s;" % db.database - -def do_drop_and_create(name): - do_drop(name) - do_create(name) - -def do_fix_grants(name): - db = UserDB(name) - if not db.exists_and_accessible: - log("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY ###" % (db.database, db.username)) - root.cur.execute("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY %%s" % (db.database, db.username), - (db.password,)) - root.db.commit() # pylint: disable=E1101 - -def do_create_if_missing(name): - db = UserDB(name) - if not db.exists_and_accessible: - do_create(name) - -def do_apply_upgrades(name): - upgrades = sorted(Upgrade.load_all(name, args.upgrade_scripts)) - if upgrades: - db = UserDB(name) - db.open() - log("Current version of %s is %s" % (db.name, db.version)) - for upgrade in upgrades: - if upgrade.version > db.version: - upgrade.apply(db) - db.version = upgrade.version - db.version = current_version - db.close() - -def log(text): - if args.verbose: - print "#", text - -parser = argparse.ArgumentParser(description = """\ -Automated setup of all SQL stuff used by the RPKI CA tools. Pulls -configuration from rpki.conf, prompts for MySQL password when needed. -""") -group = parser.add_mutually_exclusive_group() + def drop_and_create(self): + self.driver.drop(self) + self.driver.create(self) + + def fix_grants(self): + self.driver.fix_grants(self) + + +parser = argparse.ArgumentParser(description = __doc__) parser.add_argument("-c", "--config", help = "specify alternate location for rpki.conf") -parser.add_argument("-v", "--verbose", action = "store_true", +parser.add_argument("-d", "--debug", action = "store_true", + help = "enable debugging (eg, Python backtraces)") +parser.add_argument("-v", "--verbose", action = "store_true", help = "whistle while you work") +parser.add_argument("-f", "--force", action = "store_true", + help = "force database create, drop, or grant regardless of current state") + parser.add_argument("--mysql-defaults", help = "specify MySQL root access credentials via a configuration file") -parser.add_argument("--upgrade-scripts", - default = os.path.join(rpki.autoconf.datarootdir, "rpki", "upgrade-scripts"), - help = "override default location of upgrade scripts") -group.add_argument("--create", - action = "store_const", dest = "dispatch", const = do_create, - help = "create databases and load schemas") -group.add_argument("--drop", - action = "store_const", dest = "dispatch", const = do_drop, - help = "drop databases") -group.add_argument("--script-drop", - action = "store_const", dest = "dispatch", const = do_script_drop, - help = "send SQL commands to drop databases to standard output") -group.add_argument("--drop-and-create", - action = "store_const", dest = "dispatch", const = do_drop_and_create, - help = "drop databases then recreate them and load schemas") -group.add_argument("--fix-grants", - action = "store_const", dest = "dispatch", const = do_fix_grants, - help = "whack database access to match current configuration file") -group.add_argument("--create-if-missing", - action = "store_const", dest = "dispatch", const = do_create_if_missing, - help = "create databases and load schemas if they don't exist already") -group.add_argument("--apply-upgrades", - action = "store_const", dest = "dispatch", const = do_apply_upgrades, - help = "apply upgrade scripts to existing databases") -parser.set_defaults(dispatch = do_create_if_missing) + + +parser.add_argument("--postgresql-root-database", default = "postgres", + help = "name of PostgreSQL control database") +parser.add_argument("--postgresql-root-username", + help = "username of PostgreSQL control role") + +subparsers = parser.add_subparsers(title = "Commands", metavar = "", dest = "dispatch") + +subparsers.add_parser("create", + help = "create databases and load schemas") + +subparsers.add_parser("drop", + help = "drop databases") + +subparser = subparsers.add_parser("script-drop", + help = "show SQL commands to drop databases") +subparser.add_argument("script_output", + nargs = "?", type = argparse.FileType("w"), default = "-", + help = "destination for drop script") + +subparsers.add_parser("drop-and-create", + help = "drop databases then recreate them and load schemas") + +subparsers.add_parser("fix-grants", + help = "whack database to match configuration file") + args = parser.parse_args() try: cfg = rpki.config.parser(set_filename = args.config, section = "myrpki") - root = RootDB(args.mysql_defaults) - current_version = Version(rpki.version.VERSION) - for program_name in ("irdbd", "rpkid", "pubd"): - if cfg.getboolean("start_" + program_name, False): - args.dispatch(program_name) - root.close() + names = [name for name in ("irdbd", "rpkid", "pubd") if cfg.getboolean("start_" + name, False)] + names.append("rcynic") + for name in names: + getattr(UserDB(args = args, name = name), args.dispatch.replace("-", "_"))() except Exception, e: - #raise - sys.exit(str(e)) + if args.debug: + raise + else: + sys.exit(str(e)) |