#!/usr/local/bin/python # $Id$ # Copyright (C) 2013 Dragon Research Labs ("DRL") # # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND DRL DISCLAIMS ALL WARRANTIES WITH # REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY # AND FITNESS. IN NO EVENT SHALL DRL BE LIABLE FOR ANY SPECIAL, DIRECT, # INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM # LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE # OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR # PERFORMANCE OF THIS SOFTWARE. ######################################################################## # # DANGER WILL ROBINSON # # This is a PROTOTYPE of a local trust anchor mechanism. At the # moment, it DOES NOT WORK by any sane standard of measurement. It # produces output, but there is no particular reason to believe said # output is useful, and fairly good reason to believe that it is not. # # With luck, this may eventually mutate into something useful. For # now, just leave it alone unless you really know what you are doing, # in which case, on your head be it. # # YOU HAVE BEEN WARNED # ######################################################################## import os import sys import yaml import glob import time import shutil import base64 import socket import sqlite3 import weakref import rpki.POW import rpki.x509 import rpki.sundial import rpki.resource_set # Teach SQLite3 about our data types. sqlite3.register_adapter(rpki.POW.IPAddress, lambda x: buffer("_" + x.toBytes())) sqlite3.register_converter("RangeVal", lambda s: long(s) if s.isdigit() else rpki.POW.IPAddress.fromBytes(s[1:])) sqlite3.register_adapter(rpki.x509.X501DN, str) class main(object): tal_directory = None constraints = None rcynic_input = None rcynic_output = None tals = None keyfile = None ltakey = None ltacer = None ltauri = "rsync://localhost/lta" ltasia = ltauri + "/" ltaaia = ltauri + ".cer" ltamft = ltauri + "/lta.mft" ltacrl = ltauri + "/lta.crl" cer_delta = rpki.sundial.timedelta(days = 7) crl_delta = rpki.sundial.timedelta(hours = 1) all_mentioned_resources = rpki.resource_set.resource_bag() def __init__(self): print "Parsing YAML" self.parse_yaml() print print "Parsing TALs" self.parse_tals() print print "Creating DB" self.rpdb = RPDB(self.db_name) print print "Creating CA" self.create_ca() print print "Loading DB" self.rpdb.load(self.rcynic_input) print print "Compute resources we need to prune from input forest" self.compute_all_mentioned_resources() print print "Processing deletions" self.process_constraint_deletions() print print "Re-parenting TAs" self.re_parent_tas() print print "Generating CRL and manifest" self.generate_crl_and_manifest() print print "Committing final changes to DB" self.rpdb.commit() print print "Dumping para-objects" self.rpdb.dump_paras(self.rcynic_output) print print "Closing DB" self.rpdb.close() def create_ca(self): self.serial = Serial() self.ltakey = rpki.x509.RSA.generate(quiet = True) cer = OutgoingX509.self_certify( cn = "%s LTA Root Certificate" % socket.getfqdn(), keypair = self.ltakey, subject_key = self.ltakey.get_RSApublic(), serial = self.serial(), sia = (self.ltasia, self.ltamft, None), notAfter = rpki.sundial.now() + self.cer_delta, resources = rpki.resource_set.resource_bag.from_str("0-4294967295,0.0.0.0/0,::/0")) subject_id = self.rpdb.find_keyname(cer.getSubject(), cer.get_SKI()) self.rpdb.cur.execute("INSERT INTO outgoing (der, fn2, subject, issuer, uri, key) " "VALUES (?, 'cer', ?, ?, ?, ?)", (buffer(cer.get_DER()), subject_id, subject_id, self.ltaaia, buffer(self.ltakey.get_DER()))) self.ltacer = self.rpdb.find_outgoing_by_id(self.rpdb.cur.lastrowid) def parse_yaml(self, fn = "rcynic-lta.yaml"): y = yaml.safe_load(open(fn, "r")) self.db_name = y["db-name"] self.tal_directory = y["tal-directory"] self.rcynic_input = y["rcynic-input"] self.rcynic_output = y["rcynic-output"] self.keyfile = y["keyfile"] self.constraints = [Constraint(yy) for yy in y["constraints"]] def parse_tals(self): self.tals = {} for fn in glob.iglob(os.path.join(self.tal_directory, "*.tal")): with open(fn, "r") as f: uri = f.readline().strip() key = rpki.POW.Asymmetric.derReadPublic(base64.b64decode(f.read())) self.tals[uri] = key def compute_all_mentioned_resources(self): for constraint in self.constraints: self.all_mentioned_resources |= constraint.mentioned_resources def process_constraint_deletions(self): for obj in self.rpdb.find_by_resource_bag(self.all_mentioned_resources): self.add_para(obj, obj.resources - self.all_mentioned_resources) def re_parent_tas(self): for uri, key in self.tals.iteritems(): for ta in self.rpdb.find_by_ski_or_uri(key.calculateSKI(), uri): if ta.para_obj is None: self.add_para(ta, ta.resources - self.all_mentioned_resources) def add_para(self, obj, resources): return self.rpdb.add_para( obj = obj, resources = resources, serial = self.serial, ltacer = self.ltacer, ltasia = self.ltasia, ltaaia = self.ltaaia, ltamft = self.ltamft, ltacrl = self.ltacrl, ltakey = self.ltakey) def generate_crl_and_manifest(self): thisUpdate = rpki.sundial.now() nextUpdate = thisUpdate + self.crl_delta serial = self.serial() issuer = self.ltacer.getSubject() aki = buffer(self.ltacer.get_SKI()) crl = OutgoingCRL.generate( keypair = self.ltakey, issuer = self.ltacer, serial = serial, thisUpdate = thisUpdate, nextUpdate = nextUpdate, revokedCertificates = ()) issuer_id = self.rpdb.find_keyname(issuer, aki) self.rpdb.cur.execute("INSERT INTO outgoing (der, fn2, subject, issuer, uri) " "VALUES (?, 'crl', NULL, ?, ?)", (buffer(crl.get_DER()), issuer_id, self.ltacrl)) crl = self.rpdb.find_outgoing_by_id(self.rpdb.cur.lastrowid) key = rpki.x509.RSA.generate(quiet = True) cer = self.ltacer.issue( keypair = self.ltakey, subject_key = key.get_RSApublic(), serial = serial, sia = (None, None, self.ltamft), aia = self.ltaaia, crldp = self.ltacrl, resources = rpki.resource_set.resource_bag.from_inheritance(), notAfter = self.ltacer.getNotAfter(), is_ca = False) # Temporary kludge, need more general solution but that requires # more refactoring than I feel like doing this late in the day. # self.rpdb.cur.execute("SELECT fn2, der, uri FROM outgoing WHERE issuer = ?", (self.ltacer.rowid,)) names_and_objs = [(uri, OutgoingObject.create(rpdb = self.rpdb, rowid = None, fn2 = fn2, der = der, uri = uri)) for fn2, der, uri in self.rpdb.cur.fetchall()] mft = OutgoingSignedManifest.build( serial = serial, thisUpdate = thisUpdate, nextUpdate = nextUpdate, names_and_objs = names_and_objs, keypair = key, certs = cer) subject_id = self.rpdb.find_keyname(cer.getSubject(), cer.get_SKI()) self.rpdb.cur.execute("INSERT INTO outgoing (der, fn2, subject, issuer, uri, key) " "VALUES (?, 'mft', ?, ?, ?, ?)", (buffer(mft.get_DER()), subject_id, issuer_id, self.ltamft, buffer(key.get_DER()))) @staticmethod def parse_xki(s): """ Parse text form of an SKI or AKI. We accept two encodings: colon-delimited hexadecimal, and URL-safe Base64. The former is what OpenSSL prints in its text representation of SKI and AKI extensions; the latter is the g(SKI) value that some RPKI CA engines (including rpkid) use when constructing filenames. In either case, we check that the decoded result contains the right number of octets to be a SHA-1 hash. """ if ":" in s: b = "".join(chr(int(c, 16)) for c in s.split(":")) else: b = base64.urlsafe_b64decode(s + ("=" * (4 - len(s) % 4))) if len(b) != 20: raise RuntimeError("Bad length for SHA1 xKI value: %r" % s) return b class Serial(object): def __init__(self): self.value = long(time.time()) << 32 def __call__(self): self.value += 1 return self.value class Constraint(object): roa_asn = None roa_maxlen = None router_cert_key = None router_cert_subject = None def __init__(self, y): self.prefixes = rpki.resource_set.resource_bag.from_str(str(y.get("prefix", ""))) self.asns = rpki.resource_set.resource_bag.from_str(str(y.get("asn", ""))) self.ghostbuster = y.get("ghostbuster") if "roa" in y: self.roa_asn = long(y["roa"]["asn"]) if "maxlen" in y["roa"]: self.roa_maxlen = long(y["roa"]["maxlen"]) if "router-cert" in y: self.router_cert_key = y["router-cert"]["key"] self.router_cert_subject = y["router-cert"]["subject"] @property def mentioned_resources(self): return self.prefixes | self.asns class BaseObject(object): """ Mixin to add some SQL-related methods to classes derived from rpki.x509.DER_object. """ _rpdb = None _rowid = None _fn2 = None _fn2map = None _uri = None @property def rowid(self): return self._rowid @property def resources(self): return self.get_3779resources() @property def para_resources(self): return self.resources if self.para_obj is None else self.para_obj.resources @property def fn2(self): return self._fn2 @property def uri(self): return self._uri @classmethod def setfn2map(cls, **map): cls._fn2map = map for k, v in map.iteritems(): v._fn2 = k class IncomingObject(BaseObject): @property def para_obj(self): if getattr(self, "_para_id", None) is None: self._rpdb.cur.execute("SELECT replacement FROM incoming WHERE id = ?", (self.rowid,)) self._para_id = self._rpdb.cur.fetchone()[0] return self._rpdb.find_outgoing_by_id(self._para_id) @para_obj.setter def para_obj(self, value): if value is None: self._rpdb.cur.execute("DELETE FROM outgoing WHERE id IN (SELECT replacement FROM incoming WHERE id = ?)", (self.rowid,)) try: del self._para_id except AttributeError: pass else: assert isinstance(value.rowid, int) self._rpdb.cur.execute("UPDATE incoming SET replacement = ? WHERE id = ?", (value.rowid, self.rowid)) self._para_id = value.rowid @classmethod def fromFile(cls, fn): return cls._fn2map[os.path.splitext(fn)[1][1:]](DER_file = fn) @classmethod def create(cls, rpdb, rowid, fn2, der, uri): self = cls._fn2map[fn2](DER = der) self._uri = uri self._rpdb = rpdb self._rowid = rowid return self class OutgoingObject(BaseObject): @property def orig_obj(self): if getattr(self, "_orig_id", None) is None: self._rpdb.cur.execute("SELECT id FROM incoming WHERE replacement = ?", (self.rowid,)) r = self._rpdb.cur.fetchone() self._orig_id = None if r is None else r[0] return self._rpdb.find_incoming_by_id(self._orig_id) @classmethod def create(cls, rpdb, rowid, fn2, der, uri): self = cls._fn2map[fn2]() if der is not None: self.set(DER = der) self._rpdb = rpdb self._rowid = rowid self._uri = uri return self class IncomingX509 (rpki.x509.X509, IncomingObject): pass class IncomingCRL (rpki.x509.CRL, IncomingObject): pass class IncomingSignedManifest (rpki.x509.SignedManifest, IncomingObject): pass class IncomingROA (rpki.x509.ROA, IncomingObject): pass class IncomingGhostbuster (rpki.x509.Ghostbuster, IncomingObject): pass class OutgoingX509 (rpki.x509.X509, OutgoingObject): pass class OutgoingCRL (rpki.x509.CRL, OutgoingObject): pass class OutgoingSignedManifest (rpki.x509.SignedManifest, OutgoingObject): pass class OutgoingROA (rpki.x509.ROA, OutgoingObject): pass class OutgoingGhostbuster (rpki.x509.Ghostbuster, OutgoingObject): pass IncomingObject.setfn2map(cer = IncomingX509, crl = IncomingCRL, mft = IncomingSignedManifest, roa = IncomingROA, gbr = IncomingGhostbuster) OutgoingObject.setfn2map(cer = OutgoingX509, crl = OutgoingCRL, mft = OutgoingSignedManifest, roa = OutgoingROA, gbr = OutgoingGhostbuster) class RPDB(object): """ Relying party database. """ def __init__(self, db_name): try: os.unlink(db_name) except: pass self.db = sqlite3.connect(db_name, detect_types = sqlite3.PARSE_DECLTYPES) self.db.text_factory = str self.cur = self.db.cursor() self.incoming_cache = weakref.WeakValueDictionary() self.outgoing_cache = weakref.WeakValueDictionary() self.cur.executescript(''' PRAGMA foreign_keys = on; CREATE TABLE keyname ( id INTEGER PRIMARY KEY NOT NULL, name TEXT NOT NULL, keyid BLOB NOT NULL, UNIQUE (name, keyid)); CREATE TABLE incoming ( id INTEGER PRIMARY KEY NOT NULL, der BLOB NOT NULL, fn2 TEXT NOT NULL CHECK (fn2 IN ('cer', 'crl', 'mft', 'roa', 'gbr')), uri TEXT NOT NULL, depth INTEGER, deleted INTEGER NOT NULL DEFAULT 0, subject INTEGER REFERENCES keyname(id) ON DELETE RESTRICT ON UPDATE RESTRICT, issuer INTEGER NOT NULL REFERENCES keyname(id) ON DELETE RESTRICT ON UPDATE RESTRICT, replacement INTEGER REFERENCES outgoing(id) ON DELETE SET NULL ON UPDATE SET NULL, UNIQUE (der), UNIQUE (subject, issuer), CHECK ((subject IS NULL) == (fn2 == 'crl'))); CREATE TABLE outgoing ( id INTEGER PRIMARY KEY NOT NULL, der BLOB, key BLOB, fn2 TEXT NOT NULL CHECK (fn2 IN ('cer', 'crl', 'mft', 'roa', 'gbr')), uri TEXT NOT NULL, subject INTEGER REFERENCES keyname(id) ON DELETE RESTRICT ON UPDATE RESTRICT, issuer INTEGER NOT NULL REFERENCES keyname(id) ON DELETE RESTRICT ON UPDATE RESTRICT, UNIQUE (subject, issuer), CHECK ((key IS NULL) == (fn2 == 'crl')), CHECK ((subject IS NULL) == (fn2 == 'crl'))); CREATE TABLE range ( id INTEGER NOT NULL REFERENCES incoming(id) ON DELETE CASCADE ON UPDATE CASCADE, min RangeVal NOT NULL, max RangeVal NOT NULL, UNIQUE (id, min, max)); ''') def load(self, rcynic_input, spinner = 100): start = rpki.sundial.now() nobj = 0 for root, dirs, files in os.walk(rcynic_input): for fn in files: fn = os.path.join(root, fn) try: obj = IncomingObject.fromFile(fn) except: if spinner: sys.stderr.write("\r") sys.stderr.write("Couldn't read %s, skipping\n" % fn) continue if spinner and nobj % spinner == 0: sys.stderr.write("\r%s %d %s..." % ("|\\-/"[(nobj/spinner) & 3], nobj, rpki.sundial.now() - start)) nobj += 1 if obj.fn2 == "crl": ski = None aki = buffer(obj.get_AKI()) cer = None bag = None issuer = obj.getIssuer() subject = None else: if obj.fn2 == "cer": cer = obj else: cer = rpki.x509.X509(POW = obj.get_POW().certs()[0]) issuer = cer.getIssuer() subject = cer.getSubject() ski = buffer(cer.get_SKI()) aki = cer.get_AKI() if aki is None: assert subject == issuer aki = ski else: aki = buffer(aki) bag = cer.get_3779resources() der = buffer(obj.get_DER()) uri = "rsync://" + fn[len(rcynic_input) + 1:] self.cur.execute("SELECT id FROM incoming WHERE der = ?", (der,)) r = self.cur.fetchone() if r is not None: rowid = r[0] else: subject_id = None if ski is None else self.find_keyname(subject, ski) issuer_id = self.find_keyname(issuer, aki) self.cur.execute("INSERT INTO incoming (der, fn2, subject, issuer, uri) VALUES (?, ?, ?, ?, ?)", (der, obj.fn2, subject_id, issuer_id, uri)) rowid = self.cur.lastrowid if bag is not None: for rset in (bag.asn, bag.v4, bag.v6): if rset is not None: self.cur.executemany("REPLACE INTO range (id, min, max) VALUES (?, ?, ?)", ((rowid, i.min, i.max) for i in rset)) if spinner: sys.stderr.write("\r= %d objects in %s.\n" % (nobj, rpki.sundial.now() - start)) self.cur.execute("UPDATE incoming SET depth = 0 WHERE subject = issuer") for depth in xrange(1, 500): self.cur.execute("SELECT COUNT(*) FROM incoming WHERE depth IS NULL") if self.cur.fetchone()[0] == 0: break if spinner: sys.stderr.write("\rSetting depth %d..." % depth) self.cur.execute(""" UPDATE incoming SET depth = ? WHERE depth IS NULL AND issuer IN (SELECT subject FROM incoming WHERE depth = ?) """, (depth, depth - 1)) else: if spinner: sys.stderr.write("\rSetting depth %d is absurd, giving up, " % depth) if spinner: sys.stderr.write("\nCommitting...") self.db.commit() if spinner: sys.stderr.write("done.\n") def add_para(self, obj, resources, serial, ltacer, ltasia, ltaaia, ltamft, ltacrl, ltakey): assert isinstance(obj, IncomingX509) if obj.para_obj is not None: resources &= obj.para_obj.resources obj.para_obj = None if not resources: return pow = obj.get_POW() x = rpki.POW.X509() x.setVersion( pow.getVersion()) x.setSubject( pow.getSubject()) x.setNotBefore( pow.getNotBefore()) x.setNotAfter( pow.getNotAfter()) x.setPublicKey( pow.getPublicKey()) x.setSKI( pow.getSKI()) x.setBasicConstraints( pow.getBasicConstraints()) x.setKeyUsage( pow.getKeyUsage()) x.setCertificatePolicies( pow.getCertificatePolicies()) x.setSIA( *pow.getSIA()) x.setIssuer( ltacer.get_POW().getIssuer()) x.setAKI( ltacer.get_POW().getSKI()) x.setAIA( (ltaaia,)) x.setCRLDP( (ltacrl,)) x.setSerial( serial()) x.setRFC3779( asn = ((r.min, r.max) for r in resources.asn), ipv4 = ((r.min, r.max) for r in resources.v4), ipv6 = ((r.min, r.max) for r in resources.v6)) x.sign(ltakey.get_POW(), rpki.POW.SHA256_DIGEST) cer = OutgoingX509(POW = x) ski = buffer(cer.get_SKI()) aki = buffer(cer.get_AKI()) bag = cer.get_3779resources() issuer = cer.getIssuer() subject = cer.getSubject() der = buffer(cer.get_DER()) uri = ltasia + cer.gSKI() + ".cer" # This will want to change when we start generating replacement keys for everything. # This should really be a keypair, not just a public key, same comment. # key = buffer(pow.getPublicKey().derWritePublic()) subject_id = self.find_keyname(subject, ski) issuer_id = self.find_keyname(issuer, aki) self.cur.execute("INSERT INTO outgoing (der, fn2, subject, issuer, uri, key) " "VALUES (?, 'cer', ?, ?, ?, ?)", (der, subject_id, issuer_id, uri, key)) rowid = self.cur.lastrowid self.cur.execute("UPDATE incoming SET replacement = ? WHERE id = ?", (rowid, obj.rowid)) # Fix up _orig_id and _para_id here? Maybe later. #self.db.commit() def dump_paras(self, rcynic_output): shutil.rmtree(rcynic_output, ignore_errors = True) rsync = "rsync://" self.cur.execute("SELECT der, uri FROM outgoing") for der, uri in self.cur.fetchall(): assert uri.startswith(rsync) fn = os.path.join(rcynic_output, uri[len(rsync):]) dn = os.path.dirname(fn) if not os.path.exists(dn): os.makedirs(dn) with open(fn, "wb") as f: #print ">> Writing", f.name f.write(der) def find_keyname(self, name, keyid): keys = (name, buffer(keyid)) self.cur.execute("SELECT id FROM keyname WHERE name = ? AND keyid = ?", keys) result = self.cur.fetchone() if result is None: self.cur.execute("INSERT INTO keyname (name, keyid) VALUES (?, ?)", keys) result = self.cur.lastrowid else: result = result[0] return result def find_incoming_by_id(self, rowid): if rowid is None: return None if rowid in self.incoming_cache: return self.incoming_cache[rowid] r = self._find_results(None, "WHERE id = ?", [rowid]) assert len(r) < 2 return r[0] if r else None def find_outgoing_by_id(self, rowid): if rowid is None: return None if rowid in self.outgoing_cache: return self.outgoing_cache[rowid] self.cur.execute("SELECT fn2, der, key, uri FROM outgoing WHERE id = ?", (rowid,)) r = self.cur.fetchone() if r is None: return None fn2, der, key, uri = r obj = OutgoingObject.create(rpdb = self, rowid = rowid, fn2 = fn2, der = der, uri = uri) self.outgoing_cache[rowid] = obj return obj def find_by_ski_or_uri(self, ski, uri): if not ski and not uri: return [] j = "" w = [] a = [] if ski: j = "JOIN keyname ON incoming.subject = keyname.id" w.append("keyname.keyid = ?") a.append(buffer(ski)) if uri: w.append("incoming.uri = ?") a.append(uri) return self._find_results(None, "%s WHERE %s" % (j, " AND ".join(w)), a) # It's easiest to understand overlap conditions by understanding # non-overlap then inverting and and applying De Morgan's law. Ranges # A and B do not overlap if either A.min > B.max or A.max < B.min; # therefore they do overlap if A.min <= B.max and A.max >= B.min. def find_by_range(self, range_min, range_max = None, fn2 = None): if range_max is None: range_max = range_min if isinstance(range_min, (str, unicode)): range_min = long(range_min) if range_min.isdigit() else rpki.POW.IPAddress(range_min) if isinstance(range_max, (str, unicode)): range_max = long(range_max) if range_max.isdigit() else rpki.POW.IPAddress(range_max) assert isinstance(range_min, (int, long, rpki.POW.IPAddress)) assert isinstance(range_max, (int, long, rpki.POW.IPAddress)) return self._find_results( fn2, """ JOIN range ON incoming.id = range.id WHERE ? <= range.max AND ? >= range.min""", [range_min, range_max]) def find_by_resource_bag(self, bag, fn2 = None): assert bag.asn or bag.v4 or bag.v6 qset = [] aset = [] for rset in (bag.asn, bag.v4, bag.v6): if rset: for r in rset: qset.append("(? <= max AND ? >= min)") aset.append(r.min) aset.append(r.max) return self._find_results( fn2, """ JOIN range ON incoming.id = range.id WHERE """ + (" OR ".join(qset)), aset) def _find_results(self, fn2, query, args = None): if args is None: args = [] if fn2 is not None: query += " AND fn2 = ?" args.append(fn2) query = "SELECT incoming.id, incoming.fn2, incoming.der, incoming.uri FROM incoming " + query results = [] self.cur.execute(query, args) selections = self.cur.fetchall() for rowid, fn2, der, uri in selections: if rowid in self.incoming_cache: obj = self.incoming_cache[rowid] assert obj.rowid == rowid else: obj = IncomingObject.create(rpdb = self, rowid = rowid, fn2 = fn2, der = der, uri = uri) self.incoming_cache[rowid] = obj results.append(obj) return results def commit(self): self.db.commit() def close(self): self.commit() self.cur.close() self.db.close() if __name__ == "__main__": #profile = None profile = "rcynic-lta.prof" if profile: import cProfile prof = cProfile.Profile() try: prof.runcall(main) finally: prof.dump_stats(profile) sys.stderr.write("Dumped profile data to %s\n" % profile) else: main()