aboutsummaryrefslogtreecommitdiff
path: root/scripts/rpki/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/rpki/sql.py')
-rw-r--r--scripts/rpki/sql.py36
1 files changed, 33 insertions, 3 deletions
diff --git a/scripts/rpki/sql.py b/scripts/rpki/sql.py
index e32dac48..02593930 100644
--- a/scripts/rpki/sql.py
+++ b/scripts/rpki/sql.py
@@ -11,6 +11,17 @@ def connect(cfg, section="sql"):
db = cfg.get(section, "sql-database"),
passwd = cfg.get(section, "sql-password"))
+## @var sql_cache
+# Cache of objects pulled from SQL.
+
+sql_cache = {}
+
+def cache_clear():
+ """Clear the object cache."""
+
+ sql_cache = {}
+
+
class sql_persistant(object):
"""Mixin for persistant class that needs to be stored in SQL.
"""
@@ -60,6 +71,24 @@ class sql_persistant(object):
# Command to DELETE this object from SQL
sql_delete_cmd = None
+ def sql_cache_add(self):
+ """Add self to the object cache."""
+
+ assert self.sql_id_name is not None
+ sql_cache[(self.__class__, self.sql_id_name)] = self
+
+ @classmethod
+ def sql_cache_find(*keys):
+ """Find an object in the object cache."""
+
+ return sql_cache.get(keys)
+
+ def cache_delete(*keys):
+ """Delete self from the object cache."""
+
+ assert self.sql_id_name is not None
+ del sql_cache[(self.__class__, self.sql_id_name)]
+
@classmethod
def sql_fetch(cls, db, cur=None, select_dict=None, sql_parent=None):
"""Fetch rows from SQL based on a canned query and a set of
@@ -81,13 +110,13 @@ class sql_persistant(object):
self = cls()
self.in_sql = True
self.sql_decode(sql_parent, *row)
+ if self.sql_id_name is not None:
+ cache_add(self, self.__class__, getattr(self, self.sql_id_name))
self_dict = self.sql_encode()
self.sql_fetch_hook(db, cur)
result.append(self)
for k,v in self.sql_children:
setattr(self, k, v.sql_fetch(db, cur, self_dict, self))
- if cls.sql_id_name is not None:
- result = dict((getattr(i, cls.sql_id_name), i) for i in result)
return result
def sql_store(self, db, cur=None):
@@ -99,6 +128,7 @@ class sql_persistant(object):
cur.execute(self.sql_insert_cmd, self.sql_encode())
if self.sql_id_name is not None:
setattr(self, self.sql_id_name, cur.lastrowid)
+ cache_add(self, self.__class__, getattr(self, self.sql_id_name))
self.sql_insert_hook(db, cur)
elif self.sql_dirty:
cur.execute(self.sql_update_cmd, self.sql_encode())
@@ -216,7 +246,7 @@ class ca_obj(sql_persistant):
WHERE ca_id = %(ca_id)s"""
sql_delete_cmd = """DELETE FROM ca WHERE ca_id = %(ca_id)s"""
- sql_children = (("ca_details", ca_detail_obj))
+ sql_children = (("ca_details", ca_detail_obj),)
def sql_decode(self, sql_parent, ca_id, last_crl_sn, next_crl_update, last_issued_sn, last_manifest_sn, next_manifest_update, sia_uri, parent_id):
assert isinstance(sql_parent, rpki.left_right.parent_elt)