diff options
Diffstat (limited to 'scripts/rpki/sql.py')
-rw-r--r-- | scripts/rpki/sql.py | 49 |
1 files changed, 27 insertions, 22 deletions
diff --git a/scripts/rpki/sql.py b/scripts/rpki/sql.py index bab07e5e..2831a447 100644 --- a/scripts/rpki/sql.py +++ b/scripts/rpki/sql.py @@ -14,16 +14,15 @@ def connect(cfg, section="sql"): class template(object): """SQL template generator.""" def __init__(self, table_name, *columns): - index_column = columns[0] - data_columns = columns[1:] - self.table = table_name - self.index = index_column - self.columns = columns - self.insert = "INSERT %s (%s) VALUES (%s)" % (table_name, ", ".join(data_columns), ", ".join("%(" + s + ")s" for s in data_columns)) - self.update = "UPDATE %s SET %s WHERE %s = %%(%s)s" % (table_name, ", ".join(s + " = %(" + s + ")s" for s in data_columns), index_column, index_column) - self.delete = "DELETE FROM %s WHERE %s = %%s" % (table_name, index_column) - self.select_all = "SELECT %s FROM %s" % (", ".join(columns), table_name) - self.select_one = self.select_all + " WHERE " + index_column + " = %s" + index_column = columns[0] + data_columns = columns[1:] + self.table = table_name + self.index = index_column + self.columns = columns + self.select = "SELECT %s FROM %s" % (", ".join(columns), table_name) + self.insert = "INSERT %s (%s) VALUES (%s)" % (table_name, ", ".join(data_columns), ", ".join("%(" + s + ")s" for s in data_columns)) + self.update = "UPDATE %s SET %s WHERE %s = %%(%s)s" % (table_name, ", ".join(s + " = %(" + s + ")s" for s in data_columns), index_column, index_column) + self.delete = "DELETE FROM %s WHERE %s = %%s" % (table_name, index_column) ## @var sql_cache # Cache of objects pulled from SQL. @@ -58,27 +57,33 @@ class sql_persistant(object): @classmethod def sql_fetch(cls, db, cur, id): - key = (cls, id) - if key in sql_cache: - return sql_cache[key] - cur.execute(cls.sql_template.select_one, id) - row = cur.fetchone() - if row is None: + results = cls.sql_fetch_where(db, cur, "WHERE %s = %s" % (cls.sql_template.index, id)) + assert len(results) <= 1 + if len(results) == 0: return None + elif len(results) == 1: + return results[0] else: - return cls.sql_init(db, cur, row, key) + raise rpki.exceptions.DBConsistancyError, "Database contained multiple matches for %s.%s" % (cls.__name__, id) @classmethod def sql_fetch_all(cls, db, cur): - cur.execute(cls.sql_template.select_all) - all = [] + return cls.sql_fetch_where(db, cur, None) + + @classmethod + def sql_fetch_where(cls, db, cur, where): + if where is None: + cur.execute(cls.sql_template.select) + else: + cur.execute(cls.sql_template.select + where) + results = [] for row in cur.fetchall(): key = (cls, row[0]) if key in sql_cache: - all.append(sql_cache[key]) + results.append(sql_cache[key]) else: - all.append(cls.sql_init(db, cur, row, key)) - return all + results.append(cls.sql_init(db, cur, row, key)) + return results @classmethod def sql_init(cls, db, cur, row, key): |