aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2008-06-06 22:54:12 +0000
committerRob Austein <sra@hactrn.net>2008-06-06 22:54:12 +0000
commit88524489465b9ab52e3df199e18b33da34c7a6fb (patch)
tree950856a042ad6aa45e5316ca49d05b3255bac114
parent3a273ef8d516ed1fd37ded395168ccb35372aea4 (diff)
Refactor SQL code to form rpki.sq.session class, and add code to warn
about dirty objects in SQL cache on SQL errors. svn path=/rpkid/irdbd.py; revision=1852
-rwxr-xr-xrpkid/irdbd.py2
-rwxr-xr-xrpkid/pubd.py12
-rw-r--r--rpkid/rpki/left_right.py28
-rw-r--r--rpkid/rpki/resource_set.py25
-rw-r--r--rpkid/rpki/rpki_engine.py38
-rw-r--r--rpkid/rpki/sql.py113
-rw-r--r--rpkid/rpki/up_down.py4
7 files changed, 113 insertions, 109 deletions
diff --git a/rpkid/irdbd.py b/rpkid/irdbd.py
index a2dccaa4..22fef807 100755
--- a/rpkid/irdbd.py
+++ b/rpkid/irdbd.py
@@ -30,7 +30,7 @@ import rpki.exceptions, rpki.left_right, rpki.log, rpki.x509
def handler(query, path):
try:
- db.ping(True)
+ db.ping()
q_msg = rpki.left_right.cms_msg.unwrap(query, (bpki_ta, rpkid_cert))
diff --git a/rpkid/pubd.py b/rpkid/pubd.py
index 81d15524..4e1ebcc9 100755
--- a/rpkid/pubd.py
+++ b/rpkid/pubd.py
@@ -34,8 +34,7 @@ class pubd_context(rpki.rpki_engine.rpkid_context):
def __init__(self, cfg):
- self.db = rpki.sql.connect(cfg)
- self.cur = self.db.cursor()
+ self.sql = rpki.sql.session(cfg)
self.bpki_ta = rpki.x509.X509(Auto_file = cfg.get("bpki-ta"))
self.irbe_cert = rpki.x509.X509(Auto_file = cfg.get("irbe-cert"))
@@ -47,22 +46,19 @@ class pubd_context(rpki.rpki_engine.rpkid_context):
self.publication_base = cfg.get("publication-base", "publication/")
- self.sql_cache = {}
- self.sql_dirty = set()
-
def handler_common(self, query, client, certs, crl = None):
"""Common PDU handler code."""
q_msg = rpki.publication.cms_msg.unwrap(query, certs)
r_msg = q_msg.serve_top_level(self, client)
reply = rpki.publication.cms_msg.wrap(r_msg, self.pubd_key, self.pubd_cert, crl)
- self.sql_sweep()
+ self.sql.sweep()
return reply
def control_handler(self, query, path):
"""Process one PDU from the IRBE."""
rpki.log.trace()
try:
- self.db.ping(True)
+ self.sql.ping()
return 200, self.handler_common(query, None, (self.bpki_ta, self.irbe_cert))
except Exception, data:
rpki.log.error(traceback.format_exc())
@@ -72,7 +68,7 @@ class pubd_context(rpki.rpki_engine.rpkid_context):
"""Process one PDU from a client."""
rpki.log.trace()
try:
- self.db.ping(True)
+ self.sql.ping()
client_id = path.partition("/client/")[2]
if not client_id.isdigit():
raise rpki.exceptions.BadContactURL, "Bad path: %s" % path
diff --git a/rpkid/rpki/left_right.py b/rpkid/rpki/left_right.py
index 3be62e25..e6caf838 100644
--- a/rpkid/rpki/left_right.py
+++ b/rpkid/rpki/left_right.py
@@ -161,7 +161,7 @@ class self_elt(data_elt):
rpki.rpki_engine.ca_obj.create(parent, rc)
for ca in ca_map.values():
ca.delete(parent) # CA not listed by parent
- self.gctx.sql_sweep()
+ self.gctx.sql.sweep()
def update_children(self):
"""Check for updated IRDB data for all of this self's children and
@@ -546,19 +546,23 @@ class route_origin_elt(data_elt):
def sql_fetch_hook(self):
"""Extra SQL fetch actions for route_origin_elt -- handle prefix list."""
- self.ipv4 = rpki.resource_set.roa_prefix_set_ipv4.from_sql(self.gctx.cur, """
- SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix
- WHERE route_origin_id = %s AND address NOT LIKE '%:%'
- """, (self.route_origin_id,))
- self.ipv6 = rpki.resource_set.roa_prefix_set_ipv6.from_sql(self.gctx.cur, """
- SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix
- WHERE route_origin_id = %s AND address LIKE '%:%'
- """, (self.route_origin_id,))
+ self.ipv4 = rpki.resource_set.roa_prefix_set_ipv4.from_sql(
+ self.gctx.sql,
+ """
+ SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix
+ WHERE route_origin_id = %s AND address NOT LIKE '%:%'
+ """, (self.route_origin_id,))
+ self.ipv6 = rpki.resource_set.roa_prefix_set_ipv6.from_sql(
+ self.gctx.sql,
+ """
+ SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix
+ WHERE route_origin_id = %s AND address LIKE '%:%'
+ """, (self.route_origin_id,))
def sql_insert_hook(self):
"""Extra SQL insert actions for route_origin_elt -- handle address ranges."""
if self.ipv4 or self.ipv6:
- self.gctx.cur.executemany("""
+ self.gctx.sql.executemany("""
INSERT route_origin_prefix (route_origin_id, address, prefixlen, max_prefixlen)
VALUES (%s, %s, %s, %s)""",
((self.route_origin_id, x.address, x.prefixlen, x.max_prefixlen)
@@ -566,7 +570,7 @@ class route_origin_elt(data_elt):
def sql_delete_hook(self):
"""Extra SQL delete actions for route_origin_elt -- handle address ranges."""
- self.gctx.cur.execute("DELETE FROM route_origin_prefix WHERE route_origin_id = %s", (self.route_origin_id,))
+ self.gctx.sql.execute("DELETE FROM route_origin_prefix WHERE route_origin_id = %s", (self.route_origin_id,))
def ca_detail(self):
"""Fetch all ca_detail objects that link to this route_origin object."""
@@ -711,7 +715,7 @@ class route_origin_elt(data_elt):
rpki.rpki_engine.revoked_cert_obj.revoke(cert = cert, ca_detail = ca_detail)
repository.withdraw(roa, roa_uri)
repository.withdraw(cert, ee_uri)
- self.gctx.sql_sweep()
+ self.gctx.sql.sweep()
ca_detail.generate_crl()
ca_detail.generate_manifest()
diff --git a/rpkid/rpki/resource_set.py b/rpkid/rpki/resource_set.py
index 3d92b2ce..3b0ce4e5 100644
--- a/rpkid/rpki/resource_set.py
+++ b/rpkid/rpki/resource_set.py
@@ -298,18 +298,19 @@ class resource_set(list):
return other.issubset(self)
@classmethod
- def from_sql(cls, cur, query, args = None):
+ def from_sql(cls, sql, query, args = None):
"""Create resource set from an SQL query.
- cur is a DB API 2.0 cursor object.
+ sql is an object that supports execute() and fetchall() methods
+ like a DB API 2.0 cursor object.
query is an SQL query that returns a sequence of (min, max) pairs.
"""
- cur.execute(query, args)
+ sql.execute(query, args)
return cls(ini = [cls.range_type(cls.range_type.datum_type(b),
cls.range_type.datum_type(e))
- for (b,e) in cur.fetchall()])
+ for (b,e) in sql.fetchall()])
class resource_set_as(resource_set):
"""ASN resource set."""
@@ -667,15 +668,19 @@ class roa_prefix_set(list):
return self.resource_set_type([p.to_resource_range() for p in self])
@classmethod
- def from_sql(cls, cur, query, args = None):
- """Create ROA prefix set from an SQL query. cur is a DB API 2.0
- cursor object. query is an SQL query that returns a sequence of
- (address, prefixlen, max_prefixlen) triples.
+ def from_sql(cls, sql, query, args = None):
+ """Create ROA prefix set from an SQL query.
+
+ sql is an object that supports execute() and fetchall() methods
+ like a DB API 2.0 cursor object.
+
+ query is an SQL query that returns a sequence of (address,
+ prefixlen, max_prefixlen) triples.
"""
- cur.execute(query, args)
+ sql.execute(query, args)
return cls([cls.prefix_type(cls.prefix_type.range_type.datum_type(x), int(y), int(z))
- for (x,y,z) in cur.fetchall()])
+ for (x,y,z) in sql.fetchall()])
def to_roa_tuple(self):
"""Convert ROA prefix set into tuple format used by ROA ASN.1 encoder.
diff --git a/rpkid/rpki/rpki_engine.py b/rpkid/rpki/rpki_engine.py
index 18c17c48..cd16c21d 100644
--- a/rpkid/rpki/rpki_engine.py
+++ b/rpkid/rpki/rpki_engine.py
@@ -28,8 +28,7 @@ class rpkid_context(object):
def __init__(self, cfg):
- self.db = rpki.sql.connect(cfg)
- self.cur = self.db.cursor()
+ self.sql = rpki.sql.session(cfg)
self.bpki_ta = rpki.x509.X509(Auto_file = cfg.get("bpki-ta"))
self.irdb_cert = rpki.x509.X509(Auto_file = cfg.get("irdb-cert"))
@@ -44,9 +43,6 @@ class rpkid_context(object):
self.publication_kludge_base = cfg.get("publication-kludge-base", "publication/")
- self.sql_cache = {}
- self.sql_dirty = set()
-
def irdb_query(self, self_id, child_id = None):
"""Perform an IRDB callback query. In the long run this should not
be a blocking routine, it should instead issue a query and set up a
@@ -81,35 +77,17 @@ class rpkid_context(object):
v6 = r_msg[0].ipv6,
valid_until = r_msg[0].valid_until)
- def sql_cache_clear(self):
- """Clear the object cache."""
- self.sql_cache.clear()
-
- def sql_assert_pristine(self):
- """Assert that there are no dirty objects in the cache."""
- assert not self.sql_dirty, "Dirty objects in SQL cache: %s" % self.sql_dirty
-
- def sql_sweep(self):
- """Write any dirty objects out to SQL."""
- for s in self.sql_dirty.copy():
- rpki.log.debug("Sweeping %s" % repr(s))
- if s.sql_deleted:
- s.sql_delete()
- else:
- s.sql_store()
- self.sql_assert_pristine()
-
def left_right_handler(self, query, path):
"""Process one left-right PDU."""
rpki.log.trace()
try:
- self.db.ping(True)
+ self.sql.ping()
q_msg = rpki.left_right.cms_msg.unwrap(query, (self.bpki_ta, self.irbe_cert))
if q_msg.type != "query":
raise rpki.exceptions.BadQuery, "Message type is not query"
r_msg = q_msg.serve_top_level(self)
reply = rpki.left_right.cms_msg.wrap(r_msg, self.rpkid_key, self.rpkid_cert)
- self.sql_sweep()
+ self.sql.sweep()
return 200, reply
except Exception, data:
rpki.log.error(traceback.format_exc())
@@ -119,7 +97,7 @@ class rpkid_context(object):
"""Process one up-down PDU."""
rpki.log.trace()
try:
- self.db.ping(True)
+ self.sql.ping()
child_id = path.partition("/up-down/")[2]
if not child_id.isdigit():
raise rpki.exceptions.BadContactURL, "Bad path: %s" % path
@@ -127,7 +105,7 @@ class rpkid_context(object):
if child is None:
raise rpki.exceptions.ChildNotFound, "Could not find child %s" % child_id
reply = child.serve_up_down(query)
- self.sql_sweep()
+ self.sql.sweep()
return 200, reply
except Exception, data:
rpki.log.error(traceback.format_exc())
@@ -140,13 +118,13 @@ class rpkid_context(object):
rpki.log.trace()
try:
- self.db.ping(True)
+ self.sql.ping()
for s in rpki.left_right.self_elt.sql_fetch_all(self):
s.client_poll()
s.update_children()
s.update_roas()
s.regenerate_crls_and_manifests()
- self.sql_sweep()
+ self.sql.sweep()
return 200, "OK"
except Exception, data:
rpki.log.error(traceback.format_exc())
@@ -694,7 +672,7 @@ class child_cert_obj(rpki.sql.sql_persistant):
revoked_cert_obj.revoke(cert = self.cert, ca_detail = ca_detail)
repository = ca.parent().repository()
repository.withdraw(self.cert, self.uri(ca))
- self.gctx.sql_sweep()
+ self.gctx.sql.sweep()
self.sql_delete()
def reissue(self, ca_detail, resources = None, sia = None):
diff --git a/rpkid/rpki/sql.py b/rpkid/rpki/sql.py
index 81df728f..fa5927fa 100644
--- a/rpkid/rpki/sql.py
+++ b/rpkid/rpki/sql.py
@@ -15,59 +15,80 @@
# PERFORMANCE OF THIS SOFTWARE.
import MySQLdb, time, warnings, _mysql_exceptions
-import rpki.x509, rpki.resource_set, rpki.sundial
+import rpki.x509, rpki.resource_set, rpki.sundial, rpki.log
-def connect(cfg, throw_exception_on_warning = True):
- """Connect to a MySQL database using connection parameters from an
- rpki.config.parser object.
- """
-
- if throw_exception_on_warning:
- warnings.simplefilter("error", _mysql_exceptions.Warning)
-
- return MySQLdb.connect(user = cfg.get("sql-username"),
- db = cfg.get("sql-database"),
- passwd = cfg.get("sql-password"))
-
-class sesssion(object):
+class session(object):
"""SQL session layer."""
- def __init__(self, cfg):
+ _exceptions_enabled = False
- raise rpki.errorsNotImplementedYet, "This class is still under construction"
+ def __init__(self, cfg):
- warnings.simplefilter("error", _mysql_exceptions.Warning)
+ if not self._exceptions_enabled:
+ warnings.simplefilter("error", _mysql_exceptions.Warning)
+ self.__class__._exceptions_enabled = True
self.username = cfg.get("sql-username")
self.database = cfg.get("sql-database")
self.password = cfg.get("sql-password")
- self.sql_cache = {}
- self.sql_dirty = set()
+ self.cache = {}
+ self.dirty = set()
self.connect()
def connect(self):
- self.db = MySQLdb.connect(user = username, db = database, passwd = password)
+ self.db = MySQLdb.connect(user = self.username, db = self.database, passwd = self.password)
self.cur = self.db.cursor()
- def sql_cache_clear(self):
+ def close(self):
+ if self.cur:
+ self.cur.close()
+ self.cur = None
+ if self.db:
+ self.db.close()
+ self.db = None
+
+ def ping(self):
+ return self.db.ping(True)
+
+ def _wrap_execute(self, func, query, args):
+ try:
+ return func(query, args)
+ except _mysql_exceptions.MySQLError:
+ if self.dirty:
+ rpki.log.warn("MySQL exception with dirty objects in SQL cache!")
+ raise
+
+ def execute(self, query, args = None):
+ return self._wrap_execute(self.cur.execute, query, args)
+
+ def executemany(self, query, args):
+ return self._wrap_execute(self.cur.executemany, query, args)
+
+ def fetchall(self):
+ return self.cur.fetchall()
+
+ def lastrowid(self):
+ return self.cur.lastrowid
+
+ def cache_clear(self):
"""Clear the object cache."""
- self.sql_cache.clear()
+ self.cache.clear()
- def sql_assert_pristine(self):
+ def assert_pristine(self):
"""Assert that there are no dirty objects in the cache."""
- assert not self.sql_dirty, "Dirty objects in SQL cache: %s" % self.sql_dirty
+ assert not self.dirty, "Dirty objects in SQL cache: %s" % self.dirty
- def sql_sweep(self):
+ def sweep(self):
"""Write any dirty objects out to SQL."""
- for s in self.sql_dirty.copy():
+ for s in self.dirty.copy():
rpki.log.debug("Sweeping %s" % repr(s))
if s.sql_deleted:
s.sql_delete()
else:
s.sql_store()
- self.sql_assert_pristine()
+ self.assert_pristine()
class template(object):
"""SQL template generator."""
@@ -124,8 +145,8 @@ class sql_persistant(object):
return None
assert isinstance(id, (int, long)), "id should be an integer, was %s" % repr(type(id))
key = (cls, id)
- if key in gctx.sql_cache:
- return gctx.sql_cache[key]
+ if key in gctx.sql.cache:
+ return gctx.sql.cache[key]
else:
return cls.sql_fetch_where1(gctx, "%s = %%s" % cls.sql_template.index, (id,))
@@ -154,17 +175,17 @@ class sql_persistant(object):
assert args is None
if cls.sql_debug:
rpki.log.debug("sql_fetch_where(%s)" % repr(cls.sql_template.select))
- gctx.cur.execute(cls.sql_template.select)
+ gctx.sql.execute(cls.sql_template.select)
else:
query = cls.sql_template.select + " WHERE " + where
if cls.sql_debug:
rpki.log.debug("sql_fetch_where(%s, %s)" % (repr(query), repr(args)))
- gctx.cur.execute(query, args)
+ gctx.sql.execute(query, args)
results = []
- for row in gctx.cur.fetchall():
+ for row in gctx.sql.fetchall():
key = (cls, row[0])
- if key in gctx.sql_cache:
- results.append(gctx.sql_cache[key])
+ if key in gctx.sql.cache:
+ results.append(gctx.sql.cache[key])
else:
results.append(cls.sql_init(gctx, row, key))
return results
@@ -175,22 +196,22 @@ class sql_persistant(object):
self = cls()
self.gctx = gctx
self.sql_decode(dict(zip(cls.sql_template.columns, row)))
- gctx.sql_cache[key] = self
+ gctx.sql.cache[key] = self
self.sql_in_db = True
self.sql_fetch_hook()
return self
def sql_mark_dirty(self):
"""Mark this object as needing to be written back to SQL."""
- self.gctx.sql_dirty.add(self)
+ self.gctx.sql.dirty.add(self)
def sql_mark_clean(self):
"""Mark this object as not needing to be written back to SQL."""
- self.gctx.sql_dirty.discard(self)
+ self.gctx.sql.dirty.discard(self)
def sql_is_dirty(self):
"""Query whether this object needs to be written back to SQL."""
- return self in self.gctx.sql_dirty
+ return self in self.gctx.sql.dirty
def sql_mark_deleted(self):
"""Mark this object as needing to be deleted in SQL."""
@@ -199,15 +220,15 @@ class sql_persistant(object):
def sql_store(self):
"""Store this object to SQL."""
if not self.sql_in_db:
- self.gctx.cur.execute(self.sql_template.insert, self.sql_encode())
- setattr(self, self.sql_template.index, self.gctx.cur.lastrowid)
- self.gctx.sql_cache[(self.__class__, self.gctx.cur.lastrowid)] = self
+ self.gctx.sql.execute(self.sql_template.insert, self.sql_encode())
+ setattr(self, self.sql_template.index, self.gctx.sql.lastrowid())
+ self.gctx.sql.cache[(self.__class__, self.gctx.sql.lastrowid())] = self
self.sql_insert_hook()
else:
- self.gctx.cur.execute(self.sql_template.update, self.sql_encode())
+ self.gctx.sql.execute(self.sql_template.update, self.sql_encode())
self.sql_update_hook()
key = (self.__class__, getattr(self, self.sql_template.index))
- assert key in self.gctx.sql_cache and self.gctx.sql_cache[key] == self
+ assert key in self.gctx.sql.cache and self.gctx.sql.cache[key] == self
self.sql_mark_clean()
self.sql_in_db = True
@@ -215,11 +236,11 @@ class sql_persistant(object):
"""Delete this object from SQL."""
if self.sql_in_db:
id = getattr(self, self.sql_template.index)
- self.gctx.cur.execute(self.sql_template.delete, id)
+ self.gctx.sql.execute(self.sql_template.delete, id)
self.sql_delete_hook()
key = (self.__class__, id)
- if self.gctx.sql_cache.get(key) == self:
- del self.gctx.sql_cache[key]
+ if self.gctx.sql.cache.get(key) == self:
+ del self.gctx.sql.cache[key]
self.sql_in_db = False
self.sql_mark_clean()
diff --git a/rpkid/rpki/up_down.py b/rpkid/rpki/up_down.py
index df5445e4..50e1d701 100644
--- a/rpkid/rpki/up_down.py
+++ b/rpkid/rpki/up_down.py
@@ -306,7 +306,7 @@ class issue_pdu(base_elt):
resources = resources)
# Save anything we modified and generate response
- self.gctx.sql_sweep()
+ self.gctx.sql.sweep()
assert child_cert and child_cert.sql_in_db
c = certificate_elt()
c.cert_url = multi_uri(child_cert.uri(ca))
@@ -365,7 +365,7 @@ class revoke_pdu(revoke_syntax):
for ca_detail in child.ca_from_class_name(self.class_name).ca_details():
for child_cert in child.child_certs(ca_detail = ca_detail, ski = self.get_SKI()):
child_cert.revoke()
- self.gctx.sql_sweep()
+ self.gctx.sql.sweep()
r_msg.payload = revoke_response_pdu()
r_msg.payload.class_name = self.class_name
r_msg.payload.ski = self.ski