aboutsummaryrefslogtreecommitdiff
path: root/scripts/rpki/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/rpki/sql.py')
-rw-r--r--scripts/rpki/sql.py49
1 files changed, 27 insertions, 22 deletions
diff --git a/scripts/rpki/sql.py b/scripts/rpki/sql.py
index bab07e5e..2831a447 100644
--- a/scripts/rpki/sql.py
+++ b/scripts/rpki/sql.py
@@ -14,16 +14,15 @@ def connect(cfg, section="sql"):
class template(object):
"""SQL template generator."""
def __init__(self, table_name, *columns):
- index_column = columns[0]
- data_columns = columns[1:]
- self.table = table_name
- self.index = index_column
- self.columns = columns
- self.insert = "INSERT %s (%s) VALUES (%s)" % (table_name, ", ".join(data_columns), ", ".join("%(" + s + ")s" for s in data_columns))
- self.update = "UPDATE %s SET %s WHERE %s = %%(%s)s" % (table_name, ", ".join(s + " = %(" + s + ")s" for s in data_columns), index_column, index_column)
- self.delete = "DELETE FROM %s WHERE %s = %%s" % (table_name, index_column)
- self.select_all = "SELECT %s FROM %s" % (", ".join(columns), table_name)
- self.select_one = self.select_all + " WHERE " + index_column + " = %s"
+ index_column = columns[0]
+ data_columns = columns[1:]
+ self.table = table_name
+ self.index = index_column
+ self.columns = columns
+ self.select = "SELECT %s FROM %s" % (", ".join(columns), table_name)
+ self.insert = "INSERT %s (%s) VALUES (%s)" % (table_name, ", ".join(data_columns), ", ".join("%(" + s + ")s" for s in data_columns))
+ self.update = "UPDATE %s SET %s WHERE %s = %%(%s)s" % (table_name, ", ".join(s + " = %(" + s + ")s" for s in data_columns), index_column, index_column)
+ self.delete = "DELETE FROM %s WHERE %s = %%s" % (table_name, index_column)
## @var sql_cache
# Cache of objects pulled from SQL.
@@ -58,27 +57,33 @@ class sql_persistant(object):
@classmethod
def sql_fetch(cls, db, cur, id):
- key = (cls, id)
- if key in sql_cache:
- return sql_cache[key]
- cur.execute(cls.sql_template.select_one, id)
- row = cur.fetchone()
- if row is None:
+ results = cls.sql_fetch_where(db, cur, "WHERE %s = %s" % (cls.sql_template.index, id))
+ assert len(results) <= 1
+ if len(results) == 0:
return None
+ elif len(results) == 1:
+ return results[0]
else:
- return cls.sql_init(db, cur, row, key)
+ raise rpki.exceptions.DBConsistancyError, "Database contained multiple matches for %s.%s" % (cls.__name__, id)
@classmethod
def sql_fetch_all(cls, db, cur):
- cur.execute(cls.sql_template.select_all)
- all = []
+ return cls.sql_fetch_where(db, cur, None)
+
+ @classmethod
+ def sql_fetch_where(cls, db, cur, where):
+ if where is None:
+ cur.execute(cls.sql_template.select)
+ else:
+ cur.execute(cls.sql_template.select + where)
+ results = []
for row in cur.fetchall():
key = (cls, row[0])
if key in sql_cache:
- all.append(sql_cache[key])
+ results.append(sql_cache[key])
else:
- all.append(cls.sql_init(db, cur, row, key))
- return all
+ results.append(cls.sql_init(db, cur, row, key))
+ return results
@classmethod
def sql_init(cls, db, cur, row, key):