diff options
author | Rob Austein <sra@hactrn.net> | 2008-06-06 22:54:12 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2008-06-06 22:54:12 +0000 |
commit | 88524489465b9ab52e3df199e18b33da34c7a6fb (patch) | |
tree | 950856a042ad6aa45e5316ca49d05b3255bac114 | |
parent | 3a273ef8d516ed1fd37ded395168ccb35372aea4 (diff) |
Refactor SQL code to form rpki.sq.session class, and add code to warn
about dirty objects in SQL cache on SQL errors.
svn path=/rpkid/irdbd.py; revision=1852
-rwxr-xr-x | rpkid/irdbd.py | 2 | ||||
-rwxr-xr-x | rpkid/pubd.py | 12 | ||||
-rw-r--r-- | rpkid/rpki/left_right.py | 28 | ||||
-rw-r--r-- | rpkid/rpki/resource_set.py | 25 | ||||
-rw-r--r-- | rpkid/rpki/rpki_engine.py | 38 | ||||
-rw-r--r-- | rpkid/rpki/sql.py | 113 | ||||
-rw-r--r-- | rpkid/rpki/up_down.py | 4 |
7 files changed, 113 insertions, 109 deletions
diff --git a/rpkid/irdbd.py b/rpkid/irdbd.py index a2dccaa4..22fef807 100755 --- a/rpkid/irdbd.py +++ b/rpkid/irdbd.py @@ -30,7 +30,7 @@ import rpki.exceptions, rpki.left_right, rpki.log, rpki.x509 def handler(query, path): try: - db.ping(True) + db.ping() q_msg = rpki.left_right.cms_msg.unwrap(query, (bpki_ta, rpkid_cert)) diff --git a/rpkid/pubd.py b/rpkid/pubd.py index 81d15524..4e1ebcc9 100755 --- a/rpkid/pubd.py +++ b/rpkid/pubd.py @@ -34,8 +34,7 @@ class pubd_context(rpki.rpki_engine.rpkid_context): def __init__(self, cfg): - self.db = rpki.sql.connect(cfg) - self.cur = self.db.cursor() + self.sql = rpki.sql.session(cfg) self.bpki_ta = rpki.x509.X509(Auto_file = cfg.get("bpki-ta")) self.irbe_cert = rpki.x509.X509(Auto_file = cfg.get("irbe-cert")) @@ -47,22 +46,19 @@ class pubd_context(rpki.rpki_engine.rpkid_context): self.publication_base = cfg.get("publication-base", "publication/") - self.sql_cache = {} - self.sql_dirty = set() - def handler_common(self, query, client, certs, crl = None): """Common PDU handler code.""" q_msg = rpki.publication.cms_msg.unwrap(query, certs) r_msg = q_msg.serve_top_level(self, client) reply = rpki.publication.cms_msg.wrap(r_msg, self.pubd_key, self.pubd_cert, crl) - self.sql_sweep() + self.sql.sweep() return reply def control_handler(self, query, path): """Process one PDU from the IRBE.""" rpki.log.trace() try: - self.db.ping(True) + self.sql.ping() return 200, self.handler_common(query, None, (self.bpki_ta, self.irbe_cert)) except Exception, data: rpki.log.error(traceback.format_exc()) @@ -72,7 +68,7 @@ class pubd_context(rpki.rpki_engine.rpkid_context): """Process one PDU from a client.""" rpki.log.trace() try: - self.db.ping(True) + self.sql.ping() client_id = path.partition("/client/")[2] if not client_id.isdigit(): raise rpki.exceptions.BadContactURL, "Bad path: %s" % path diff --git a/rpkid/rpki/left_right.py b/rpkid/rpki/left_right.py index 3be62e25..e6caf838 100644 --- a/rpkid/rpki/left_right.py +++ b/rpkid/rpki/left_right.py @@ -161,7 +161,7 @@ class self_elt(data_elt): rpki.rpki_engine.ca_obj.create(parent, rc) for ca in ca_map.values(): ca.delete(parent) # CA not listed by parent - self.gctx.sql_sweep() + self.gctx.sql.sweep() def update_children(self): """Check for updated IRDB data for all of this self's children and @@ -546,19 +546,23 @@ class route_origin_elt(data_elt): def sql_fetch_hook(self): """Extra SQL fetch actions for route_origin_elt -- handle prefix list.""" - self.ipv4 = rpki.resource_set.roa_prefix_set_ipv4.from_sql(self.gctx.cur, """ - SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix - WHERE route_origin_id = %s AND address NOT LIKE '%:%' - """, (self.route_origin_id,)) - self.ipv6 = rpki.resource_set.roa_prefix_set_ipv6.from_sql(self.gctx.cur, """ - SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix - WHERE route_origin_id = %s AND address LIKE '%:%' - """, (self.route_origin_id,)) + self.ipv4 = rpki.resource_set.roa_prefix_set_ipv4.from_sql( + self.gctx.sql, + """ + SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix + WHERE route_origin_id = %s AND address NOT LIKE '%:%' + """, (self.route_origin_id,)) + self.ipv6 = rpki.resource_set.roa_prefix_set_ipv6.from_sql( + self.gctx.sql, + """ + SELECT address, prefixlen, max_prefixlen FROM route_origin_prefix + WHERE route_origin_id = %s AND address LIKE '%:%' + """, (self.route_origin_id,)) def sql_insert_hook(self): """Extra SQL insert actions for route_origin_elt -- handle address ranges.""" if self.ipv4 or self.ipv6: - self.gctx.cur.executemany(""" + self.gctx.sql.executemany(""" INSERT route_origin_prefix (route_origin_id, address, prefixlen, max_prefixlen) VALUES (%s, %s, %s, %s)""", ((self.route_origin_id, x.address, x.prefixlen, x.max_prefixlen) @@ -566,7 +570,7 @@ class route_origin_elt(data_elt): def sql_delete_hook(self): """Extra SQL delete actions for route_origin_elt -- handle address ranges.""" - self.gctx.cur.execute("DELETE FROM route_origin_prefix WHERE route_origin_id = %s", (self.route_origin_id,)) + self.gctx.sql.execute("DELETE FROM route_origin_prefix WHERE route_origin_id = %s", (self.route_origin_id,)) def ca_detail(self): """Fetch all ca_detail objects that link to this route_origin object.""" @@ -711,7 +715,7 @@ class route_origin_elt(data_elt): rpki.rpki_engine.revoked_cert_obj.revoke(cert = cert, ca_detail = ca_detail) repository.withdraw(roa, roa_uri) repository.withdraw(cert, ee_uri) - self.gctx.sql_sweep() + self.gctx.sql.sweep() ca_detail.generate_crl() ca_detail.generate_manifest() diff --git a/rpkid/rpki/resource_set.py b/rpkid/rpki/resource_set.py index 3d92b2ce..3b0ce4e5 100644 --- a/rpkid/rpki/resource_set.py +++ b/rpkid/rpki/resource_set.py @@ -298,18 +298,19 @@ class resource_set(list): return other.issubset(self) @classmethod - def from_sql(cls, cur, query, args = None): + def from_sql(cls, sql, query, args = None): """Create resource set from an SQL query. - cur is a DB API 2.0 cursor object. + sql is an object that supports execute() and fetchall() methods + like a DB API 2.0 cursor object. query is an SQL query that returns a sequence of (min, max) pairs. """ - cur.execute(query, args) + sql.execute(query, args) return cls(ini = [cls.range_type(cls.range_type.datum_type(b), cls.range_type.datum_type(e)) - for (b,e) in cur.fetchall()]) + for (b,e) in sql.fetchall()]) class resource_set_as(resource_set): """ASN resource set.""" @@ -667,15 +668,19 @@ class roa_prefix_set(list): return self.resource_set_type([p.to_resource_range() for p in self]) @classmethod - def from_sql(cls, cur, query, args = None): - """Create ROA prefix set from an SQL query. cur is a DB API 2.0 - cursor object. query is an SQL query that returns a sequence of - (address, prefixlen, max_prefixlen) triples. + def from_sql(cls, sql, query, args = None): + """Create ROA prefix set from an SQL query. + + sql is an object that supports execute() and fetchall() methods + like a DB API 2.0 cursor object. + + query is an SQL query that returns a sequence of (address, + prefixlen, max_prefixlen) triples. """ - cur.execute(query, args) + sql.execute(query, args) return cls([cls.prefix_type(cls.prefix_type.range_type.datum_type(x), int(y), int(z)) - for (x,y,z) in cur.fetchall()]) + for (x,y,z) in sql.fetchall()]) def to_roa_tuple(self): """Convert ROA prefix set into tuple format used by ROA ASN.1 encoder. diff --git a/rpkid/rpki/rpki_engine.py b/rpkid/rpki/rpki_engine.py index 18c17c48..cd16c21d 100644 --- a/rpkid/rpki/rpki_engine.py +++ b/rpkid/rpki/rpki_engine.py @@ -28,8 +28,7 @@ class rpkid_context(object): def __init__(self, cfg): - self.db = rpki.sql.connect(cfg) - self.cur = self.db.cursor() + self.sql = rpki.sql.session(cfg) self.bpki_ta = rpki.x509.X509(Auto_file = cfg.get("bpki-ta")) self.irdb_cert = rpki.x509.X509(Auto_file = cfg.get("irdb-cert")) @@ -44,9 +43,6 @@ class rpkid_context(object): self.publication_kludge_base = cfg.get("publication-kludge-base", "publication/") - self.sql_cache = {} - self.sql_dirty = set() - def irdb_query(self, self_id, child_id = None): """Perform an IRDB callback query. In the long run this should not be a blocking routine, it should instead issue a query and set up a @@ -81,35 +77,17 @@ class rpkid_context(object): v6 = r_msg[0].ipv6, valid_until = r_msg[0].valid_until) - def sql_cache_clear(self): - """Clear the object cache.""" - self.sql_cache.clear() - - def sql_assert_pristine(self): - """Assert that there are no dirty objects in the cache.""" - assert not self.sql_dirty, "Dirty objects in SQL cache: %s" % self.sql_dirty - - def sql_sweep(self): - """Write any dirty objects out to SQL.""" - for s in self.sql_dirty.copy(): - rpki.log.debug("Sweeping %s" % repr(s)) - if s.sql_deleted: - s.sql_delete() - else: - s.sql_store() - self.sql_assert_pristine() - def left_right_handler(self, query, path): """Process one left-right PDU.""" rpki.log.trace() try: - self.db.ping(True) + self.sql.ping() q_msg = rpki.left_right.cms_msg.unwrap(query, (self.bpki_ta, self.irbe_cert)) if q_msg.type != "query": raise rpki.exceptions.BadQuery, "Message type is not query" r_msg = q_msg.serve_top_level(self) reply = rpki.left_right.cms_msg.wrap(r_msg, self.rpkid_key, self.rpkid_cert) - self.sql_sweep() + self.sql.sweep() return 200, reply except Exception, data: rpki.log.error(traceback.format_exc()) @@ -119,7 +97,7 @@ class rpkid_context(object): """Process one up-down PDU.""" rpki.log.trace() try: - self.db.ping(True) + self.sql.ping() child_id = path.partition("/up-down/")[2] if not child_id.isdigit(): raise rpki.exceptions.BadContactURL, "Bad path: %s" % path @@ -127,7 +105,7 @@ class rpkid_context(object): if child is None: raise rpki.exceptions.ChildNotFound, "Could not find child %s" % child_id reply = child.serve_up_down(query) - self.sql_sweep() + self.sql.sweep() return 200, reply except Exception, data: rpki.log.error(traceback.format_exc()) @@ -140,13 +118,13 @@ class rpkid_context(object): rpki.log.trace() try: - self.db.ping(True) + self.sql.ping() for s in rpki.left_right.self_elt.sql_fetch_all(self): s.client_poll() s.update_children() s.update_roas() s.regenerate_crls_and_manifests() - self.sql_sweep() + self.sql.sweep() return 200, "OK" except Exception, data: rpki.log.error(traceback.format_exc()) @@ -694,7 +672,7 @@ class child_cert_obj(rpki.sql.sql_persistant): revoked_cert_obj.revoke(cert = self.cert, ca_detail = ca_detail) repository = ca.parent().repository() repository.withdraw(self.cert, self.uri(ca)) - self.gctx.sql_sweep() + self.gctx.sql.sweep() self.sql_delete() def reissue(self, ca_detail, resources = None, sia = None): diff --git a/rpkid/rpki/sql.py b/rpkid/rpki/sql.py index 81df728f..fa5927fa 100644 --- a/rpkid/rpki/sql.py +++ b/rpkid/rpki/sql.py @@ -15,59 +15,80 @@ # PERFORMANCE OF THIS SOFTWARE. import MySQLdb, time, warnings, _mysql_exceptions -import rpki.x509, rpki.resource_set, rpki.sundial +import rpki.x509, rpki.resource_set, rpki.sundial, rpki.log -def connect(cfg, throw_exception_on_warning = True): - """Connect to a MySQL database using connection parameters from an - rpki.config.parser object. - """ - - if throw_exception_on_warning: - warnings.simplefilter("error", _mysql_exceptions.Warning) - - return MySQLdb.connect(user = cfg.get("sql-username"), - db = cfg.get("sql-database"), - passwd = cfg.get("sql-password")) - -class sesssion(object): +class session(object): """SQL session layer.""" - def __init__(self, cfg): + _exceptions_enabled = False - raise rpki.errorsNotImplementedYet, "This class is still under construction" + def __init__(self, cfg): - warnings.simplefilter("error", _mysql_exceptions.Warning) + if not self._exceptions_enabled: + warnings.simplefilter("error", _mysql_exceptions.Warning) + self.__class__._exceptions_enabled = True self.username = cfg.get("sql-username") self.database = cfg.get("sql-database") self.password = cfg.get("sql-password") - self.sql_cache = {} - self.sql_dirty = set() + self.cache = {} + self.dirty = set() self.connect() def connect(self): - self.db = MySQLdb.connect(user = username, db = database, passwd = password) + self.db = MySQLdb.connect(user = self.username, db = self.database, passwd = self.password) self.cur = self.db.cursor() - def sql_cache_clear(self): + def close(self): + if self.cur: + self.cur.close() + self.cur = None + if self.db: + self.db.close() + self.db = None + + def ping(self): + return self.db.ping(True) + + def _wrap_execute(self, func, query, args): + try: + return func(query, args) + except _mysql_exceptions.MySQLError: + if self.dirty: + rpki.log.warn("MySQL exception with dirty objects in SQL cache!") + raise + + def execute(self, query, args = None): + return self._wrap_execute(self.cur.execute, query, args) + + def executemany(self, query, args): + return self._wrap_execute(self.cur.executemany, query, args) + + def fetchall(self): + return self.cur.fetchall() + + def lastrowid(self): + return self.cur.lastrowid + + def cache_clear(self): """Clear the object cache.""" - self.sql_cache.clear() + self.cache.clear() - def sql_assert_pristine(self): + def assert_pristine(self): """Assert that there are no dirty objects in the cache.""" - assert not self.sql_dirty, "Dirty objects in SQL cache: %s" % self.sql_dirty + assert not self.dirty, "Dirty objects in SQL cache: %s" % self.dirty - def sql_sweep(self): + def sweep(self): """Write any dirty objects out to SQL.""" - for s in self.sql_dirty.copy(): + for s in self.dirty.copy(): rpki.log.debug("Sweeping %s" % repr(s)) if s.sql_deleted: s.sql_delete() else: s.sql_store() - self.sql_assert_pristine() + self.assert_pristine() class template(object): """SQL template generator.""" @@ -124,8 +145,8 @@ class sql_persistant(object): return None assert isinstance(id, (int, long)), "id should be an integer, was %s" % repr(type(id)) key = (cls, id) - if key in gctx.sql_cache: - return gctx.sql_cache[key] + if key in gctx.sql.cache: + return gctx.sql.cache[key] else: return cls.sql_fetch_where1(gctx, "%s = %%s" % cls.sql_template.index, (id,)) @@ -154,17 +175,17 @@ class sql_persistant(object): assert args is None if cls.sql_debug: rpki.log.debug("sql_fetch_where(%s)" % repr(cls.sql_template.select)) - gctx.cur.execute(cls.sql_template.select) + gctx.sql.execute(cls.sql_template.select) else: query = cls.sql_template.select + " WHERE " + where if cls.sql_debug: rpki.log.debug("sql_fetch_where(%s, %s)" % (repr(query), repr(args))) - gctx.cur.execute(query, args) + gctx.sql.execute(query, args) results = [] - for row in gctx.cur.fetchall(): + for row in gctx.sql.fetchall(): key = (cls, row[0]) - if key in gctx.sql_cache: - results.append(gctx.sql_cache[key]) + if key in gctx.sql.cache: + results.append(gctx.sql.cache[key]) else: results.append(cls.sql_init(gctx, row, key)) return results @@ -175,22 +196,22 @@ class sql_persistant(object): self = cls() self.gctx = gctx self.sql_decode(dict(zip(cls.sql_template.columns, row))) - gctx.sql_cache[key] = self + gctx.sql.cache[key] = self self.sql_in_db = True self.sql_fetch_hook() return self def sql_mark_dirty(self): """Mark this object as needing to be written back to SQL.""" - self.gctx.sql_dirty.add(self) + self.gctx.sql.dirty.add(self) def sql_mark_clean(self): """Mark this object as not needing to be written back to SQL.""" - self.gctx.sql_dirty.discard(self) + self.gctx.sql.dirty.discard(self) def sql_is_dirty(self): """Query whether this object needs to be written back to SQL.""" - return self in self.gctx.sql_dirty + return self in self.gctx.sql.dirty def sql_mark_deleted(self): """Mark this object as needing to be deleted in SQL.""" @@ -199,15 +220,15 @@ class sql_persistant(object): def sql_store(self): """Store this object to SQL.""" if not self.sql_in_db: - self.gctx.cur.execute(self.sql_template.insert, self.sql_encode()) - setattr(self, self.sql_template.index, self.gctx.cur.lastrowid) - self.gctx.sql_cache[(self.__class__, self.gctx.cur.lastrowid)] = self + self.gctx.sql.execute(self.sql_template.insert, self.sql_encode()) + setattr(self, self.sql_template.index, self.gctx.sql.lastrowid()) + self.gctx.sql.cache[(self.__class__, self.gctx.sql.lastrowid())] = self self.sql_insert_hook() else: - self.gctx.cur.execute(self.sql_template.update, self.sql_encode()) + self.gctx.sql.execute(self.sql_template.update, self.sql_encode()) self.sql_update_hook() key = (self.__class__, getattr(self, self.sql_template.index)) - assert key in self.gctx.sql_cache and self.gctx.sql_cache[key] == self + assert key in self.gctx.sql.cache and self.gctx.sql.cache[key] == self self.sql_mark_clean() self.sql_in_db = True @@ -215,11 +236,11 @@ class sql_persistant(object): """Delete this object from SQL.""" if self.sql_in_db: id = getattr(self, self.sql_template.index) - self.gctx.cur.execute(self.sql_template.delete, id) + self.gctx.sql.execute(self.sql_template.delete, id) self.sql_delete_hook() key = (self.__class__, id) - if self.gctx.sql_cache.get(key) == self: - del self.gctx.sql_cache[key] + if self.gctx.sql.cache.get(key) == self: + del self.gctx.sql.cache[key] self.sql_in_db = False self.sql_mark_clean() diff --git a/rpkid/rpki/up_down.py b/rpkid/rpki/up_down.py index df5445e4..50e1d701 100644 --- a/rpkid/rpki/up_down.py +++ b/rpkid/rpki/up_down.py @@ -306,7 +306,7 @@ class issue_pdu(base_elt): resources = resources) # Save anything we modified and generate response - self.gctx.sql_sweep() + self.gctx.sql.sweep() assert child_cert and child_cert.sql_in_db c = certificate_elt() c.cert_url = multi_uri(child_cert.uri(ca)) @@ -365,7 +365,7 @@ class revoke_pdu(revoke_syntax): for ca_detail in child.ca_from_class_name(self.class_name).ca_details(): for child_cert in child.child_certs(ca_detail = ca_detail, ski = self.get_SKI()): child_cert.revoke() - self.gctx.sql_sweep() + self.gctx.sql.sweep() r_msg.payload = revoke_response_pdu() r_msg.payload.class_name = self.class_name r_msg.payload.ski = self.ski |