aboutsummaryrefslogtreecommitdiff
path: root/ca/rpki-sql-setup
diff options
context:
space:
mode:
Diffstat (limited to 'ca/rpki-sql-setup')
-rwxr-xr-xca/rpki-sql-setup378
1 files changed, 189 insertions, 189 deletions
diff --git a/ca/rpki-sql-setup b/ca/rpki-sql-setup
index b209fec7..c399fcfe 100755
--- a/ca/rpki-sql-setup
+++ b/ca/rpki-sql-setup
@@ -39,218 +39,218 @@ ER_NO_SUCH_TABLE = 1146 # See mysqld_ername.h
class RootDB(object):
- """
- Class to wrap MySQL actions that require root-equivalent access so
- we can defer such actions until we're sure they're really needed.
- Overall goal here is to prompt the user for the root password once
- at most, and not at all when not necessary.
- """
-
- def __init__(self, mysql_defaults = None):
- self.initialized = False
- self.mysql_defaults = mysql_defaults
-
- def __getattr__(self, name):
- if self.initialized:
- raise AttributeError
- if self.mysql_defaults is None:
- self.db = MySQLdb.connect(db = "mysql",
- user = "root",
- passwd = getpass.getpass("Please enter your MySQL root password: "))
- else:
- mysql_cfg = rpki.config.parser(set_filename = self.mysql_defaults, section = "client")
- self.db = MySQLdb.connect(db = "mysql",
- user = mysql_cfg.get("user"),
- passwd = mysql_cfg.get("password"))
- self.cur = self.db.cursor()
- self.cur.execute("SHOW DATABASES")
- self.databases = set(d[0] for d in self.cur.fetchall())
- self.initialized = True
- return getattr(self, name)
-
- def close(self):
- if self.initialized:
- self.db.close()
+ """
+ Class to wrap MySQL actions that require root-equivalent access so
+ we can defer such actions until we're sure they're really needed.
+ Overall goal here is to prompt the user for the root password once
+ at most, and not at all when not necessary.
+ """
+
+ def __init__(self, mysql_defaults = None):
+ self.initialized = False
+ self.mysql_defaults = mysql_defaults
+
+ def __getattr__(self, name):
+ if self.initialized:
+ raise AttributeError
+ if self.mysql_defaults is None:
+ self.db = MySQLdb.connect(db = "mysql",
+ user = "root",
+ passwd = getpass.getpass("Please enter your MySQL root password: "))
+ else:
+ mysql_cfg = rpki.config.parser(set_filename = self.mysql_defaults, section = "client")
+ self.db = MySQLdb.connect(db = "mysql",
+ user = mysql_cfg.get("user"),
+ passwd = mysql_cfg.get("password"))
+ self.cur = self.db.cursor()
+ self.cur.execute("SHOW DATABASES")
+ self.databases = set(d[0] for d in self.cur.fetchall())
+ self.initialized = True
+ return getattr(self, name)
+
+ def close(self):
+ if self.initialized:
+ self.db.close()
class UserDB(object):
- """
- Class to wrap MySQL access parameters for a particular database.
-
- NB: The SQL definitions for the upgrade_version table is embedded in
- this class rather than being declared in any of the .sql files.
- This is deliberate: nothing but the upgrade system should ever touch
- this table, and it's simpler to keep everything in one place.
-
- We have to be careful about SQL commits here, because CREATE TABLE
- implies an automatic commit. So presence of the magic table per se
- isn't significant, only its content (or lack thereof).
- """
-
- upgrade_version_table_schema = """
- CREATE TABLE upgrade_version (
- version TEXT NOT NULL,
- updated DATETIME NOT NULL
- ) ENGINE=InnoDB
+ """
+ Class to wrap MySQL access parameters for a particular database.
+
+ NB: The SQL definitions for the upgrade_version table is embedded in
+ this class rather than being declared in any of the .sql files.
+ This is deliberate: nothing but the upgrade system should ever touch
+ this table, and it's simpler to keep everything in one place.
+
+ We have to be careful about SQL commits here, because CREATE TABLE
+ implies an automatic commit. So presence of the magic table per se
+ isn't significant, only its content (or lack thereof).
"""
- def __init__(self, name):
- self.name = name
- self.database = cfg.get("sql-database", section = name)
- self.username = cfg.get("sql-username", section = name)
- self.password = cfg.get("sql-password", section = name)
- self.db = None
- self.cur = None
-
- def open(self):
- self.db = MySQLdb.connect(db = self.database, user = self.username, passwd = self.password)
- self.db.autocommit(False)
- self.cur = self.db.cursor()
-
- def close(self):
- if self.cur is not None:
- self.cur.close()
- self.cur = None
- if self.db is not None:
- self.db.commit()
- self.db.close()
- self.db = None
-
- @property
- def exists_and_accessible(self):
- try:
- MySQLdb.connect(db = self.database, user = self.username, passwd = self.password).close()
- except: # pylint: disable=W0702
- return False
- else:
- return True
-
- @property
- def version(self):
- try:
- self.cur.execute("SELECT version FROM upgrade_version")
- v = self.cur.fetchone()
- return Version(None if v is None else v[0])
- except _mysql_exceptions.ProgrammingError, e:
- if e.args[0] != ER_NO_SUCH_TABLE:
- raise
- log("Creating upgrade_version table in %s" % self.name)
- self.cur.execute(self.upgrade_version_table_schema)
- return Version(None)
-
- @version.setter
- def version(self, v):
- if v > self.version:
- self.cur.execute("DELETE FROM upgrade_version")
- self.cur.execute("INSERT upgrade_version (version, updated) VALUES (%s, %s)", (v, datetime.datetime.now()))
- self.db.commit()
- log("Updated %s to %s" % (self.name, v))
+ upgrade_version_table_schema = """
+ CREATE TABLE upgrade_version (
+ version TEXT NOT NULL,
+ updated DATETIME NOT NULL
+ ) ENGINE=InnoDB
+ """
+
+ def __init__(self, name):
+ self.name = name
+ self.database = cfg.get("sql-database", section = name)
+ self.username = cfg.get("sql-username", section = name)
+ self.password = cfg.get("sql-password", section = name)
+ self.db = None
+ self.cur = None
+
+ def open(self):
+ self.db = MySQLdb.connect(db = self.database, user = self.username, passwd = self.password)
+ self.db.autocommit(False)
+ self.cur = self.db.cursor()
+
+ def close(self):
+ if self.cur is not None:
+ self.cur.close()
+ self.cur = None
+ if self.db is not None:
+ self.db.commit()
+ self.db.close()
+ self.db = None
+
+ @property
+ def exists_and_accessible(self):
+ try:
+ MySQLdb.connect(db = self.database, user = self.username, passwd = self.password).close()
+ except: # pylint: disable=W0702
+ return False
+ else:
+ return True
+
+ @property
+ def version(self):
+ try:
+ self.cur.execute("SELECT version FROM upgrade_version")
+ v = self.cur.fetchone()
+ return Version(None if v is None else v[0])
+ except _mysql_exceptions.ProgrammingError, e:
+ if e.args[0] != ER_NO_SUCH_TABLE:
+ raise
+ log("Creating upgrade_version table in %s" % self.name)
+ self.cur.execute(self.upgrade_version_table_schema)
+ return Version(None)
+
+ @version.setter
+ def version(self, v):
+ if v > self.version:
+ self.cur.execute("DELETE FROM upgrade_version")
+ self.cur.execute("INSERT upgrade_version (version, updated) VALUES (%s, %s)", (v, datetime.datetime.now()))
+ self.db.commit()
+ log("Updated %s to %s" % (self.name, v))
class Version(object):
- """
- A version number. This is a class in its own right to force the
- comparision and string I/O behavior we want.
- """
+ """
+ A version number. This is a class in its own right to force the
+ comparision and string I/O behavior we want.
+ """
- def __init__(self, v):
- if v is None:
- v = "0.0"
- self.v = tuple(v.lower().split("."))
+ def __init__(self, v):
+ if v is None:
+ v = "0.0"
+ self.v = tuple(v.lower().split("."))
- def __str__(self):
- return ".".join(self.v)
+ def __str__(self):
+ return ".".join(self.v)
- def __cmp__(self, other):
- return cmp(self.v, other.v)
+ def __cmp__(self, other):
+ return cmp(self.v, other.v)
class Upgrade(object):
- """
- One upgrade script. Really, just its filename and the Version
- object we parse from its filename, we don't need to read the script
- itself except when applying it, but we do need to sort all the
- available upgrade scripts into version order.
- """
-
- @classmethod
- def load_all(cls, name, dn):
- g = os.path.join(dn, "upgrade-%s-to-*.py" % name)
- for fn in glob.iglob(g):
- yield cls(g, fn)
-
- def __init__(self, g, fn):
- head, sep, tail = g.partition("*") # pylint: disable=W0612
- self.fn = fn
- self.version = Version(fn[len(head):-len(tail)])
-
- def __cmp__(self, other):
- return cmp(self.version, other.version)
-
- def apply(self, db):
- # db is an argument here primarily so the script we exec can get at it
- log("Applying %s to %s" % (self.fn, db.name))
- with open(self.fn, "r") as f:
- exec f # pylint: disable=W0122
+ """
+ One upgrade script. Really, just its filename and the Version
+ object we parse from its filename, we don't need to read the script
+ itself except when applying it, but we do need to sort all the
+ available upgrade scripts into version order.
+ """
+
+ @classmethod
+ def load_all(cls, name, dn):
+ g = os.path.join(dn, "upgrade-%s-to-*.py" % name)
+ for fn in glob.iglob(g):
+ yield cls(g, fn)
+
+ def __init__(self, g, fn):
+ head, sep, tail = g.partition("*") # pylint: disable=W0612
+ self.fn = fn
+ self.version = Version(fn[len(head):-len(tail)])
+
+ def __cmp__(self, other):
+ return cmp(self.version, other.version)
+
+ def apply(self, db):
+ # db is an argument here primarily so the script we exec can get at it
+ log("Applying %s to %s" % (self.fn, db.name))
+ with open(self.fn, "r") as f:
+ exec f # pylint: disable=W0122
def do_drop(name):
- db = UserDB(name)
- if db.database in root.databases:
- log("DROP DATABASE %s" % db.database)
- root.cur.execute("DROP DATABASE %s" % db.database)
- root.db.commit()
+ db = UserDB(name)
+ if db.database in root.databases:
+ log("DROP DATABASE %s" % db.database)
+ root.cur.execute("DROP DATABASE %s" % db.database)
+ root.db.commit()
def do_create(name):
- db = UserDB(name)
- log("CREATE DATABASE %s" % db.database)
- root.cur.execute("CREATE DATABASE %s" % db.database)
- log("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY ###" % (db.database, db.username))
- root.cur.execute("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY %%s" % (db.database, db.username),
- (db.password,))
- root.db.commit()
- db.open()
- db.version = current_version
- db.close()
+ db = UserDB(name)
+ log("CREATE DATABASE %s" % db.database)
+ root.cur.execute("CREATE DATABASE %s" % db.database)
+ log("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY ###" % (db.database, db.username))
+ root.cur.execute("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY %%s" % (db.database, db.username),
+ (db.password,))
+ root.db.commit()
+ db.open()
+ db.version = current_version
+ db.close()
def do_script_drop(name):
- db = UserDB(name)
- print "DROP DATABASE IF EXISTS %s;" % db.database
+ db = UserDB(name)
+ print "DROP DATABASE IF EXISTS %s;" % db.database
def do_drop_and_create(name):
- do_drop(name)
- do_create(name)
+ do_drop(name)
+ do_create(name)
def do_fix_grants(name):
- db = UserDB(name)
- if not db.exists_and_accessible:
- log("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY ###" % (db.database, db.username))
- root.cur.execute("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY %%s" % (db.database, db.username),
- (db.password,))
- root.db.commit()
+ db = UserDB(name)
+ if not db.exists_and_accessible:
+ log("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY ###" % (db.database, db.username))
+ root.cur.execute("GRANT ALL ON %s.* TO %s@localhost IDENTIFIED BY %%s" % (db.database, db.username),
+ (db.password,))
+ root.db.commit()
def do_create_if_missing(name):
- db = UserDB(name)
- if not db.exists_and_accessible:
- do_create(name)
+ db = UserDB(name)
+ if not db.exists_and_accessible:
+ do_create(name)
def do_apply_upgrades(name):
- upgrades = sorted(Upgrade.load_all(name, args.upgrade_scripts))
- if upgrades:
- db = UserDB(name)
- db.open()
- log("Current version of %s is %s" % (db.name, db.version))
- for upgrade in upgrades:
- if upgrade.version > db.version:
- upgrade.apply(db)
- db.version = upgrade.version
- db.version = current_version
- db.close()
+ upgrades = sorted(Upgrade.load_all(name, args.upgrade_scripts))
+ if upgrades:
+ db = UserDB(name)
+ db.open()
+ log("Current version of %s is %s" % (db.name, db.version))
+ for upgrade in upgrades:
+ if upgrade.version > db.version:
+ upgrade.apply(db)
+ db.version = upgrade.version
+ db.version = current_version
+ db.close()
def log(text):
- if args.verbose:
- print "#", text
+ if args.verbose:
+ print "#", text
parser = argparse.ArgumentParser(description = """\
Automated setup of all SQL stuff used by the RPKI CA tools. Pulls
@@ -291,13 +291,13 @@ parser.set_defaults(dispatch = do_create_if_missing)
args = parser.parse_args()
try:
- cfg = rpki.config.parser(set_filename = args.config, section = "myrpki")
- root = RootDB(args.mysql_defaults)
- current_version = Version(rpki.version.VERSION)
- for program_name in ("irdbd", "rpkid", "pubd"):
- if cfg.getboolean("start_" + program_name, False):
- args.dispatch(program_name)
- root.close()
+ cfg = rpki.config.parser(set_filename = args.config, section = "myrpki")
+ root = RootDB(args.mysql_defaults)
+ current_version = Version(rpki.version.VERSION)
+ for program_name in ("irdbd", "rpkid", "pubd"):
+ if cfg.getboolean("start_" + program_name, False):
+ args.dispatch(program_name)
+ root.close()
except Exception, e:
- #raise
- sys.exit(str(e))
+ #raise
+ sys.exit(str(e))