diff options
Diffstat (limited to 'scripts/rpki')
-rw-r--r-- | scripts/rpki/exceptions.py | 3 | ||||
-rw-r--r-- | scripts/rpki/left_right.py | 87 | ||||
-rw-r--r-- | scripts/rpki/resource_set.py | 8 | ||||
-rw-r--r-- | scripts/rpki/sql.py | 36 |
4 files changed, 71 insertions, 63 deletions
diff --git a/scripts/rpki/exceptions.py b/scripts/rpki/exceptions.py index 51327bff..1b241832 100644 --- a/scripts/rpki/exceptions.py +++ b/scripts/rpki/exceptions.py @@ -67,3 +67,6 @@ class SubprocessError(Exception): class BadIRDBReply(Exception): """Unexpected reply to IRDB query.""" + +class NotFound(Exception): + """Object not found in database.""" diff --git a/scripts/rpki/left_right.py b/scripts/rpki/left_right.py index 92fe2835..92aeffbc 100644 --- a/scripts/rpki/left_right.py +++ b/scripts/rpki/left_right.py @@ -113,48 +113,44 @@ class data_elt(base_elt, rpki.sql.sql_persistant): operate. This is a separate method because the self object needs to override it. """ - return self.sql_fetch_where1(gctx, "%s = %s AND self_id = %s" % (self.sql_template.index, getattr(self, self.sql_template.index), self.self_id)) + where = self.sql_template.index + " = %s AND self_id = %s" + args = (getattr(self, self.sql_template.index), self.self_id) + r = self.sql_fetch_where1(gctx, where, args) + if r is None: + raise rpki.exceptions.NotFound, "Lookup failed where %s" + (where % args) + return r def serve_set(self, gctx, r_msg): """Handle a set action.""" db_pdu = self.serve_fetch_one(gctx) - if db_pdu is not None: - r_pdu = self.make_reply() - for a in db_pdu.sql_template.columns[1:]: - v = getattr(self, a) - if v is not None: - setattr(db_pdu, a, v) - db_pdu.sql_mark_dirty() - db_pdu.serve_pre_save_hook(gctx, self, r_pdu) - db_pdu.sql_store(gctx) - db_pdu.serve_post_save_hook(gctx, self, r_pdu) - r_msg.append(r_pdu) - else: - r_msg.append(make_error_report(self)) + r_pdu = self.make_reply() + for a in db_pdu.sql_template.columns[1:]: + v = getattr(self, a) + if v is not None: + setattr(db_pdu, a, v) + db_pdu.sql_mark_dirty() + db_pdu.serve_pre_save_hook(gctx, self, r_pdu) + db_pdu.sql_store(gctx) + db_pdu.serve_post_save_hook(gctx, self, r_pdu) + r_msg.append(r_pdu) def serve_get(self, gctx, r_msg): """Handle a get action.""" r_pdu = self.serve_fetch_one(gctx) - if r_pdu is not None: - self.make_reply(r_pdu) - r_msg.append(r_pdu) - else: - r_msg.append(make_error_report(self)) + self.make_reply(r_pdu) + r_msg.append(r_pdu) def serve_list(self, gctx, r_msg): """Handle a list action for non-self objects.""" - for r_pdu in self.sql_fetch_where(gctx, "self_id = %s" % self.self_id): + for r_pdu in self.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)): self.make_reply(r_pdu) r_msg.append(r_pdu) def serve_destroy(self, gctx, r_msg): """Handle a destroy action.""" db_pdu = self.serve_fetch_one(gctx) - if db_pdu is not None: - db_pdu.sql_delete(gctx) - r_msg.append(self.make_reply()) - else: - r_msg.append(make_error_report(self)) + db_pdu.sql_delete(gctx) + r_msg.append(self.make_reply()) def serve_dispatch(self, gctx, r_msg): """Action dispatch handler.""" @@ -215,7 +211,7 @@ class self_elt(data_elt): def sql_fetch_hook(self, gctx): """Extra SQL fetch actions for self_elt -- handle extension preferences.""" - gctx.cur.execute("SELECT pref_name, pref_value FROM self_pref WHERE self_id = %s", self.self_id) + gctx.cur.execute("SELECT pref_name, pref_value FROM self_pref WHERE self_id = %s", (self.self_id,)) for name, value in gctx.cur.fetchall(): e = extension_preference_elt() e.name = name @@ -230,27 +226,27 @@ class self_elt(data_elt): def sql_delete_hook(self, gctx): """Extra SQL delete actions for self_elt -- handle extension preferences.""" - gctx.cur.execute("DELETE FROM self_pref WHERE self_id = %s", self.self_id) + gctx.cur.execute("DELETE FROM self_pref WHERE self_id = %s", (self.self_id,)) def bscs(self, gctx): """Fetch all BSC objects that link to this self object.""" - return bsc_elt.sql_fetch_where(gctx, "self_id = %s" % self.self_id) + return bsc_elt.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)) def repositories(self, gctx): """Fetch all repository objects that link to this self object.""" - return repository_elt.sql_fetch_where(gctx, "self_id = %s" % self.self_id) + return repository_elt.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)) def parents(self, gctx): """Fetch all parent objects that link to this self object.""" - return parent_elt.sql_fetch_where(gctx, "self_id = %s" % self.self_id) + return parent_elt.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)) def children(self, gctx): """Fetch all child objects that link to this self object.""" - return child_elt.sql_fetch_where(gctx, "self_id = %s" % self.self_id) + return child_elt.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)) def route_origins(self, gctx): """Fetch all route_origin objects that link to this self object.""" - return route_origin_elt.sql_fetch_where(gctx, "self_id = %s" % self.self_id) + return route_origin_elt.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)) def serve_pre_save_hook(self, gctx, q_pdu, r_pdu): """Extra server actions for self_elt -- handle extension preferences.""" @@ -288,7 +284,10 @@ class self_elt(data_elt): """Find the self object on which a get, set, or destroy method should operate. """ - return self.sql_fetch(gctx, self.self_id) + r = self.sql_fetch(gctx, self.self_id) + if r is None: + raise rpki.exceptions.NotFound + return r def serve_list(self, gctx, r_msg): """Handle a list action for self objects. This is different from @@ -427,7 +426,7 @@ class bsc_elt(data_elt): def sql_fetch_hook(self, gctx): """Extra SQL fetch actions for bsc_elt -- handle signing certs.""" - gctx.cur.execute("SELECT cert FROM bsc_cert WHERE bsc_id = %s", self.bsc_id) + gctx.cur.execute("SELECT cert FROM bsc_cert WHERE bsc_id = %s", (self.bsc_id,)) self.signing_cert[:] = [rpki.x509.X509(DER = x) for (x,) in gctx.cur.fetchall()] def sql_insert_hook(self, gctx): @@ -438,19 +437,19 @@ class bsc_elt(data_elt): def sql_delete_hook(self, gctx): """Extra SQL delete actions for bsc_elt -- handle signing certs.""" - gctx.cur.execute("DELETE FROM bsc_cert WHERE bsc_id = %s", self.bsc_id) + gctx.cur.execute("DELETE FROM bsc_cert WHERE bsc_id = %s", (self.bsc_id,)) def repositories(self, gctx): """Fetch all repository objects that link to this BSC object.""" - return repository_elt.sql_fetch_where(gctx, "bsc_id = %s" % self.bsc_id) + return repository_elt.sql_fetch_where(gctx, "bsc_id = %s", (self.bsc_id,)) def parents(self, gctx): """Fetch all parent objects that link to this BSC object.""" - return parent_elt.sql_fetch_where(gctx, "bsc_id = %s" % self.bsc_id) + return parent_elt.sql_fetch_where(gctx, "bsc_id = %s", (self.bsc_id,)) def children(self, gctx): """Fetch all child objects that link to this BSC object.""" - return child_elt.sql_fetch_where(gctx, "bsc_id = %s" % self.bsc_id) + return child_elt.sql_fetch_where(gctx, "bsc_id = %s", (self.bsc_id,)) def serve_pre_save_hook(self, gctx, q_pdu, r_pdu): """Extra server actions for bsc_elt -- handle signing certs and key generation.""" @@ -520,7 +519,7 @@ class parent_elt(data_elt): def cas(self, gctx): """Fetch all CA objects that link to this parent object.""" - return rpki.sql.ca_obj.sql_fetch_where(gctx, "parent_id = %s" % self.parent_id) + return rpki.sql.ca_obj.sql_fetch_where(gctx, "parent_id = %s", (self.parent_id,)) def serve_post_save_hook(self, gctx, q_pdu, r_pdu): """Extra server actions for parent_elt.""" @@ -639,7 +638,7 @@ class child_elt(data_elt): def parents(self, gctx): """Fetch all parent objects that link to self object to which this child object links.""" - return parent_elt.sql_fetch_where(gctx, "self_id = %s" % self.self_id) + return parent_elt.sql_fetch_where(gctx, "self_id = %s", (self.self_id,)) def ca_from_class_name(self, gctx, class_name): """Fetch the CA corresponding to an up-down class_name.""" @@ -724,7 +723,7 @@ class repository_elt(data_elt): def parents(self, gctx): """Fetch all parent objects that link to this repository object.""" - return parent_elt.sql_fetch_where(gctx, "repository_id = %s" % self.repository_id) + return parent_elt.sql_fetch_where(gctx, "repository_id = %s", (self.repository_id,)) def startElement(self, stack, name, attrs): """Handle <repository/> element.""" @@ -811,11 +810,11 @@ class route_origin_elt(data_elt): self.ipv4 = rpki.resource_set.resource_set_ipv4.from_sql(gctx.cur, """ SELECT start_ip, end_ip FROM route_origin_range WHERE route_origin_id = %s AND start_ip NOT LIKE '%:%' - """, self.route_origin_id) + """, (self.route_origin_id,)) self.ipv6 = rpki.resource_set.resource_set_ipv6.from_sql(gctx.cur, """ SELECT start_ip, end_ip FROM route_origin_range WHERE route_origin_id = %s AND start_ip LIKE '%:%' - """, self.route_origin_id) + """, (self.route_origin_id,)) def sql_insert_hook(self, gctx): """Extra SQL insert actions for route_origin_elt -- handle address ranges.""" @@ -827,7 +826,7 @@ class route_origin_elt(data_elt): def sql_delete_hook(self, gctx): """Extra SQL delete actions for route_origin_elt -- handle address ranges.""" - gctx.cur.execute("DELETE FROM route_origin_range WHERE route_origin_id = %s", self.route_origin_id) + gctx.cur.execute("DELETE FROM route_origin_range WHERE route_origin_id = %s", (self.route_origin_id,)) def ca_detail(self, gctx): """Fetch all ca_detail objects that link to this route_origin object.""" diff --git a/scripts/rpki/resource_set.py b/scripts/rpki/resource_set.py index ab2d3891..baf68d82 100644 --- a/scripts/rpki/resource_set.py +++ b/scripts/rpki/resource_set.py @@ -246,18 +246,18 @@ class resource_set(list): return other.issubset(self) @classmethod - def from_sql(cls, cursor, query): + def from_sql(cls, cur, query, args = None): """Create resource set from an SQL query. - cursor is a DB API 2.0 cursor object. + cur is a DB API 2.0 cursor object. query is an SQL query that returns a sequence of (min, max) pairs. """ - cursor.execute(query) + cur.execute(query, args) return cls(ini = [cls.range_type(cls.range_type.datum_type(b), cls.range_type.datum_type(e)) - for (b,e) in cursor.fetchall()]) + for (b,e) in cur.fetchall()]) class resource_set_as(resource_set): """ASN resource set.""" diff --git a/scripts/rpki/sql.py b/scripts/rpki/sql.py index 305fb07f..865ef5a2 100644 --- a/scripts/rpki/sql.py +++ b/scripts/rpki/sql.py @@ -77,17 +77,19 @@ class sql_persistant(object): if key in sql_cache: return sql_cache[key] else: - return cls.sql_fetch_where1(gctx, "%s = %s" % (cls.sql_template.index, id)) + return cls.sql_fetch_where1(gctx, "%s = %s", (cls.sql_template.index, id)) @classmethod - def sql_fetch_where1(cls, gctx, where): + def sql_fetch_where1(cls, gctx, where, args = None): """Fetch one object from SQL, based on an arbitrary SQL WHERE expression.""" - results = cls.sql_fetch_where(gctx, where) + results = cls.sql_fetch_where(gctx, where, args) if len(results) == 0: return None elif len(results) == 1: return results[0] else: + if args is not None: + where = where % args raise rpki.exceptions.DBConsistancyError, \ "Database contained multiple matches for %s where %s" % (cls.__name__, where) @@ -97,12 +99,12 @@ class sql_persistant(object): return cls.sql_fetch_where(gctx, None) @classmethod - def sql_fetch_where(cls, gctx, where): + def sql_fetch_where(cls, gctx, where, args = None): """Fetch objects of this type matching an arbitrary SQL WHERE expression.""" if where is None: gctx.cur.execute(cls.sql_template.select) else: - gctx.cur.execute(cls.sql_template.select + " WHERE " + where) + gctx.cur.execute(cls.sql_template.select + " WHERE " + where, args) results = [] for row in gctx.cur.fetchall(): key = (cls, row[0]) @@ -225,11 +227,11 @@ class ca_obj(sql_persistant): def ca_details(self, gctx): """Fetch all ca_detail objects that link to this CA object.""" - return ca_detail_obj.sql_fetch_where(gctx, "ca_id = %s" % self.ca_id) + return ca_detail_obj.sql_fetch_where(gctx, "ca_id = %s", (self.ca_id,)) def fetch_active(self, gctx): """Fetch the active ca_detail for this CA, if any.""" - return ca_detail_obj.sql_fetch_where1(gctx, "ca_id = %s AND state = 'active'" % self.ca_id) + return ca_detail_obj.sql_fetch_where1(gctx, "ca_id = %s AND state = 'active'", (self.ca_id,)) def construct_sia_uri(self, gctx, parent, rc): """Construct the sia_uri value for this CA given configured @@ -260,7 +262,7 @@ class ca_obj(sql_persistant): rc_resources = rc.to_resource_bag() cert_map = dict((c.cert.get_SKI(), c) for c in rc.certs) - for ca_detail in ca_detail_obj.sql_fetch_where(gctx, "ca_id = %s AND latest_ca_cert IS NOT NULL" % self.ca_id): + for ca_detail in ca_detail_obj.sql_fetch_where(gctx, "ca_id = %s AND latest_ca_cert IS NOT NULL", (self.ca_id,)): ski = ca_detail.latest_ca_cert.get_SKI() if ca_detail.state != "deprecated": current_resources = ca_detail.latest_ca_cert.get_3779resources() @@ -375,7 +377,7 @@ class ca_detail_obj(sql_persistant): def route_origins(self, gctx): """Fetch all route_origin objects that link to this ca_detail.""" - return rpki.left_right.route_origin_elt.sql_fetch_where(gctx, "ca_detail_id = %s" % self.ca_detail_id) + return rpki.left_right.route_origin_elt.sql_fetch_where(gctx, "ca_detail_id = %s", (self.ca_detail_id,)) def crl_uri(self, ca): """Return publication URI for this ca_detail's CRL.""" @@ -431,7 +433,7 @@ class ca_detail_obj(sql_persistant): # This will need a callback when we go event-driven issue_response = rpki.up_down.issue_pdu.query(gctx, parent, ca, self) - self.latest_ca_cert = issue_response.classes[0].certs[0].cert + self.latest_ca_cert = issue_response.payload.classes[0].certs[0].cert new_resources = self.latest_ca_cert.get_3779resources() if sia_uri_changed or old_resources.oversized(new_resources): @@ -649,17 +651,21 @@ class child_cert_obj(sql_persistant): code calls this indirectly, through methods in other classes. """ + args = [] if revoked: where = "revoked IS NOT NULL" else: where = "revoked IS NULL" if child: - where += " AND child_id = %s" % child.child_id + where += " AND child_id = %s" + args.append(child.child_id) if ca_detail: - where += " AND ca_detail_id = %s" % ca_detail.ca_detail_id + where += " AND ca_detail_id = %s" + args.append(ca_detail.ca_detail_id) if ski: - where += " AND ski = '%s'" % ski + where += " AND ski = %s" + args.append(ski) if unique: - return cls.sql_fetch_where1(gctx, where) + return cls.sql_fetch_where1(gctx, where, args) else: - return cls.sql_fetch_where(gctx, where) + return cls.sql_fetch_where(gctx, where, args) |