diff options
Diffstat (limited to 'rpki/rootd.py')
-rw-r--r-- | rpki/rootd.py | 260 |
1 files changed, 180 insertions, 80 deletions
diff --git a/rpki/rootd.py b/rpki/rootd.py index e912a846..6b6aa0fa 100644 --- a/rpki/rootd.py +++ b/rpki/rootd.py @@ -46,15 +46,13 @@ rootd = None class list_pdu(rpki.up_down.list_pdu): def serve_pdu(self, q_msg, r_msg, ignored, callback, errback): r_msg.payload = rpki.up_down.list_response_pdu() - rootd.compose_response(r_msg) - callback() + rootd.compose_response(r_msg, callback, errback) class issue_pdu(rpki.up_down.issue_pdu): def serve_pdu(self, q_msg, r_msg, ignored, callback, errback): self.pkcs10.check_valid_request_ca() r_msg.payload = rpki.up_down.issue_response_pdu() - rootd.compose_response(r_msg, self.pkcs10) - callback() + rootd.compose_response(r_msg, callback, errback, self.pkcs10) class revoke_pdu(rpki.up_down.revoke_pdu): def serve_pdu(self, q_msg, r_msg, ignored, callback, errback): @@ -68,14 +66,15 @@ class revoke_pdu(rpki.up_down.revoke_pdu): raise rpki.exceptions.NotInDatabase logger.debug("Revoking certificate %s", self.ski) now = rpki.sundial.now() + pubd_msg = rpki.publication.msg.query() rootd.revoke_subject_cert(now) rootd.del_subject_cert() rootd.del_subject_pkcs10() - rootd.generate_crl_and_manifest(now) r_msg.payload = rpki.up_down.revoke_response_pdu() r_msg.payload.class_name = self.class_name r_msg.payload.ski = self.ski - callback() + rootd.generate_crl_and_manifest(now, pubd_msg) + rootd.publish(callback, errback, pubd_msg) class error_response_pdu(rpki.up_down.error_response_pdu): exceptions = rpki.up_down.error_response_pdu.exceptions.copy() @@ -84,23 +83,18 @@ class error_response_pdu(rpki.up_down.error_response_pdu): class message_pdu(rpki.up_down.message_pdu): - name2type = { - "list" : list_pdu, - "list_response" : rpki.up_down.list_response_pdu, - "issue" : issue_pdu, - "issue_response" : rpki.up_down.issue_response_pdu, - "revoke" : revoke_pdu, - "revoke_response" : rpki.up_down.revoke_response_pdu, - "error_response" : error_response_pdu } + name2type = dict( + rpki.up_down.message_pdu.name2type, + list = list_pdu, + issue = issue_pdu, + revoke = revoke_pdu, + error_response = error_response_pdu) - type2name = dict((v, k) for k, v in name2type.items()) + type2name = dict((v, k) for k, v in name2type.iteritems()) error_pdu_type = error_response_pdu def log_query(self, child): - """ - Log query we're handling. - """ logger.info("Serving %s query", self.type) class sax_handler(rpki.up_down.sax_handler): @@ -111,34 +105,30 @@ class cms_msg(rpki.up_down.cms_msg): class main(object): - def get_root_cert(self): - logger.debug("Read root cert %s", self.rpki_root_cert_file) - self.rpki_root_cert = rpki.x509.X509(Auto_file = self.rpki_root_cert_file) def root_newer_than_subject(self): - return os.stat(self.rpki_root_cert_file).st_mtime > \ - os.stat(os.path.join(self.rpki_root_dir, self.rpki_subject_cert)).st_mtime + return self.rpki_root_cert.mtime > os.stat(self.rpki_subject_cert_file).st_mtime + def get_subject_cert(self): - filename = os.path.join(self.rpki_root_dir, self.rpki_subject_cert) try: - x = rpki.x509.X509(Auto_file = filename) - logger.debug("Read subject cert %s", filename) + x = rpki.x509.X509(Auto_file = self.rpki_subject_cert_file) + logger.debug("Read subject cert %s", self.rpki_subject_cert_file) return x except IOError: return None + def set_subject_cert(self, cert): - filename = os.path.join(self.rpki_root_dir, self.rpki_subject_cert) - logger.debug("Writing subject cert %s, SKI %s", filename, cert.hSKI()) - f = open(filename, "wb") - f.write(cert.get_DER()) - f.close() + logger.debug("Writing subject cert %s, SKI %s", self.rpki_subject_cert_file, cert.hSKI()) + with open(self.rpki_subject_cert_file, "wb") as f: + f.write(cert.get_DER()) + def del_subject_cert(self): - filename = os.path.join(self.rpki_root_dir, self.rpki_subject_cert) - logger.debug("Deleting subject cert %s", filename) - os.remove(filename) + logger.debug("Deleting subject cert %s", self.rpki_subject_cert_file) + os.remove(self.rpki_subject_cert_file) + def get_subject_pkcs10(self): try: @@ -148,11 +138,12 @@ class main(object): except IOError: return None + def set_subject_pkcs10(self, pkcs10): logger.debug("Writing subject PKCS #10 %s", self.rpki_subject_pkcs10) - f = open(self.rpki_subject_pkcs10, "wb") - f.write(pkcs10.get_DER()) - f.close() + with open(self.rpki_subject_pkcs10, "wb") as f: + f.write(pkcs10.get_DER()) + def del_subject_pkcs10(self): logger.debug("Deleting subject PKCS #10 %s", self.rpki_subject_pkcs10) @@ -161,9 +152,11 @@ class main(object): except OSError: pass + def issue_subject_cert_maybe(self, new_pkcs10): now = rpki.sundial.now() subject_cert = self.get_subject_cert() + hash = None if subject_cert is None else rpki.x509.sha256(subject_cert.get_DER()).encode("hex") old_pkcs10 = self.get_subject_pkcs10() if new_pkcs10 is not None and new_pkcs10 != old_pkcs10: self.set_subject_pkcs10(new_pkcs10) @@ -179,17 +172,16 @@ class main(object): logger.debug("Root certificate has changed, regenerating subject") self.revoke_subject_cert(now) subject_cert = None - self.get_root_cert() if subject_cert is not None: - return subject_cert + return subject_cert, None pkcs10 = old_pkcs10 if new_pkcs10 is None else new_pkcs10 if pkcs10 is None: logger.debug("No PKCS #10 request, can't generate subject certificate yet") - return None + return None, None resources = self.rpki_root_cert.get_3779resources() notAfter = now + self.rpki_subject_lifetime logger.info("Generating subject cert %s with resources %s, expires %s", - self.rpki_base_uri + self.rpki_subject_cert, resources, notAfter) + self.rpki_subject_cert_uri, resources, notAfter) req_key = pkcs10.getPublicKey() req_sia = pkcs10.get_SIA() self.next_serial_number() @@ -199,15 +191,21 @@ class main(object): serial = self.serial_number, sia = req_sia, aia = self.rpki_root_cert_uri, - crldp = self.rpki_base_uri + self.rpki_root_crl, + crldp = self.rpki_root_crl_uri, resources = resources, notBefore = now, notAfter = notAfter) self.set_subject_cert(subject_cert) - self.generate_crl_and_manifest(now) - return subject_cert + pubd_msg = rpki.publication.msg.query() + pubd_msg.append(rpki.publication.publish_elt.make_pdu( + uri = self.rpki_subject_cert_uri, + hash = hash, + der = subject_cert.get_DER())) + self.generate_crl_and_manifest(now, pubd_msg) + return subject_cert, pubd_msg + - def generate_crl_and_manifest(self, now): + def generate_crl_and_manifest(self, now, pubd_msg): subject_cert = self.get_subject_cert() self.next_serial_number() self.next_crl_number() @@ -220,23 +218,26 @@ class main(object): thisUpdate = now, nextUpdate = now + self.rpki_subject_regen, revokedCertificates = self.revoked) - fn = os.path.join(self.rpki_root_dir, self.rpki_root_crl) - logger.debug("Writing CRL %s", fn) - f = open(fn, "wb") - f.write(crl.get_DER()) - f.close() - manifest_content = [(self.rpki_root_crl, crl)] + hash = self.read_hash_maybe(self.rpki_root_crl_file) + logger.debug("Writing CRL %s", self.rpki_root_crl_file) + with open(self.rpki_root_crl_file, "wb") as f: + f.write(crl.get_DER()) + pubd_msg.append(rpki.publication.publish_elt.make_pdu( + uri = self.rpki_root_crl_uri, + hash = hash, + der = crl.get_DER())) + manifest_content = [(os.path.basename(self.rpki_root_crl_uri), crl)] if subject_cert is not None: - manifest_content.append((self.rpki_subject_cert, subject_cert)) + manifest_content.append((os.path.basename(self.rpki_subject_cert_uri), subject_cert)) manifest_resources = rpki.resource_set.resource_bag.from_inheritance() manifest_keypair = rpki.x509.RSA.generate() manifest_cert = self.rpki_root_cert.issue( keypair = self.rpki_root_key, subject_key = manifest_keypair.get_public(), serial = self.serial_number, - sia = (None, None, self.rpki_base_uri + self.rpki_root_manifest), + sia = (None, None, self.rpki_root_manifest_uri), aia = self.rpki_root_cert_uri, - crldp = self.rpki_base_uri + self.rpki_root_crl, + crldp = self.rpki_root_crl_uri, resources = manifest_resources, notBefore = now, notAfter = now + self.rpki_subject_lifetime, @@ -248,17 +249,42 @@ class main(object): names_and_objs = manifest_content, keypair = manifest_keypair, certs = manifest_cert) - fn = os.path.join(self.rpki_root_dir, self.rpki_root_manifest) - logger.debug("Writing manifest %s", fn) - f = open(fn, "wb") - f.write(manifest.get_DER()) - f.close() + hash = self.read_hash_maybe(self.rpki_root_manifest_file) + logger.debug("Writing manifest %s", self.rpki_root_manifest_file) + with open(self.rpki_root_manifest_file, "wb") as f: + f.write(manifest.get_DER()) + pubd_msg.append(rpki.publication.publish_elt.make_pdu( + uri = self.rpki_root_manifest_uri, + hash = hash, + der = manifest.get_DER())) + hash = rpki.x509.sha256(self.rpki_root_cert.get_DER()).encode("hex") + if hash != self.rpki_root_cert_hash: + pubd_msg.append(rpki.publication.publish_elt.make_pdu( + uri = self.rpki_root_cert_uri, + hash = self.rpki_root_cert_hash, + der = self.rpki_root_cert.get_DER())) + self.rpki_root_cert_hash = hash + + + @staticmethod + def read_hash_maybe(fn): + """ + Return hash of an existing object, or None. + """ + + try: + with open(fn, "rb") as f: + return rpki.x509.sha256(f.read()).encode("hex") + except IOError: + return None + def revoke_subject_cert(self, now): self.revoked.append((self.get_subject_cert().getSerial(), now)) - def compose_response(self, r_msg, pkcs10 = None): - subject_cert = self.issue_subject_cert_maybe(pkcs10) + + def compose_response(self, r_msg, callback, errback, pkcs10 = None): + subject_cert, pubd_msg = self.issue_subject_cert_maybe(pkcs10) rc = rpki.up_down.class_elt() rc.class_name = self.rpki_class_name rc.cert_url = rpki.up_down.multi_uri(self.rpki_root_cert_uri) @@ -267,14 +293,78 @@ class main(object): r_msg.payload.classes.append(rc) if subject_cert is not None: rc.certs.append(rpki.up_down.certificate_elt()) - rc.certs[0].cert_url = rpki.up_down.multi_uri(self.rpki_base_uri + self.rpki_subject_cert) + rc.certs[0].cert_url = rpki.up_down.multi_uri(self.rpki_subject_cert_uri) rc.certs[0].cert = subject_cert + self.publish(callback, errback, pubd_msg) + + + def publish(self, callback, errback, q_msg): + + def done(r_msg): + if len(q_msg) != len(r_msg): + raise rpki.exceptions.BadPublicationReply("Wrong number of response PDUs from pubd: sent %r, got %r" % (q_msg, r_msg)) + callback() + + def fix_hashes(r_msg): + published_hash = dict((r_pdu.uri, r_pdu.hash) for r_pdu in r_msg) + for q_pdu in q_msg: + if q_pdu.hash is None and published_hash.get(q_pdu.uri) is not None: + logger.debug("Updating hash of %r to %s from previously published data", q_pdu, published_hash[q_pdu.uri]) + q_pdu.hash = published_hash[q_pdu.uri] + self.call_pubd(done, errback, q_msg) + + if not q_msg: + callback() + elif all(q_pdu.hash is not None for q_pdu in q_msg): + self.call_pubd(done, errback, q_msg) + else: + logger.debug("Some publication PDUs are missing hashes, checking...") + self.call_pubd(fix_hashes, errback, rpki.publication.msg.query(rpki.publication.list_elt())) + + + def call_pubd(self, callback, errback, q_msg): + + try: + if not q_msg: + return callback(()) + + for q_pdu in q_msg: + logger.info("Sending %r to pubd", q_pdu) + + q_der = rpki.publication.cms_msg().wrap(q_msg, self.rootd_bpki_key, self.rootd_bpki_cert, self.rootd_bpki_crl) + + def done(r_der): + try: + logger.debug("Received response from pubd") + r_cms = rpki.publication.cms_msg(DER = r_der) + r_msg = r_cms.unwrap((self.bpki_ta, self.pubd_bpki_cert)) + self.pubd_cms_timestamp = r_cms.check_replay(self.pubd_cms_timestamp, self.pubd_contact_uri) + for r_pdu in r_msg: + r_pdu.raise_if_error() + callback(r_msg) + except (rpki.async.ExitNow, SystemExit): + raise + except Exception, e: + errback(e) + + logger.debug("Sending request to pubd") + rpki.http.client( + url = self.pubd_contact_uri, + msg = q_der, + callback = done, + errback = errback) + + except (rpki.async.ExitNow, SystemExit): + raise + except Exception, e: + errback(e) + def up_down_handler(self, query, path, cb): try: q_cms = cms_msg(DER = query) q_msg = q_cms.unwrap((self.bpki_ta, self.child_bpki_cert)) - self.cms_timestamp = q_cms.check_replay(self.cms_timestamp, path) + self.rpkid_cms_timestamp = q_cms.check_replay(self.rpkid_cms_timestamp, path) except (rpki.async.ExitNow, SystemExit): raise except Exception, e: @@ -304,7 +394,7 @@ class main(object): def next_crl_number(self): if self.crl_number is None: try: - crl = rpki.x509.CRL(DER_file = os.path.join(self.rpki_root_dir, self.rpki_root_crl)) + crl = rpki.x509.CRL(DER_file = self.rpki_root_crl_file) self.crl_number = crl.getCRLNumber() except: # pylint: disable=W0702 self.crl_number = 0 @@ -328,11 +418,11 @@ class main(object): global rootd rootd = self # Gross, but simpler than what we'd have to do otherwise - self.rpki_root_cert = None self.serial_number = None self.crl_number = None self.revoked = [] - self.cms_timestamp = None + self.rpkid_cms_timestamp = None + self.pubd_cms_timestamp = None os.environ["TZ"] = "UTC" time.tzset() @@ -359,28 +449,38 @@ class main(object): self.rootd_bpki_crl = rpki.x509.CRL( Auto_update = self.cfg.get("rootd-bpki-crl")) self.child_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("child-bpki-cert")) + if self.cfg.has_option("pubd-bpki-cert"): + self.pubd_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("pubd-bpki-cert")) + else: + self.pubd_bpki_cert = None + self.http_server_host = self.cfg.get("server-host", "") self.http_server_port = self.cfg.getint("server-port") - self.rpki_class_name = self.cfg.get("rpki-class-name", "wombat") + self.rpki_class_name = self.cfg.get("rpki-class-name") - self.rpki_root_dir = self.cfg.get("rpki-root-dir") - self.rpki_base_uri = self.cfg.get("rpki-base-uri", "rsync://" + self.rpki_class_name + ".invalid/") + self.rpki_root_key = rpki.x509.RSA( Auto_update = self.cfg.get("rpki-root-key-file")) + self.rpki_root_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpki-root-cert-file")) + self.rpki_root_cert_uri = self.cfg.get("rpki-root-cert-uri") + self.rpki_root_cert_hash = None - self.rpki_root_key = rpki.x509.RSA(Auto_update = self.cfg.get("rpki-root-key")) - self.rpki_root_cert_file = self.cfg.get("rpki-root-cert") - self.rpki_root_cert_uri = self.cfg.get("rpki-root-cert-uri", self.rpki_base_uri + "root.cer") + self.rpki_root_manifest_file = self.cfg.get("rpki-root-manifest-file") + self.rpki_root_manifest_uri = self.cfg.get("rpki-root-manifest-uri") - self.rpki_root_manifest = self.cfg.get("rpki-root-manifest", "root.mft") - self.rpki_root_crl = self.cfg.get("rpki-root-crl", "root.crl") - self.rpki_subject_cert = self.cfg.get("rpki-subject-cert", "child.cer") - self.rpki_subject_pkcs10 = self.cfg.get("rpki-subject-pkcs10", "child.pkcs10") + self.rpki_root_crl_file = self.cfg.get("rpki-root-crl-file") + self.rpki_root_crl_uri = self.cfg.get("rpki-root-crl-uri") + self.rpki_subject_cert_file = self.cfg.get("rpki-subject-cert-file") + self.rpki_subject_cert_uri = self.cfg.get("rpki-subject-cert-uri") + self.rpki_subject_pkcs10 = self.cfg.get("rpki-subject-pkcs10-file") self.rpki_subject_lifetime = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-lifetime", "8w")) - self.rpki_subject_regen = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-regen", self.rpki_subject_lifetime.convert_to_seconds() / 2)) + self.rpki_subject_regen = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-regen", + self.rpki_subject_lifetime.convert_to_seconds() / 2)) self.include_bpki_crl = self.cfg.getboolean("include-bpki-crl", False) - rpki.http.server(host = self.http_server_host, - port = self.http_server_port, - handlers = self.up_down_handler) + self.pubd_contact_uri = self.cfg.get("pubd-contact-uri") + + rpki.http.server(host = self.http_server_host, + port = self.http_server_port, + handlers = self.up_down_handler) |