diff options
Diffstat (limited to 'rpki/pubd.py')
-rw-r--r-- | rpki/pubd.py | 247 |
1 files changed, 218 insertions, 29 deletions
diff --git a/rpki/pubd.py b/rpki/pubd.py index 79315a78..0ee4d38c 100644 --- a/rpki/pubd.py +++ b/rpki/pubd.py @@ -23,6 +23,7 @@ RPKI publication engine. import os import re +import uuid import time import logging import argparse @@ -36,8 +37,11 @@ import rpki.exceptions import rpki.relaxng import rpki.log import rpki.publication +import rpki.publication_control import rpki.daemonize +from lxml.etree import Element, SubElement, ElementTree, Comment + logger = logging.getLogger(__name__) class main(object): @@ -96,6 +100,7 @@ class main(object): self.irbe_cert = rpki.x509.X509(Auto_update = self.cfg.get("irbe-cert")) self.pubd_cert = rpki.x509.X509(Auto_update = self.cfg.get("pubd-cert")) self.pubd_key = rpki.x509.RSA( Auto_update = self.cfg.get("pubd-key")) + self.pubd_crl = rpki.x509.CRL( Auto_update = self.cfg.get("pubd-crl")) self.http_server_host = self.cfg.get("server-host", "") self.http_server_port = self.cfg.getint("server-port") @@ -104,45 +109,39 @@ class main(object): self.publication_multimodule = self.cfg.getboolean("publication-multimodule", False) + self.rrdp_expiration_interval = rpki.sundial.timedelta.parse(self.cfg.get("rrdp-expiration-interval", "6h")) + self.rrdp_publication_base = self.cfg.get("rrdp-publication-base", "rrdp-publication/") + + self.session = session_obj.fetch(self) + rpki.http.server( host = self.http_server_host, port = self.http_server_port, handlers = (("/control", self.control_handler), ("/client/", self.client_handler))) - def handler_common(self, query, client, cb, certs, crl = None): - """ - Common PDU handler code. - """ - - def done(r_msg): - reply = rpki.publication.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert, crl) - self.sql.sweep() - cb(reply) - - q_cms = rpki.publication.cms_msg(DER = query) - q_msg = q_cms.unwrap(certs) - if client is None: - self.irbe_cms_timestamp = q_cms.check_replay(self.irbe_cms_timestamp, "control") - else: - q_cms.check_replay_sql(client, client.client_handle) - q_msg.serve_top_level(self, client, done) def control_handler(self, query, path, cb): """ Process one PDU from the IRBE. """ - def done(body): - cb(200, body = body) + def done(r_msg): + self.sql.sweep() + cb(code = 200, + body = rpki.publication_control.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert)) try: - self.handler_common(query, None, done, (self.bpki_ta, self.irbe_cert)) + q_cms = rpki.publication_control.cms_msg(DER = query) + q_msg = q_cms.unwrap((self.bpki_ta, self.irbe_cert)) + self.irbe_cms_timestamp = q_cms.check_replay(self.irbe_cms_timestamp, "control") + q_msg.serve_top_level(self, done) except (rpki.async.ExitNow, SystemExit): raise except Exception, e: logger.exception("Unhandled exception processing control query, path %r", path) - cb(500, reason = "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + cb(code = 500, reason = "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + client_url_regexp = re.compile("/client/([-A-Z0-9_/]+)$", re.I) @@ -151,23 +150,213 @@ class main(object): Process one PDU from a client. """ - def done(body): - cb(200, body = body) + def done(r_msg): + self.sql.sweep() + cb(code = 200, + body = rpki.publication.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert, self.pubd_crl)) try: match = self.client_url_regexp.search(path) if match is None: raise rpki.exceptions.BadContactURL("Bad path: %s" % path) client_handle = match.group(1) - client = rpki.publication.client_elt.sql_fetch_where1(self, "client_handle = %s", (client_handle,)) + client = rpki.publication_control.client_elt.sql_fetch_where1(self, "client_handle = %s", (client_handle,)) if client is None: raise rpki.exceptions.ClientNotFound("Could not find client %s" % client_handle) - config = rpki.publication.config_elt.fetch(self) - if config is None or config.bpki_crl is None: - raise rpki.exceptions.CMSCRLNotSet - self.handler_common(query, client, done, (self.bpki_ta, client.bpki_cert, client.bpki_glue), config.bpki_crl) + q_cms = rpki.publication.cms_msg(DER = query) + q_msg = q_cms.unwrap((self.bpki_ta, client.bpki_cert, client.bpki_glue)) + q_cms.check_replay_sql(client, client.client_handle) + q_msg.serve_top_level(self, client, done) except (rpki.async.ExitNow, SystemExit): raise except Exception, e: logger.exception("Unhandled exception processing client query, path %r", path) - cb(500, reason = "Could not process PDU: %s" % e) + cb(code = 500, + reason = "Could not process PDU: %s" % e) + + +class session_obj(rpki.sql.sql_persistent): + """ + An RRDP session. + """ + + # We probably need additional columns or an additional table to + # handle cleanup of old serial numbers. Not sure quite what these + # would look like, other than that the SQL datatypes are probably + # BIGINT and DATETIME. Maybe a table to track time at which we + # retired a particular serial number, or, to save us the arithmetic, + # the corresponding cleanup time? + + sql_template = rpki.sql.template( + "session", + "session_id", + "uuid") + + def __repr__(self): + return rpki.log.log_repr(self, self.uuid, self.serial) + + @classmethod + def fetch(cls, gctx): + """ + Fetch the one and only session, creating it if necessary. + """ + + self = cls.sql_fetch(gctx, 1) + if self is None: + self = cls() + self.gctx = gctx + self.session_id = 1 + self.uuid = uuid.uuid4() + self.sql_store() + return self + + @property + def objects(self): + return object_obj.sql_fetch_where(self.gctx, "session_id = %s", (self.session_id,)) + + @property + def snapshots(self): + return snapshot_obj.sql_fetch_where(self.gctx, "session_id = %s", (self.session_id,)) + + @property + def current_snapshot(self): + return snapshot_obj.sql_fetch_where1(self.gctx, + "session_id = %s AND activated IS NOT NULL AND expires IS NULL", + (self.session_id,)) + + def new_snapshot(self): + return snapshot_obj.create(self) + + def add_snapshot(self, new_snapshot): + now = rpki.sundial.now() + old_snapshot = self.current_snapshot + if old_snapshot is not None: + old_snapshot.expires = now + self.gctx.rrdp_expiration_interval + old_snapshot.sql_store() + new_snapshot.activated = now + new_snapshot.sql_store() + + def expire_snapshots(self): + for snapshot in snapshot_obj.sql_fetch_where(self.gctx, + "session_id = %s AND expires IS NOT NULL AND expires < %s", + (self.session_id, rpki.sundial.now())): + snapshot.sql_delete() + + +class snapshot_obj(rpki.sql.sql_persistent): + """ + An RRDP session snapshot. + """ + + sql_template = rpki.sql.template( + "snapshot", + "snapshot_id", + ("activated", rpki.sundial.datetime), + ("expires", rpki.sundial.datetime), + "session_id") + + @property + @rpki.sql.cache_reference + def session(self): + return session_obj.sql_fetch(self.gctx, self.session_id) + + @classmethod + def create(cls, session): + self = cls() + self.gctx = session.gctx + self.session_id = session.session_id + self.activated = None + self.expires = None + self.sql_store() + return self + + @property + def serial(self): + """ + I know that using an SQL ID for any other purpose is usually a bad + idea, but in this case it has exactly the right properties, and we + really do want both the autoincrement behavior and the foreign key + behavior to tie to the snapshot serial numbers. So risk it. + + Well, OK, only almost the right properties. auto-increment + probably does not back up if we ROLLBACK, which could leave gaps + in the sequence. So may need to rework this, eg, to use a serial + field in the session object. Ignore the issue until we have the + rest of this working. + """ + + return self.snapshot_id + + def publish(self, client, obj, uri, hash): + if hash is not None: + self.withdraw(client, uri, hash) + if object_obj.current_object_at_uri(client, self, uri) is not None: + raise rpki.exceptions.ExistingObjectAtURI("Object already published at %s" % uri) + logger.debug("Publishing %s", uri) + return object_obj.create(client, self, obj, uri) + + def withdraw(self, client, uri, hash): + obj = object_obj.current_object_at_uri(client, self, uri) + if obj is None: + raise rpki.exceptions.NoObjectAtURI("No object published at %s" % uri) + if obj.hash != hash: + raise rpki.exceptions.DifferentObjectAtURI("Found different object at %s (%s, %s)" % (uri, obj.hash, hash)) + logger.debug("Withdrawing %s", uri) + obj.delete(self) + + + +class object_obj(rpki.sql.sql_persistent): + """ + A published object. + """ + + sql_template = rpki.sql.template( + "object", + "object_id", + "uri", + "hash", + "payload", + "published_snapshot_id", + "withdrawn_snapshot_id", + "client_id", + "session_id") + + def __repr__(self): + return rpki.log.log_repr(self, self.uri, self.published_snapshot_id, self.withdrawn_snapshot_id) + + @property + @rpki.sql.cache_reference + def session(self): + return session_obj.sql_fetch(self.gctx, self.session_id) + + @property + @rpki.sql.cache_reference + def client(self): + return rpki.publication_control.client_elt.sql_fetch(self.gctx, self.client_id) + + @classmethod + def create(cls, client, snapshot, obj, uri): + self = cls() + self.gctx = snapshot.gctx + self.uri = uri + self.payload = obj + self.hash = rpki.x509.sha256(obj.get_Base64()).encode("hex") + logger.debug("Computed hash %s of %r", self.hash, obj) + self.published_snapshot_id = snapshot.snapshot_id + self.withdrawn_snapshot_id = None + self.session_id = snapshot.session_id + self.client_id = client.client_id + self.sql_mark_dirty() + return self + + def delete(self, snapshot): + self.withdrawn_snapshot_id = snapshot.snapshot_id + #self.sql_mark_dirty() + self.sql_store() + + @classmethod + def current_object_at_uri(cls, client, snapshot, uri): + return cls.sql_fetch_where1(client.gctx, + "session_id = %s AND client_id = %s AND withdrawn_snapshot_id IS NULL AND uri = %s", + (snapshot.session_id, client.client_id, uri)) |