aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--scripts/rpki/left_right.py12
-rw-r--r--scripts/rpki/sql.py36
2 files changed, 40 insertions, 8 deletions
diff --git a/scripts/rpki/left_right.py b/scripts/rpki/left_right.py
index e7f164e6..9dc721d3 100644
--- a/scripts/rpki/left_right.py
+++ b/scripts/rpki/left_right.py
@@ -175,11 +175,13 @@ class parent_elt(base_elt, rpki.sql.sql_persistant):
WHERE parent_id = %(parent_id)s"""
sql_delete_cmd = """DELETE FROM parent WHERE parent_id = %(parent_id)s"""
+ sql_children = (("cas", rpki.sql.ca_obj),)
+
def sql_decode(self, sql_parent, parent_id, ta, uri, sia_base, bsc_id, repos_id):
assert isinstance(sql_parent, self_elt)
self.self_obj = sql_parent
- self.bsc_obj = self.self_obj.bscs[bsc_id]
- self.repository_obj = self.self_obj.repos[repos_id]
+ self.bsc_obj = bsc_elt.sql_cache_find(bsc_id)
+ self.repository_obj = repository_elt.sql_cache_find(repos_id)
self.parent_id = parent_id
self.peer_contact = uri
self.peer_ta = rpki.x509.X509(DER=ta)
@@ -232,8 +234,8 @@ class child_elt(base_elt, rpki.sql.sql_persistant):
def sql_decode(self, sql_parent, child_id, ta, bsc_id):
assert isinstance(sql_parent, self_elt)
self.self_obj = sql_parent
- self.bsc_obj = self.self_obj.bscs[bsc_id]
- self.bsc_link = bsc_id
+ self.bsc_obj = bsc_elt.sql_cache_find(bsc_id)
+ self.child_id = child_id
self.peer_ta = rpki.x509.X509(DER=ta)
def sql_encode(self):
@@ -280,7 +282,7 @@ class repository_elt(base_elt, rpki.sql.sql_persistant):
def sql_decode(self, sql_parent, bsc_id, repos_id, uri, ta):
assert isinstance(sql_parent, self_elt)
self.self_obj = sql_parent
- self.bsc_obj = self.self_obj.bscs[bsc_id]
+ self.bsc_obj = bsc_elt.sql_cache_find(bsc_id)
self.repository_id = repos_id
self.peer_contact = uri
self.peer_ta = rpki.x509.X509(DER=ta)
diff --git a/scripts/rpki/sql.py b/scripts/rpki/sql.py
index e32dac48..02593930 100644
--- a/scripts/rpki/sql.py
+++ b/scripts/rpki/sql.py
@@ -11,6 +11,17 @@ def connect(cfg, section="sql"):
db = cfg.get(section, "sql-database"),
passwd = cfg.get(section, "sql-password"))
+## @var sql_cache
+# Cache of objects pulled from SQL.
+
+sql_cache = {}
+
+def cache_clear():
+ """Clear the object cache."""
+
+ sql_cache = {}
+
+
class sql_persistant(object):
"""Mixin for persistant class that needs to be stored in SQL.
"""
@@ -60,6 +71,24 @@ class sql_persistant(object):
# Command to DELETE this object from SQL
sql_delete_cmd = None
+ def sql_cache_add(self):
+ """Add self to the object cache."""
+
+ assert self.sql_id_name is not None
+ sql_cache[(self.__class__, self.sql_id_name)] = self
+
+ @classmethod
+ def sql_cache_find(*keys):
+ """Find an object in the object cache."""
+
+ return sql_cache.get(keys)
+
+ def cache_delete(*keys):
+ """Delete self from the object cache."""
+
+ assert self.sql_id_name is not None
+ del sql_cache[(self.__class__, self.sql_id_name)]
+
@classmethod
def sql_fetch(cls, db, cur=None, select_dict=None, sql_parent=None):
"""Fetch rows from SQL based on a canned query and a set of
@@ -81,13 +110,13 @@ class sql_persistant(object):
self = cls()
self.in_sql = True
self.sql_decode(sql_parent, *row)
+ if self.sql_id_name is not None:
+ cache_add(self, self.__class__, getattr(self, self.sql_id_name))
self_dict = self.sql_encode()
self.sql_fetch_hook(db, cur)
result.append(self)
for k,v in self.sql_children:
setattr(self, k, v.sql_fetch(db, cur, self_dict, self))
- if cls.sql_id_name is not None:
- result = dict((getattr(i, cls.sql_id_name), i) for i in result)
return result
def sql_store(self, db, cur=None):
@@ -99,6 +128,7 @@ class sql_persistant(object):
cur.execute(self.sql_insert_cmd, self.sql_encode())
if self.sql_id_name is not None:
setattr(self, self.sql_id_name, cur.lastrowid)
+ cache_add(self, self.__class__, getattr(self, self.sql_id_name))
self.sql_insert_hook(db, cur)
elif self.sql_dirty:
cur.execute(self.sql_update_cmd, self.sql_encode())
@@ -216,7 +246,7 @@ class ca_obj(sql_persistant):
WHERE ca_id = %(ca_id)s"""
sql_delete_cmd = """DELETE FROM ca WHERE ca_id = %(ca_id)s"""
- sql_children = (("ca_details", ca_detail_obj))
+ sql_children = (("ca_details", ca_detail_obj),)
def sql_decode(self, sql_parent, ca_id, last_crl_sn, next_crl_update, last_issued_sn, last_manifest_sn, next_manifest_update, sia_uri, parent_id):
assert isinstance(sql_parent, rpki.left_right.parent_elt)