diff options
author | Rob Austein <sra@hactrn.net> | 2015-10-26 06:29:00 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2015-10-26 06:29:00 +0000 |
commit | b46deb1417dc3596e9ac9fe2fe8cc0b7f42457e7 (patch) | |
tree | ca0dc0276d1adc168bc3337ce0564c4ec4957c1b /rpki | |
parent | 397beaf6d9900dc3b3cb612c89ebf1d57b1d16f6 (diff) |
"Any programmer who fails to comply with the standard naming, formatting,
or commenting conventions should be shot. If it so happens that it is
inconvenient to shoot him, then he is to be politely requested to recode
his program in adherence to the above standard."
-- Michael Spier, Digital Equipment Corporation
svn path=/branches/tk705/; revision=6152
Diffstat (limited to 'rpki')
46 files changed, 14121 insertions, 14120 deletions
diff --git a/rpki/adns.py b/rpki/adns.py index c5af3549..b0f235e7 100644 --- a/rpki/adns.py +++ b/rpki/adns.py @@ -31,14 +31,14 @@ import rpki.sundial import rpki.log try: - import dns.resolver, dns.rdatatype, dns.rdataclass, dns.name, dns.message - import dns.inet, dns.exception, dns.query, dns.rcode, dns.ipv4, dns.ipv6 + import dns.resolver, dns.rdatatype, dns.rdataclass, dns.name, dns.message + import dns.inet, dns.exception, dns.query, dns.rcode, dns.ipv4, dns.ipv6 except ImportError: - if __name__ == "__main__": - sys.stderr.write("DNSPython not available, skipping rpki.adns unit test\n") - sys.exit(0) - else: - raise + if __name__ == "__main__": + sys.stderr.write("DNSPython not available, skipping rpki.adns unit test\n") + sys.exit(0) + else: + raise logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) resolver = dns.resolver.Resolver() if resolver.cache is None: - resolver.cache = dns.resolver.Cache() + resolver.cache = dns.resolver.Cache() ## @var nameservers # Nameservers from resolver.nameservers converted to (af, address) @@ -58,327 +58,327 @@ if resolver.cache is None: nameservers = [] for ns in resolver.nameservers: - try: - nameservers.append((socket.AF_INET, dns.ipv4.inet_aton(ns))) - continue - except: - pass - try: - nameservers.append((socket.AF_INET6, dns.ipv6.inet_aton(ns))) - continue - except: - pass - logger.error("Couldn't parse nameserver address %r", ns) + try: + nameservers.append((socket.AF_INET, dns.ipv4.inet_aton(ns))) + continue + except: + pass + try: + nameservers.append((socket.AF_INET6, dns.ipv6.inet_aton(ns))) + continue + except: + pass + logger.error("Couldn't parse nameserver address %r", ns) class dispatcher(asyncore.dispatcher): - """ - Basic UDP socket reader for use with asyncore. - """ - - def __init__(self, cb, eb, af, bufsize = 65535): - asyncore.dispatcher.__init__(self) - self.cb = cb - self.eb = eb - self.af = af - self.bufsize = bufsize - self.create_socket(af, socket.SOCK_DGRAM) - - def handle_read(self): - """ - Receive a packet, hand it off to query class callback. - """ - - wire, from_address = self.recvfrom(self.bufsize) - self.cb(self.af, from_address[0], from_address[1], wire) - - def handle_error(self): - """ - Pass errors to query class errback. - """ - - self.eb(sys.exc_info()[1]) - - def handle_connect(self): - """ - Quietly ignore UDP "connection" events. - """ - - pass - - def writable(self): """ - We don't need to hear about UDP socket becoming writable. + Basic UDP socket reader for use with asyncore. """ - return False + def __init__(self, cb, eb, af, bufsize = 65535): + asyncore.dispatcher.__init__(self) + self.cb = cb + self.eb = eb + self.af = af + self.bufsize = bufsize + self.create_socket(af, socket.SOCK_DGRAM) + def handle_read(self): + """ + Receive a packet, hand it off to query class callback. + """ -class query(object): - """ - Simplified (no search paths) asynchronous adaptation of - dns.resolver.Resolver.query() (q.v.). - """ - - def __init__(self, cb, eb, qname, qtype = dns.rdatatype.A, qclass = dns.rdataclass.IN): - if isinstance(qname, (str, unicode)): - qname = dns.name.from_text(qname) - if isinstance(qtype, str): - qtype = dns.rdatatype.from_text(qtype) - if isinstance(qclass, str): - qclass = dns.rdataclass.from_text(qclass) - assert qname.is_absolute() - self.cb = cb - self.eb = eb - self.qname = qname - self.qtype = qtype - self.qclass = qclass - self.start = time.time() - rpki.async.event_defer(self.go) - - def go(self): - """ - Start running the query. Check our cache before doing network - query; if we find an answer there, just return it. Otherwise - start the network query. - """ - - if resolver.cache: - answer = resolver.cache.get((self.qname, self.qtype, self.qclass)) - else: - answer = None - if answer: - self.cb(self, answer) - else: - self.timer = rpki.async.timer() - self.sockets = {} - self.request = dns.message.make_query(self.qname, self.qtype, self.qclass) - if resolver.keyname is not None: - self.request.use_tsig(resolver.keyring, resolver.keyname, resolver.keyalgorithm) - self.request.use_edns(resolver.edns, resolver.ednsflags, resolver.payload) - self.response = None - self.backoff = 0.10 - self.nameservers = nameservers[:] - self.loop1() - - def loop1(self): - """ - Outer loop. If we haven't got a response yet and still have - nameservers to check, start inner loop. Otherwise, we're done. - """ + wire, from_address = self.recvfrom(self.bufsize) + self.cb(self.af, from_address[0], from_address[1], wire) - self.timer.cancel() - if self.response is None and self.nameservers: - self.iterator = rpki.async.iterator(self.nameservers[:], self.loop2, self.done2) - else: - self.done1() + def handle_error(self): + """ + Pass errors to query class errback. + """ - def loop2(self, iterator, nameserver): - """ - Inner loop. Send query to next nameserver in our list, unless - we've hit the overall timeout for this query. - """ + self.eb(sys.exc_info()[1]) - self.timer.cancel() - try: - timeout = resolver._compute_timeout(self.start) - except dns.resolver.Timeout, e: - self.lose(e) - else: - af, addr = nameserver - if af not in self.sockets: - self.sockets[af] = dispatcher(self.socket_cb, self.socket_eb, af) - self.sockets[af].sendto(self.request.to_wire(), - (dns.inet.inet_ntop(af, addr), resolver.port)) - self.timer.set_handler(self.socket_timeout) - self.timer.set_errback(self.socket_eb) - self.timer.set(rpki.sundial.timedelta(seconds = timeout)) - - def socket_timeout(self): - """ - No answer from nameserver, move on to next one (inner loop). - """ + def handle_connect(self): + """ + Quietly ignore UDP "connection" events. + """ - self.response = None - self.iterator() + pass - def socket_eb(self, e): - """ - UDP socket signaled error. If it really is some kind of socket - error, handle as if we've timed out on this nameserver; otherwise, - pass error back to caller. - """ + def writable(self): + """ + We don't need to hear about UDP socket becoming writable. + """ - self.timer.cancel() - if isinstance(e, socket.error): - self.response = None - self.iterator() - else: - self.lose(e) + return False - def socket_cb(self, af, from_host, from_port, wire): - """ - Received a packet that might be a DNS message. If it doesn't look - like it came from one of our nameservers, just drop it and leave - the timer running. Otherwise, try parsing it: if it's an answer, - we're done, otherwise handle error appropriately and move on to - next nameserver. - """ - sender = (af, dns.inet.inet_pton(af, from_host)) - if from_port != resolver.port or sender not in self.nameservers: - return - self.timer.cancel() - try: - self.response = dns.message.from_wire(wire, keyring = self.request.keyring, request_mac = self.request.mac, one_rr_per_rrset = False) - except dns.exception.FormError: - self.nameservers.remove(sender) - else: - rcode = self.response.rcode() - if rcode in (dns.rcode.NOERROR, dns.rcode.NXDOMAIN): - self.done1() - return - if rcode != dns.rcode.SERVFAIL: - self.nameservers.remove(sender) - self.response = None - self.iterator() - - def done2(self): +class query(object): """ - Done with inner loop. If we still haven't got an answer and - haven't (yet?) eliminated all of our nameservers, wait a little - while before starting the cycle again, unless we've hit the - timeout threshold for the whole query. + Simplified (no search paths) asynchronous adaptation of + dns.resolver.Resolver.query() (q.v.). """ - if self.response is None and self.nameservers: - try: - delay = rpki.sundial.timedelta(seconds = min(resolver._compute_timeout(self.start), self.backoff)) - self.backoff *= 2 - self.timer.set_handler(self.loop1) - self.timer.set_errback(self.lose) - self.timer.set(delay) - except dns.resolver.Timeout, e: - self.lose(e) - else: - self.loop1() - - def cleanup(self): - """ - Shut down our timer and sockets. - """ + def __init__(self, cb, eb, qname, qtype = dns.rdatatype.A, qclass = dns.rdataclass.IN): + if isinstance(qname, (str, unicode)): + qname = dns.name.from_text(qname) + if isinstance(qtype, str): + qtype = dns.rdatatype.from_text(qtype) + if isinstance(qclass, str): + qclass = dns.rdataclass.from_text(qclass) + assert qname.is_absolute() + self.cb = cb + self.eb = eb + self.qname = qname + self.qtype = qtype + self.qclass = qclass + self.start = time.time() + rpki.async.event_defer(self.go) + + def go(self): + """ + Start running the query. Check our cache before doing network + query; if we find an answer there, just return it. Otherwise + start the network query. + """ + + if resolver.cache: + answer = resolver.cache.get((self.qname, self.qtype, self.qclass)) + else: + answer = None + if answer: + self.cb(self, answer) + else: + self.timer = rpki.async.timer() + self.sockets = {} + self.request = dns.message.make_query(self.qname, self.qtype, self.qclass) + if resolver.keyname is not None: + self.request.use_tsig(resolver.keyring, resolver.keyname, resolver.keyalgorithm) + self.request.use_edns(resolver.edns, resolver.ednsflags, resolver.payload) + self.response = None + self.backoff = 0.10 + self.nameservers = nameservers[:] + self.loop1() + + def loop1(self): + """ + Outer loop. If we haven't got a response yet and still have + nameservers to check, start inner loop. Otherwise, we're done. + """ + + self.timer.cancel() + if self.response is None and self.nameservers: + self.iterator = rpki.async.iterator(self.nameservers[:], self.loop2, self.done2) + else: + self.done1() + + def loop2(self, iterator, nameserver): + """ + Inner loop. Send query to next nameserver in our list, unless + we've hit the overall timeout for this query. + """ + + self.timer.cancel() + try: + timeout = resolver._compute_timeout(self.start) + except dns.resolver.Timeout, e: + self.lose(e) + else: + af, addr = nameserver + if af not in self.sockets: + self.sockets[af] = dispatcher(self.socket_cb, self.socket_eb, af) + self.sockets[af].sendto(self.request.to_wire(), + (dns.inet.inet_ntop(af, addr), resolver.port)) + self.timer.set_handler(self.socket_timeout) + self.timer.set_errback(self.socket_eb) + self.timer.set(rpki.sundial.timedelta(seconds = timeout)) + + def socket_timeout(self): + """ + No answer from nameserver, move on to next one (inner loop). + """ + + self.response = None + self.iterator() + + def socket_eb(self, e): + """ + UDP socket signaled error. If it really is some kind of socket + error, handle as if we've timed out on this nameserver; otherwise, + pass error back to caller. + """ + + self.timer.cancel() + if isinstance(e, socket.error): + self.response = None + self.iterator() + else: + self.lose(e) + + def socket_cb(self, af, from_host, from_port, wire): + """ + Received a packet that might be a DNS message. If it doesn't look + like it came from one of our nameservers, just drop it and leave + the timer running. Otherwise, try parsing it: if it's an answer, + we're done, otherwise handle error appropriately and move on to + next nameserver. + """ + + sender = (af, dns.inet.inet_pton(af, from_host)) + if from_port != resolver.port or sender not in self.nameservers: + return + self.timer.cancel() + try: + self.response = dns.message.from_wire(wire, keyring = self.request.keyring, request_mac = self.request.mac, one_rr_per_rrset = False) + except dns.exception.FormError: + self.nameservers.remove(sender) + else: + rcode = self.response.rcode() + if rcode in (dns.rcode.NOERROR, dns.rcode.NXDOMAIN): + self.done1() + return + if rcode != dns.rcode.SERVFAIL: + self.nameservers.remove(sender) + self.response = None + self.iterator() + + def done2(self): + """ + Done with inner loop. If we still haven't got an answer and + haven't (yet?) eliminated all of our nameservers, wait a little + while before starting the cycle again, unless we've hit the + timeout threshold for the whole query. + """ + + if self.response is None and self.nameservers: + try: + delay = rpki.sundial.timedelta(seconds = min(resolver._compute_timeout(self.start), self.backoff)) + self.backoff *= 2 + self.timer.set_handler(self.loop1) + self.timer.set_errback(self.lose) + self.timer.set(delay) + except dns.resolver.Timeout, e: + self.lose(e) + else: + self.loop1() + + def cleanup(self): + """ + Shut down our timer and sockets. + """ + + self.timer.cancel() + for s in self.sockets.itervalues(): + s.close() - self.timer.cancel() - for s in self.sockets.itervalues(): - s.close() + def lose(self, e): + """ + Something bad happened. Clean up, then pass error back to caller. + """ + + self.cleanup() + self.eb(self, e) + + def done1(self): + """ + Done with outer loop. If we got a useful answer, cache it, then + pass it back to caller; if we got an error, pass the appropriate + exception back to caller. + """ + + self.cleanup() + try: + if not self.nameservers: + raise dns.resolver.NoNameservers + if self.response.rcode() == dns.rcode.NXDOMAIN: + raise dns.resolver.NXDOMAIN + answer = dns.resolver.Answer(self.qname, self.qtype, self.qclass, self.response) + if resolver.cache: + resolver.cache.put((self.qname, self.qtype, self.qclass), answer) + self.cb(self, answer) + except (rpki.async.ExitNow, SystemExit): + raise + except Exception, e: + self.lose(e) - def lose(self, e): - """ - Something bad happened. Clean up, then pass error back to caller. - """ +class getaddrinfo(object): - self.cleanup() - self.eb(self, e) + typemap = { dns.rdatatype.A : socket.AF_INET, + dns.rdatatype.AAAA : socket.AF_INET6 } + + def __init__(self, cb, eb, host, address_families = typemap.values()): + self.cb = cb + self.eb = eb + self.host = host + self.result = [] + self.queries = [query(self.done, self.lose, host, qtype) + for qtype in self.typemap + if self.typemap[qtype] in address_families] + + def done(self, q, answer): + if answer is not None: + for a in answer: + self.result.append((self.typemap[a.rdtype], a.address)) + self.queries.remove(q) + if not self.queries: + self.cb(self.result) - def done1(self): - """ - Done with outer loop. If we got a useful answer, cache it, then - pass it back to caller; if we got an error, pass the appropriate - exception back to caller. - """ + def lose(self, q, e): + if isinstance(e, dns.resolver.NoAnswer): + self.done(q, None) + else: + for q in self.queries: + q.cleanup() + self.eb(e) - self.cleanup() - try: - if not self.nameservers: - raise dns.resolver.NoNameservers - if self.response.rcode() == dns.rcode.NXDOMAIN: - raise dns.resolver.NXDOMAIN - answer = dns.resolver.Answer(self.qname, self.qtype, self.qclass, self.response) - if resolver.cache: - resolver.cache.put((self.qname, self.qtype, self.qclass), answer) - self.cb(self, answer) - except (rpki.async.ExitNow, SystemExit): - raise - except Exception, e: - self.lose(e) +if __name__ == "__main__": -class getaddrinfo(object): + rpki.log.init("test-adns") + print "Some adns tests may take a minute or two, please be patient" - typemap = { dns.rdatatype.A : socket.AF_INET, - dns.rdatatype.AAAA : socket.AF_INET6 } - - def __init__(self, cb, eb, host, address_families = typemap.values()): - self.cb = cb - self.eb = eb - self.host = host - self.result = [] - self.queries = [query(self.done, self.lose, host, qtype) - for qtype in self.typemap - if self.typemap[qtype] in address_families] - - def done(self, q, answer): - if answer is not None: - for a in answer: - self.result.append((self.typemap[a.rdtype], a.address)) - self.queries.remove(q) - if not self.queries: - self.cb(self.result) - - def lose(self, q, e): - if isinstance(e, dns.resolver.NoAnswer): - self.done(q, None) - else: - for q in self.queries: - q.cleanup() - self.eb(e) + class test_getaddrinfo(object): -if __name__ == "__main__": + def __init__(self, qname): + self.qname = qname + getaddrinfo(self.done, self.lose, qname) - rpki.log.init("test-adns") - print "Some adns tests may take a minute or two, please be patient" + def done(self, result): + print "getaddrinfo(%s) returned: %s" % ( + self.qname, + ", ".join(str(r) for r in result)) - class test_getaddrinfo(object): + def lose(self, e): + print "getaddrinfo(%s) failed: %r" % (self.qname, e) - def __init__(self, qname): - self.qname = qname - getaddrinfo(self.done, self.lose, qname) + class test_query(object): - def done(self, result): - print "getaddrinfo(%s) returned: %s" % ( - self.qname, - ", ".join(str(r) for r in result)) + def __init__(self, qname, qtype = dns.rdatatype.A, qclass = dns.rdataclass.IN): + self.qname = qname + self.qtype = qtype + self.qclass = qclass + query(self.done, self.lose, qname, qtype = qtype, qclass = qclass) - def lose(self, e): - print "getaddrinfo(%s) failed: %r" % (self.qname, e) + def done(self, q, result): + print "query(%s, %s, %s) returned: %s" % ( + self.qname, + dns.rdatatype.to_text(self.qtype), + dns.rdataclass.to_text(self.qclass), + ", ".join(str(r) for r in result)) - class test_query(object): + def lose(self, q, e): + print "getaddrinfo(%s, %s, %s) failed: %r" % ( + self.qname, + dns.rdatatype.to_text(self.qtype), + dns.rdataclass.to_text(self.qclass), + e) - def __init__(self, qname, qtype = dns.rdatatype.A, qclass = dns.rdataclass.IN): - self.qname = qname - self.qtype = qtype - self.qclass = qclass - query(self.done, self.lose, qname, qtype = qtype, qclass = qclass) + if True: + for t in (dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.HINFO): + test_query("subvert-rpki.hactrn.net", t) + test_query("nonexistant.rpki.net") + test_query("subvert-rpki.hactrn.net", qclass = dns.rdataclass.CH) - def done(self, q, result): - print "query(%s, %s, %s) returned: %s" % ( - self.qname, - dns.rdatatype.to_text(self.qtype), - dns.rdataclass.to_text(self.qclass), - ", ".join(str(r) for r in result)) + for h in ("subvert-rpki.hactrn.net", "nonexistant.rpki.net"): + test_getaddrinfo(h) - def lose(self, q, e): - print "getaddrinfo(%s, %s, %s) failed: %r" % ( - self.qname, - dns.rdatatype.to_text(self.qtype), - dns.rdataclass.to_text(self.qclass), - e) - - if True: - for t in (dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.HINFO): - test_query("subvert-rpki.hactrn.net", t) - test_query("nonexistant.rpki.net") - test_query("subvert-rpki.hactrn.net", qclass = dns.rdataclass.CH) - - for h in ("subvert-rpki.hactrn.net", "nonexistant.rpki.net"): - test_getaddrinfo(h) - - rpki.async.event_loop() + rpki.async.event_loop() diff --git a/rpki/cli.py b/rpki/cli.py index 35999cb0..51ac0367 100644 --- a/rpki/cli.py +++ b/rpki/cli.py @@ -28,244 +28,244 @@ import argparse import traceback try: - import readline - have_readline = True + import readline + have_readline = True except ImportError: - have_readline = False + have_readline = False class BadCommandSyntax(Exception): - "Bad command line syntax." + "Bad command line syntax." class ExitArgparse(Exception): - "Exit method from ArgumentParser." + "Exit method from ArgumentParser." - def __init__(self, message = None, status = 0): - super(ExitArgparse, self).__init__() - self.message = message - self.status = status + def __init__(self, message = None, status = 0): + super(ExitArgparse, self).__init__() + self.message = message + self.status = status class Cmd(cmd.Cmd): - """ - Customized subclass of Python cmd module. - """ + """ + Customized subclass of Python cmd module. + """ - emptyline_repeats_last_command = False + emptyline_repeats_last_command = False - EOF_exits_command_loop = True + EOF_exits_command_loop = True - identchars = cmd.IDENTCHARS + "/-." + identchars = cmd.IDENTCHARS + "/-." - histfile = None + histfile = None - last_command_failed = False + last_command_failed = False - def onecmd(self, line): - """ - Wrap error handling around cmd.Cmd.onecmd(). Might want to do - something kinder than showing a traceback, eventually. - """ + def onecmd(self, line): + """ + Wrap error handling around cmd.Cmd.onecmd(). Might want to do + something kinder than showing a traceback, eventually. + """ - self.last_command_failed = False - try: - return cmd.Cmd.onecmd(self, line) - except SystemExit: - raise - except ExitArgparse, e: - if e.message is not None: - print e.message - self.last_command_failed = e.status != 0 - return False - except BadCommandSyntax, e: - print e - except: - traceback.print_exc() - self.last_command_failed = True - return False - - def do_EOF(self, arg): - if self.EOF_exits_command_loop and self.prompt: - print - return self.EOF_exits_command_loop - - def do_exit(self, arg): - """ - Exit program. - """ + self.last_command_failed = False + try: + return cmd.Cmd.onecmd(self, line) + except SystemExit: + raise + except ExitArgparse, e: + if e.message is not None: + print e.message + self.last_command_failed = e.status != 0 + return False + except BadCommandSyntax, e: + print e + except: + traceback.print_exc() + self.last_command_failed = True + return False + + def do_EOF(self, arg): + if self.EOF_exits_command_loop and self.prompt: + print + return self.EOF_exits_command_loop + + def do_exit(self, arg): + """ + Exit program. + """ + + return True + + do_quit = do_exit + + def emptyline(self): + """ + Handle an empty line. cmd module default is to repeat the last + command, which I find to be violation of the principal of least + astonishment, so my preference is that an empty line does nothing. + """ + + if self.emptyline_repeats_last_command: + cmd.Cmd.emptyline(self) + + def filename_complete(self, text, line, begidx, endidx): + """ + Filename completion handler, with hack to restore what I consider + the normal (bash-like) behavior when one hits the completion key + and there's only one match. + """ + + result = glob.glob(text + "*") + if len(result) == 1: + path = result.pop() + if os.path.isdir(path) or (os.path.islink(path) and os.path.isdir(os.path.join(path, "."))): + result.append(path + os.path.sep) + else: + result.append(path + " ") + return result + + def completenames(self, text, *ignored): + """ + Command name completion handler, with hack to restore what I + consider the normal (bash-like) behavior when one hits the + completion key and there's only one match. + """ + + result = cmd.Cmd.completenames(self, text, *ignored) + if len(result) == 1: + result[0] += " " + return result + + def help_help(self): + """ + Type "help [topic]" for help on a command, + or just "help" for a list of commands. + """ + + self.stdout.write(self.help_help.__doc__ + "\n") + + def complete_help(self, *args): + """ + Better completion function for help command arguments. + """ + + text = args[0] + names = self.get_names() + result = [] + for prefix in ("do_", "help_"): + result.extend(s[len(prefix):] for s in names if s.startswith(prefix + text) and s != "do_EOF") + return result + + if have_readline: + + def cmdloop_with_history(self): + """ + Better command loop, with history file and tweaked readline + completion delimiters. + """ + + old_completer_delims = readline.get_completer_delims() + if self.histfile is not None: + try: + readline.read_history_file(self.histfile) + except IOError: + pass + try: + readline.set_completer_delims("".join(set(old_completer_delims) - set(self.identchars))) + self.cmdloop() + finally: + if self.histfile is not None and readline.get_current_history_length(): + readline.write_history_file(self.histfile) + readline.set_completer_delims(old_completer_delims) + + else: + + cmdloop_with_history = cmd.Cmd.cmdloop - return True - do_quit = do_exit - def emptyline(self): +def yes_or_no(prompt, default = None, require_full_word = False): """ - Handle an empty line. cmd module default is to repeat the last - command, which I find to be violation of the principal of least - astonishment, so my preference is that an empty line does nothing. + Ask a yes-or-no question. """ - if self.emptyline_repeats_last_command: - cmd.Cmd.emptyline(self) + prompt = prompt.rstrip() + _yes_or_no_prompts[default] + while True: + answer = raw_input(prompt).strip().lower() + if not answer and default is not None: + return default + if answer == "yes" or (not require_full_word and answer.startswith("y")): + return True + if answer == "no" or (not require_full_word and answer.startswith("n")): + return False + print 'Please answer "yes" or "no"' - def filename_complete(self, text, line, begidx, endidx): - """ - Filename completion handler, with hack to restore what I consider - the normal (bash-like) behavior when one hits the completion key - and there's only one match. - """ +_yes_or_no_prompts = { + True : ' ("yes" or "no" ["yes"]) ', + False : ' ("yes" or "no" ["no"]) ', + None : ' ("yes" or "no") ' } - result = glob.glob(text + "*") - if len(result) == 1: - path = result.pop() - if os.path.isdir(path) or (os.path.islink(path) and os.path.isdir(os.path.join(path, "."))): - result.append(path + os.path.sep) - else: - result.append(path + " ") - return result - def completenames(self, text, *ignored): +class NonExitingArgumentParser(argparse.ArgumentParser): """ - Command name completion handler, with hack to restore what I - consider the normal (bash-like) behavior when one hits the - completion key and there's only one match. + ArgumentParser tweaked to throw ExitArgparse exception + rather than using sys.exit(), for use with command loop. """ - result = cmd.Cmd.completenames(self, text, *ignored) - if len(result) == 1: - result[0] += " " - return result + def exit(self, status = 0, message = None): + raise ExitArgparse(status = status, message = message) - def help_help(self): - """ - Type "help [topic]" for help on a command, - or just "help" for a list of commands. - """ - - self.stdout.write(self.help_help.__doc__ + "\n") - def complete_help(self, *args): - """ - Better completion function for help command arguments. +def parsecmd(subparsers, *arg_clauses): """ + Decorator to combine the argparse and cmd modules. - text = args[0] - names = self.get_names() - result = [] - for prefix in ("do_", "help_"): - result.extend(s[len(prefix):] for s in names if s.startswith(prefix + text) and s != "do_EOF") - return result + subparsers is an instance of argparse.ArgumentParser (or subclass) which was + returned by calling the .add_subparsers() method on an ArgumentParser instance + intended to handle parsing for the entire program on the command line. - if have_readline: - - def cmdloop_with_history(self): - """ - Better command loop, with history file and tweaked readline - completion delimiters. - """ - - old_completer_delims = readline.get_completer_delims() - if self.histfile is not None: - try: - readline.read_history_file(self.histfile) - except IOError: - pass - try: - readline.set_completer_delims("".join(set(old_completer_delims) - set(self.identchars))) - self.cmdloop() - finally: - if self.histfile is not None and readline.get_current_history_length(): - readline.write_history_file(self.histfile) - readline.set_completer_delims(old_completer_delims) - - else: - - cmdloop_with_history = cmd.Cmd.cmdloop + arg_clauses is a series of defarg() invocations defining arguments to be parsed + by the argparse code. + The decorator will use arg_clauses to construct two separate argparse parser + instances: one will be attached to the global parser as a subparser, the + other will be used to parse arguments for this command when invoked by cmd. + The decorator will replace the original do_whatever method with a wrapped version + which uses the local argparse instance to parse the single string supplied by + the cmd module. -def yes_or_no(prompt, default = None, require_full_word = False): - """ - Ask a yes-or-no question. - """ - - prompt = prompt.rstrip() + _yes_or_no_prompts[default] - while True: - answer = raw_input(prompt).strip().lower() - if not answer and default is not None: - return default - if answer == "yes" or (not require_full_word and answer.startswith("y")): - return True - if answer == "no" or (not require_full_word and answer.startswith("n")): - return False - print 'Please answer "yes" or "no"' - -_yes_or_no_prompts = { - True : ' ("yes" or "no" ["yes"]) ', - False : ' ("yes" or "no" ["no"]) ', - None : ' ("yes" or "no") ' } - - -class NonExitingArgumentParser(argparse.ArgumentParser): - """ - ArgumentParser tweaked to throw ExitArgparse exception - rather than using sys.exit(), for use with command loop. - """ - - def exit(self, status = 0, message = None): - raise ExitArgparse(status = status, message = message) + The intent is that, from the command's point of view, all of this should work + pretty much the same way regardless of whether the command was invoked from + the global command line or from within the cmd command loop. Either way, + the command method should get an argparse.Namespace object. + In theory, we could generate a completion handler from the argparse definitions, + much as the separate argcomplete package does. In practice this is a lot of + work and I'm not ready to get into that just yet. + """ -def parsecmd(subparsers, *arg_clauses): - """ - Decorator to combine the argparse and cmd modules. - - subparsers is an instance of argparse.ArgumentParser (or subclass) which was - returned by calling the .add_subparsers() method on an ArgumentParser instance - intended to handle parsing for the entire program on the command line. - - arg_clauses is a series of defarg() invocations defining arguments to be parsed - by the argparse code. - - The decorator will use arg_clauses to construct two separate argparse parser - instances: one will be attached to the global parser as a subparser, the - other will be used to parse arguments for this command when invoked by cmd. - - The decorator will replace the original do_whatever method with a wrapped version - which uses the local argparse instance to parse the single string supplied by - the cmd module. - - The intent is that, from the command's point of view, all of this should work - pretty much the same way regardless of whether the command was invoked from - the global command line or from within the cmd command loop. Either way, - the command method should get an argparse.Namespace object. - - In theory, we could generate a completion handler from the argparse definitions, - much as the separate argcomplete package does. In practice this is a lot of - work and I'm not ready to get into that just yet. - """ - - def decorate(func): - assert func.__name__.startswith("do_") - parser = NonExitingArgumentParser(description = func.__doc__, - prog = func.__name__[3:], - add_help = False) - subparser = subparsers.add_parser(func.__name__[3:], - description = func.__doc__, - help = func.__doc__.lstrip().partition("\n")[0]) - for positional, keywords in arg_clauses: - parser.add_argument(*positional, **keywords) - subparser.add_argument(*positional, **keywords) - subparser.set_defaults(func = func) - def wrapped(self, arg): - return func(self, parser.parse_args(shlex.split(arg))) - wrapped.argparser = parser - wrapped.__doc__ = func.__doc__ - return wrapped - return decorate + def decorate(func): + assert func.__name__.startswith("do_") + parser = NonExitingArgumentParser(description = func.__doc__, + prog = func.__name__[3:], + add_help = False) + subparser = subparsers.add_parser(func.__name__[3:], + description = func.__doc__, + help = func.__doc__.lstrip().partition("\n")[0]) + for positional, keywords in arg_clauses: + parser.add_argument(*positional, **keywords) + subparser.add_argument(*positional, **keywords) + subparser.set_defaults(func = func) + def wrapped(self, arg): + return func(self, parser.parse_args(shlex.split(arg))) + wrapped.argparser = parser + wrapped.__doc__ = func.__doc__ + return wrapped + return decorate def cmdarg(*positional, **keywords): - """ - Syntactic sugar to let us use keyword arguments normally when constructing - arguments for deferred calls to argparse.ArgumentParser.add_argument(). - """ + """ + Syntactic sugar to let us use keyword arguments normally when constructing + arguments for deferred calls to argparse.ArgumentParser.add_argument(). + """ - return positional, keywords + return positional, keywords diff --git a/rpki/config.py b/rpki/config.py index 99041259..5dd03a6d 100644 --- a/rpki/config.py +++ b/rpki/config.py @@ -33,10 +33,10 @@ logger = logging.getLogger(__name__) # Default name of config file if caller doesn't specify one explictly. try: - import rpki.autoconf - default_filename = os.path.join(rpki.autoconf.sysconfdir, "rpki.conf") + import rpki.autoconf + default_filename = os.path.join(rpki.autoconf.sysconfdir, "rpki.conf") except ImportError: - default_filename = None + default_filename = None ## @var rpki_conf_envname # Name of environment variable containing config file name. @@ -44,230 +44,230 @@ except ImportError: rpki_conf_envname = "RPKI_CONF" class parser(object): - """ - Extensions to stock Python ConfigParser: - - Read config file and set default section while initializing parser object. - - Support for OpenSSL-style subscripted options and a limited form of - OpenSSL-style indirect variable references (${section::option}). - - get-methods with default values and default section name. - - If no filename is given to the constructor (filename and - set_filename both None), we check for an environment variable naming - the config file, then finally we check for a global config file if - autoconf provided a directory name to check. - - NB: Programs which accept a configuration filename on the command - lines should pass that filename using set_filename so that we can - set the magic environment variable. Constraints from some external - libraries (principally Django) sometimes require library code to - look things up in the configuration file without the knowledge of - the controlling program, but setting the environment variable - insures that everybody's reading from the same script, as it were. - """ - - # Odd keyword-only calling sequence is a defense against old code - # that thinks it knows how __init__() handles positional arguments. - - def __init__(self, **kwargs): - section = kwargs.pop("section", None) - allow_missing = kwargs.pop("allow_missing", False) - set_filename = kwargs.pop("set_filename", None) - filename = kwargs.pop("filename", set_filename) - - assert not kwargs, "Unexpected keyword arguments: " + ", ".join("%s = %r" % kv for kv in kwargs.iteritems()) - - if set_filename is not None: - os.environ[rpki_conf_envname] = set_filename - - self.cfg = ConfigParser.RawConfigParser() - self.default_section = section - - self.filename = filename or os.getenv(rpki_conf_envname) or default_filename - - try: - with open(self.filename, "r") as f: - self.cfg.readfp(f) - except IOError: - if allow_missing: - self.filename = None - else: - raise - - - def has_section(self, section): """ - Test whether a section exists. - """ - - return self.cfg.has_section(section) + Extensions to stock Python ConfigParser: + Read config file and set default section while initializing parser object. - def has_option(self, option, section = None): - """ - Test whether an option exists. - """ + Support for OpenSSL-style subscripted options and a limited form of + OpenSSL-style indirect variable references (${section::option}). - if section is None: - section = self.default_section - return self.cfg.has_option(section, option) + get-methods with default values and default section name. + If no filename is given to the constructor (filename and + set_filename both None), we check for an environment variable naming + the config file, then finally we check for a global config file if + autoconf provided a directory name to check. - def multiget(self, option, section = None): + NB: Programs which accept a configuration filename on the command + lines should pass that filename using set_filename so that we can + set the magic environment variable. Constraints from some external + libraries (principally Django) sometimes require library code to + look things up in the configuration file without the knowledge of + the controlling program, but setting the environment variable + insures that everybody's reading from the same script, as it were. """ - Parse OpenSSL-style foo.0, foo.1, ... subscripted options. - Returns iteration of values matching the specified option name. - """ + # Odd keyword-only calling sequence is a defense against old code + # that thinks it knows how __init__() handles positional arguments. - matches = [] - if section is None: - section = self.default_section - if self.cfg.has_option(section, option): - yield self.cfg.get(section, option) - option += "." - matches = [o for o in self.cfg.options(section) if o.startswith(option) and o[len(option):].isdigit()] - matches.sort() - for option in matches: - yield self.cfg.get(section, option) + def __init__(self, **kwargs): + section = kwargs.pop("section", None) + allow_missing = kwargs.pop("allow_missing", False) + set_filename = kwargs.pop("set_filename", None) + filename = kwargs.pop("filename", set_filename) + assert not kwargs, "Unexpected keyword arguments: " + ", ".join("%s = %r" % kv for kv in kwargs.iteritems()) - _regexp = re.compile("\\${(.*?)::(.*?)}") + if set_filename is not None: + os.environ[rpki_conf_envname] = set_filename - def _repl(self, m): - """ - Replacement function for indirect variable substitution. - This is intended for use with re.subn(). - """ + self.cfg = ConfigParser.RawConfigParser() + self.default_section = section - section, option = m.group(1, 2) - if section == "ENV": - return os.getenv(option, "") - else: - return self.cfg.get(section, option) + self.filename = filename or os.getenv(rpki_conf_envname) or default_filename + try: + with open(self.filename, "r") as f: + self.cfg.readfp(f) + except IOError: + if allow_missing: + self.filename = None + else: + raise - def get(self, option, default = None, section = None): - """ - Get an option, perhaps with a default value. - """ - if section is None: - section = self.default_section - if default is not None and not self.cfg.has_option(section, option): - return default - val = self.cfg.get(section, option) - while True: - val, modified = self._regexp.subn(self._repl, val, 1) - if not modified: - return val + def has_section(self, section): + """ + Test whether a section exists. + """ + return self.cfg.has_section(section) - def getboolean(self, option, default = None, section = None): - """ - Get a boolean option, perhaps with a default value. - """ - v = self.get(option, default, section) - if isinstance(v, str): - v = v.lower() - if v not in self.cfg._boolean_states: - raise ValueError("Not a boolean: %s" % v) - v = self.cfg._boolean_states[v] - return v + def has_option(self, option, section = None): + """ + Test whether an option exists. + """ + if section is None: + section = self.default_section + return self.cfg.has_option(section, option) - def getint(self, option, default = None, section = None): - """ - Get an integer option, perhaps with a default value. - """ - return int(self.get(option, default, section)) + def multiget(self, option, section = None): + """ + Parse OpenSSL-style foo.0, foo.1, ... subscripted options. + Returns iteration of values matching the specified option name. + """ - def getlong(self, option, default = None, section = None): - """ - Get a long integer option, perhaps with a default value. - """ - - return long(self.get(option, default, section)) - - - def set_global_flags(self): - """ - Consolidated control for all the little global control flags - scattered through the libraries. This isn't a particularly good - place for this function to live, but it has to live somewhere and - making it a method of the config parser from which it gets all of - its data is less silly than the available alternatives. - """ - - # pylint: disable=W0621 - import rpki.x509 - import rpki.log - import rpki.daemonize - - for line in self.multiget("configure_logger"): - try: - name, level = line.split() - logging.getLogger(name).setLevel(getattr(logging, level.upper())) - except Exception, e: - logger.warning("Could not process configure_logger line %r: %s", line, e) - - try: - rpki.x509.CMS_object.debug_cms_certs = self.getboolean("debug_cms_certs") - except ConfigParser.NoOptionError: - pass - - try: - rpki.x509.XML_CMS_object.dump_outbound_cms = rpki.x509.DeadDrop(self.get("dump_outbound_cms")) - except OSError, e: - logger.warning("Couldn't initialize mailbox %s: %s", self.get("dump_outbound_cms"), e) - except ConfigParser.NoOptionError: - pass - - try: - rpki.x509.XML_CMS_object.dump_inbound_cms = rpki.x509.DeadDrop(self.get("dump_inbound_cms")) - except OSError, e: - logger.warning("Couldn't initialize mailbox %s: %s", self.get("dump_inbound_cms"), e) - except ConfigParser.NoOptionError: - pass - - try: - rpki.x509.XML_CMS_object.check_inbound_schema = self.getboolean("check_inbound_schema") - except ConfigParser.NoOptionError: - pass - - try: - rpki.x509.XML_CMS_object.check_outbound_schema = self.getboolean("check_outbound_schema") - except ConfigParser.NoOptionError: - pass - - try: - rpki.log.enable_tracebacks = self.getboolean("enable_tracebacks") - except ConfigParser.NoOptionError: - pass - - try: - rpki.daemonize.default_pid_directory = self.get("pid_directory") - except ConfigParser.NoOptionError: - pass - - try: - rpki.daemonize.pid_filename = self.get("pid_filename") - except ConfigParser.NoOptionError: - pass - - try: - rpki.x509.generate_insecure_debug_only_rsa_key = rpki.x509.insecure_debug_only_rsa_key_generator(*self.get("insecure-debug-only-rsa-key-db").split()) - except ConfigParser.NoOptionError: - pass - except: # pylint: disable=W0702 - logger.warning("insecure-debug-only-rsa-key-db configured but initialization failed, check for corrupted database file") - - try: - rpki.up_down.content_type = self.get("up_down_content_type") - except ConfigParser.NoOptionError: - pass + matches = [] + if section is None: + section = self.default_section + if self.cfg.has_option(section, option): + yield self.cfg.get(section, option) + option += "." + matches = [o for o in self.cfg.options(section) if o.startswith(option) and o[len(option):].isdigit()] + matches.sort() + for option in matches: + yield self.cfg.get(section, option) + + + _regexp = re.compile("\\${(.*?)::(.*?)}") + + def _repl(self, m): + """ + Replacement function for indirect variable substitution. + This is intended for use with re.subn(). + """ + + section, option = m.group(1, 2) + if section == "ENV": + return os.getenv(option, "") + else: + return self.cfg.get(section, option) + + + def get(self, option, default = None, section = None): + """ + Get an option, perhaps with a default value. + """ + + if section is None: + section = self.default_section + if default is not None and not self.cfg.has_option(section, option): + return default + val = self.cfg.get(section, option) + while True: + val, modified = self._regexp.subn(self._repl, val, 1) + if not modified: + return val + + + def getboolean(self, option, default = None, section = None): + """ + Get a boolean option, perhaps with a default value. + """ + + v = self.get(option, default, section) + if isinstance(v, str): + v = v.lower() + if v not in self.cfg._boolean_states: + raise ValueError("Not a boolean: %s" % v) + v = self.cfg._boolean_states[v] + return v + + + def getint(self, option, default = None, section = None): + """ + Get an integer option, perhaps with a default value. + """ + + return int(self.get(option, default, section)) + + + def getlong(self, option, default = None, section = None): + """ + Get a long integer option, perhaps with a default value. + """ + + return long(self.get(option, default, section)) + + + def set_global_flags(self): + """ + Consolidated control for all the little global control flags + scattered through the libraries. This isn't a particularly good + place for this function to live, but it has to live somewhere and + making it a method of the config parser from which it gets all of + its data is less silly than the available alternatives. + """ + + # pylint: disable=W0621 + import rpki.x509 + import rpki.log + import rpki.daemonize + + for line in self.multiget("configure_logger"): + try: + name, level = line.split() + logging.getLogger(name).setLevel(getattr(logging, level.upper())) + except Exception, e: + logger.warning("Could not process configure_logger line %r: %s", line, e) + + try: + rpki.x509.CMS_object.debug_cms_certs = self.getboolean("debug_cms_certs") + except ConfigParser.NoOptionError: + pass + + try: + rpki.x509.XML_CMS_object.dump_outbound_cms = rpki.x509.DeadDrop(self.get("dump_outbound_cms")) + except OSError, e: + logger.warning("Couldn't initialize mailbox %s: %s", self.get("dump_outbound_cms"), e) + except ConfigParser.NoOptionError: + pass + + try: + rpki.x509.XML_CMS_object.dump_inbound_cms = rpki.x509.DeadDrop(self.get("dump_inbound_cms")) + except OSError, e: + logger.warning("Couldn't initialize mailbox %s: %s", self.get("dump_inbound_cms"), e) + except ConfigParser.NoOptionError: + pass + + try: + rpki.x509.XML_CMS_object.check_inbound_schema = self.getboolean("check_inbound_schema") + except ConfigParser.NoOptionError: + pass + + try: + rpki.x509.XML_CMS_object.check_outbound_schema = self.getboolean("check_outbound_schema") + except ConfigParser.NoOptionError: + pass + + try: + rpki.log.enable_tracebacks = self.getboolean("enable_tracebacks") + except ConfigParser.NoOptionError: + pass + + try: + rpki.daemonize.default_pid_directory = self.get("pid_directory") + except ConfigParser.NoOptionError: + pass + + try: + rpki.daemonize.pid_filename = self.get("pid_filename") + except ConfigParser.NoOptionError: + pass + + try: + rpki.x509.generate_insecure_debug_only_rsa_key = rpki.x509.insecure_debug_only_rsa_key_generator(*self.get("insecure-debug-only-rsa-key-db").split()) + except ConfigParser.NoOptionError: + pass + except: # pylint: disable=W0702 + logger.warning("insecure-debug-only-rsa-key-db configured but initialization failed, check for corrupted database file") + + try: + rpki.up_down.content_type = self.get("up_down_content_type") + except ConfigParser.NoOptionError: + pass diff --git a/rpki/csv_utils.py b/rpki/csv_utils.py index 9034e96b..2864693c 100644 --- a/rpki/csv_utils.py +++ b/rpki/csv_utils.py @@ -22,93 +22,93 @@ import csv import os class BadCSVSyntax(Exception): - """ - Bad CSV syntax. - """ + """ + Bad CSV syntax. + """ class csv_reader(object): - """ - Reader for tab-delimited text that's (slightly) friendlier than the - stock Python csv module (which isn't intended for direct use by - humans anyway, and neither was this package originally, but that - seems to be the way that it has evolved...). - - Columns parameter specifies how many columns users of the reader - expect to see; lines with fewer columns will be padded with None - values. - - Original API design for this class courtesy of Warren Kumari, but - don't blame him if you don't like what I did with his ideas. - """ - - def __init__(self, filename, columns = None, min_columns = None, comment_characters = "#;"): - assert columns is None or isinstance(columns, int) - assert min_columns is None or isinstance(min_columns, int) - if columns is not None and min_columns is None: - min_columns = columns - self.filename = filename - self.columns = columns - self.min_columns = min_columns - self.comment_characters = comment_characters - self.file = open(filename, "r") - - def __iter__(self): - line_number = 0 - for line in self.file: - line_number += 1 - line = line.strip() - if not line or line[0] in self.comment_characters: - continue - fields = line.split() - if self.min_columns is not None and len(fields) < self.min_columns: - raise BadCSVSyntax("%s:%d: Not enough columns in line %r" % (self.filename, line_number, line)) - if self.columns is not None and len(fields) > self.columns: - raise BadCSVSyntax("%s:%d: Too many columns in line %r" % (self.filename, line_number, line)) - if self.columns is not None and len(fields) < self.columns: - fields += tuple(None for i in xrange(self.columns - len(fields))) - yield fields - - def __enter__(self): - return self - - def __exit__(self, _type, value, traceback): - self.file.close() + """ + Reader for tab-delimited text that's (slightly) friendlier than the + stock Python csv module (which isn't intended for direct use by + humans anyway, and neither was this package originally, but that + seems to be the way that it has evolved...). + + Columns parameter specifies how many columns users of the reader + expect to see; lines with fewer columns will be padded with None + values. + + Original API design for this class courtesy of Warren Kumari, but + don't blame him if you don't like what I did with his ideas. + """ + + def __init__(self, filename, columns = None, min_columns = None, comment_characters = "#;"): + assert columns is None or isinstance(columns, int) + assert min_columns is None or isinstance(min_columns, int) + if columns is not None and min_columns is None: + min_columns = columns + self.filename = filename + self.columns = columns + self.min_columns = min_columns + self.comment_characters = comment_characters + self.file = open(filename, "r") + + def __iter__(self): + line_number = 0 + for line in self.file: + line_number += 1 + line = line.strip() + if not line or line[0] in self.comment_characters: + continue + fields = line.split() + if self.min_columns is not None and len(fields) < self.min_columns: + raise BadCSVSyntax("%s:%d: Not enough columns in line %r" % (self.filename, line_number, line)) + if self.columns is not None and len(fields) > self.columns: + raise BadCSVSyntax("%s:%d: Too many columns in line %r" % (self.filename, line_number, line)) + if self.columns is not None and len(fields) < self.columns: + fields += tuple(None for i in xrange(self.columns - len(fields))) + yield fields + + def __enter__(self): + return self + + def __exit__(self, _type, value, traceback): + self.file.close() class csv_writer(object): - """ - Writer object for tab delimited text. We just use the stock CSV - module in excel-tab mode for this. + """ + Writer object for tab delimited text. We just use the stock CSV + module in excel-tab mode for this. - If "renmwo" is set (default), the file will be written to - a temporary name and renamed to the real filename after closing. - """ + If "renmwo" is set (default), the file will be written to + a temporary name and renamed to the real filename after closing. + """ - def __init__(self, filename, renmwo = True): - self.filename = filename - self.renmwo = "%s.~renmwo%d~" % (filename, os.getpid()) if renmwo else filename - self.file = open(self.renmwo, "w") - self.writer = csv.writer(self.file, dialect = csv.get_dialect("excel-tab")) + def __init__(self, filename, renmwo = True): + self.filename = filename + self.renmwo = "%s.~renmwo%d~" % (filename, os.getpid()) if renmwo else filename + self.file = open(self.renmwo, "w") + self.writer = csv.writer(self.file, dialect = csv.get_dialect("excel-tab")) - def __enter__(self): - return self + def __enter__(self): + return self - def __exit__(self, _type, value, traceback): - self.close() + def __exit__(self, _type, value, traceback): + self.close() - def close(self): - """ - Close this writer. - """ + def close(self): + """ + Close this writer. + """ - if self.file is not None: - self.file.close() - self.file = None - if self.filename != self.renmwo: - os.rename(self.renmwo, self.filename) + if self.file is not None: + self.file.close() + self.file = None + if self.filename != self.renmwo: + os.rename(self.renmwo, self.filename) - def __getattr__(self, attr): - """ - Fake inheritance from whatever object csv.writer deigns to give us. - """ + def __getattr__(self, attr): + """ + Fake inheritance from whatever object csv.writer deigns to give us. + """ - return getattr(self.writer, attr) + return getattr(self.writer, attr) diff --git a/rpki/daemonize.py b/rpki/daemonize.py index 6a825566..bd59fca0 100644 --- a/rpki/daemonize.py +++ b/rpki/daemonize.py @@ -80,56 +80,56 @@ default_pid_directory = "/var/run/rpki" pid_filename = None def daemon(nochdir = False, noclose = False, pidfile = None): - """ - Make this program become a daemon, like 4.4BSD daemon(3), and - write its pid out to a file with cleanup on exit. - """ - - if pidfile is None: - if pid_filename is None: - prog = os.path.splitext(os.path.basename(sys.argv[0]))[0] - pidfile = os.path.join(default_pid_directory, "%s.pid" % prog) + """ + Make this program become a daemon, like 4.4BSD daemon(3), and + write its pid out to a file with cleanup on exit. + """ + + if pidfile is None: + if pid_filename is None: + prog = os.path.splitext(os.path.basename(sys.argv[0]))[0] + pidfile = os.path.join(default_pid_directory, "%s.pid" % prog) + else: + pidfile = pid_filename + + old_sighup_action = signal.signal(signal.SIGHUP, signal.SIG_IGN) + + try: + pid = os.fork() + except OSError, e: + sys.exit("fork() failed: %d (%s)" % (e.errno, e.strerror)) else: - pidfile = pid_filename + if pid > 0: + os._exit(0) - old_sighup_action = signal.signal(signal.SIGHUP, signal.SIG_IGN) + if not nochdir: + os.chdir("/") - try: - pid = os.fork() - except OSError, e: - sys.exit("fork() failed: %d (%s)" % (e.errno, e.strerror)) - else: - if pid > 0: - os._exit(0) + os.setsid() - if not nochdir: - os.chdir("/") + if not noclose: + sys.stdout.flush() + sys.stderr.flush() + fd = os.open(os.devnull, os.O_RDWR) + os.dup2(fd, 0) + os.dup2(fd, 1) + os.dup2(fd, 2) + if fd > 2: + os.close(fd) - os.setsid() + signal.signal(signal.SIGHUP, old_sighup_action) - if not noclose: - sys.stdout.flush() - sys.stderr.flush() - fd = os.open(os.devnull, os.O_RDWR) - os.dup2(fd, 0) - os.dup2(fd, 1) - os.dup2(fd, 2) - if fd > 2: - os.close(fd) + def delete_pid_file(): + try: + os.unlink(pidfile) + except OSError: + pass - signal.signal(signal.SIGHUP, old_sighup_action) + atexit.register(delete_pid_file) - def delete_pid_file(): try: - os.unlink(pidfile) - except OSError: - pass - - atexit.register(delete_pid_file) - - try: - f = open(pidfile, "w") - f.write("%d\n" % os.getpid()) - f.close() - except IOError, e: - logger.warning("Couldn't write PID file %s: %s", pidfile, e.strerror) + f = open(pidfile, "w") + f.write("%d\n" % os.getpid()) + f.close() + except IOError, e: + logger.warning("Couldn't write PID file %s: %s", pidfile, e.strerror) diff --git a/rpki/django_settings/common.py b/rpki/django_settings/common.py index 2f676660..3860d40b 100644 --- a/rpki/django_settings/common.py +++ b/rpki/django_settings/common.py @@ -58,7 +58,7 @@ class DatabaseConfigurator(object): default_sql_engine = "mysql" - def configure(self, cfg, section): + def configure(self, cfg, section): # pylint: disable=W0621 self.cfg = cfg self.section = section engine = cfg.get("sql-engine", section = section, diff --git a/rpki/exceptions.py b/rpki/exceptions.py index f456dfc5..cbdb9f83 100644 --- a/rpki/exceptions.py +++ b/rpki/exceptions.py @@ -22,222 +22,222 @@ Exception definitions for RPKI modules. """ class RPKI_Exception(Exception): - "Base class for RPKI exceptions." + "Base class for RPKI exceptions." class NotInDatabase(RPKI_Exception): - "Lookup failed for an object expected to be in the database." + "Lookup failed for an object expected to be in the database." class BadURISyntax(RPKI_Exception): - "Illegal syntax for a URI." + "Illegal syntax for a URI." class BadStatusCode(RPKI_Exception): - "Unrecognized protocol status code." + "Unrecognized protocol status code." class BadQuery(RPKI_Exception): - "Unexpected protocol query." + "Unexpected protocol query." class DBConsistancyError(RPKI_Exception): - "Found multiple matches for a database query that shouldn't ever return that." + "Found multiple matches for a database query that shouldn't ever return that." class CMSVerificationFailed(RPKI_Exception): - "Verification of a CMS message failed." + "Verification of a CMS message failed." class HTTPRequestFailed(RPKI_Exception): - "HTTP request failed." + "HTTP request failed." class DERObjectConversionError(RPKI_Exception): - "Error trying to convert a DER-based object from one representation to another." + "Error trying to convert a DER-based object from one representation to another." class NotACertificateChain(RPKI_Exception): - "Certificates don't form a proper chain." + "Certificates don't form a proper chain." class BadContactURL(RPKI_Exception): - "Error trying to parse contact URL." + "Error trying to parse contact URL." class BadClassNameSyntax(RPKI_Exception): - "Illegal syntax for a class_name." + "Illegal syntax for a class_name." class BadIssueResponse(RPKI_Exception): - "issue_response PDU with wrong number of classes or certificates." + "issue_response PDU with wrong number of classes or certificates." class NotImplementedYet(RPKI_Exception): - "Internal error -- not implemented yet." + "Internal error -- not implemented yet." class BadPKCS10(RPKI_Exception): - "Bad PKCS #10 object." + "Bad PKCS #10 object." class UpstreamError(RPKI_Exception): - "Received an error from upstream." + "Received an error from upstream." class ChildNotFound(RPKI_Exception): - "Could not find specified child in database." + "Could not find specified child in database." class BSCNotFound(RPKI_Exception): - "Could not find specified BSC in database." + "Could not find specified BSC in database." class BadSender(RPKI_Exception): - "Unexpected XML sender value." + "Unexpected XML sender value." class ClassNameMismatch(RPKI_Exception): - "class_name does not match child context." + "class_name does not match child context." class ClassNameUnknown(RPKI_Exception): - "Unknown class_name." + "Unknown class_name." class SKIMismatch(RPKI_Exception): - "SKI value in response does not match request." + "SKI value in response does not match request." class SubprocessError(RPKI_Exception): - "Subprocess returned unexpected error." + "Subprocess returned unexpected error." class BadIRDBReply(RPKI_Exception): - "Unexpected reply to IRDB query." + "Unexpected reply to IRDB query." class NotFound(RPKI_Exception): - "Object not found in database." + "Object not found in database." class MustBePrefix(RPKI_Exception): - "Resource range cannot be expressed as a prefix." + "Resource range cannot be expressed as a prefix." class TLSValidationError(RPKI_Exception): - "TLS certificate validation error." + "TLS certificate validation error." class MultipleTLSEECert(TLSValidationError): - "Received more than one TLS EE certificate." + "Received more than one TLS EE certificate." class ReceivedTLSCACert(TLSValidationError): - "Received CA certificate via TLS." + "Received CA certificate via TLS." class WrongEContentType(RPKI_Exception): - "Received wrong CMS eContentType." + "Received wrong CMS eContentType." class EmptyPEM(RPKI_Exception): - "Couldn't find PEM block to convert." + "Couldn't find PEM block to convert." class UnexpectedCMSCerts(RPKI_Exception): - "Received CMS certs when not expecting any." + "Received CMS certs when not expecting any." class UnexpectedCMSCRLs(RPKI_Exception): - "Received CMS CRLs when not expecting any." + "Received CMS CRLs when not expecting any." class MissingCMSEEcert(RPKI_Exception): - "Didn't receive CMS EE cert when expecting one." + "Didn't receive CMS EE cert when expecting one." class MissingCMSCRL(RPKI_Exception): - "Didn't receive CMS CRL when expecting one." + "Didn't receive CMS CRL when expecting one." class UnparsableCMSDER(RPKI_Exception): - "Alleged CMS DER wasn't parsable." + "Alleged CMS DER wasn't parsable." class CMSCRLNotSet(RPKI_Exception): - "CMS CRL has not been configured." + "CMS CRL has not been configured." class ServerShuttingDown(RPKI_Exception): - "Server is shutting down." + "Server is shutting down." class NoActiveCA(RPKI_Exception): - "No active ca_detail for specified class." + "No active ca_detail for specified class." class BadClientURL(RPKI_Exception): - "URL given to HTTP client does not match profile." + "URL given to HTTP client does not match profile." class ClientNotFound(RPKI_Exception): - "Could not find specified client in database." + "Could not find specified client in database." class BadExtension(RPKI_Exception): - "Forbidden X.509 extension." + "Forbidden X.509 extension." class ForbiddenURI(RPKI_Exception): - "Forbidden URI, does not start with correct base URI." + "Forbidden URI, does not start with correct base URI." class HTTPClientAborted(RPKI_Exception): - "HTTP client connection closed while in request-sent state." + "HTTP client connection closed while in request-sent state." class BadPublicationReply(RPKI_Exception): - "Unexpected reply to publication query." + "Unexpected reply to publication query." class DuplicateObject(RPKI_Exception): - "Attempt to create an object that already exists." + "Attempt to create an object that already exists." class EmptyROAPrefixList(RPKI_Exception): - "Can't create ROA with an empty prefix list." + "Can't create ROA with an empty prefix list." class NoCoveringCertForROA(RPKI_Exception): - "Couldn't find a covering certificate to generate ROA." + "Couldn't find a covering certificate to generate ROA." class BSCNotReady(RPKI_Exception): - "BSC not yet in a usable state, signing_cert not set." + "BSC not yet in a usable state, signing_cert not set." class HTTPUnexpectedState(RPKI_Exception): - "HTTP event occurred in an unexpected state." + "HTTP event occurred in an unexpected state." class HTTPBadVersion(RPKI_Exception): - "HTTP couldn't parse HTTP version." + "HTTP couldn't parse HTTP version." class HandleTranslationError(RPKI_Exception): - "Internal error translating protocol handle -> SQL id." + "Internal error translating protocol handle -> SQL id." class NoObjectAtURI(RPKI_Exception): - "No object published at specified URI." + "No object published at specified URI." class ExistingObjectAtURI(RPKI_Exception): - "An object has already been published at specified URI." + "An object has already been published at specified URI." class DifferentObjectAtURI(RPKI_Exception): - "An object with a different hash exists at specified URI." + "An object with a different hash exists at specified URI." class CMSContentNotSet(RPKI_Exception): - """ - Inner content of a CMS_object has not been set. If object is known - to be valid, the .extract() method should be able to set the - content; otherwise, only the .verify() method (which checks - signatures) is safe. - """ + """ + Inner content of a CMS_object has not been set. If object is known + to be valid, the .extract() method should be able to set the + content; otherwise, only the .verify() method (which checks + signatures) is safe. + """ class HTTPTimeout(RPKI_Exception): - "HTTP connection timed out." + "HTTP connection timed out." class BadIPResource(RPKI_Exception): - "Parse failure for alleged IP resource string." + "Parse failure for alleged IP resource string." class BadROAPrefix(RPKI_Exception): - "Parse failure for alleged ROA prefix string." + "Parse failure for alleged ROA prefix string." class CommandParseFailure(RPKI_Exception): - "Failed to parse command line." + "Failed to parse command line." class CMSCertHasExpired(RPKI_Exception): - "CMS certificate has expired." + "CMS certificate has expired." class TrustedCMSCertHasExpired(RPKI_Exception): - "Trusted CMS certificate has expired." + "Trusted CMS certificate has expired." class MultipleCMSEECert(RPKI_Exception): - "Can't have more than one CMS EE certificate in validation chain." + "Can't have more than one CMS EE certificate in validation chain." class ResourceOverlap(RPKI_Exception): - "Overlapping resources in resource_set." + "Overlapping resources in resource_set." class CMSReplay(RPKI_Exception): - "Possible CMS replay attack detected." + "Possible CMS replay attack detected." class PastNotAfter(RPKI_Exception): - "Requested notAfter value is already in the past." + "Requested notAfter value is already in the past." class NullValidityInterval(RPKI_Exception): - "Requested validity interval is null." + "Requested validity interval is null." class BadX510DN(RPKI_Exception): - "X.510 distinguished name does not match profile." + "X.510 distinguished name does not match profile." class BadAutonomousSystemNumber(RPKI_Exception): - "Bad AutonomousSystem number." + "Bad AutonomousSystem number." class WrongEKU(RPKI_Exception): - "Extended Key Usage extension does not match profile." + "Extended Key Usage extension does not match profile." class UnexpectedUpDownResponse(RPKI_Exception): - "Up-down message is not of the expected type." + "Up-down message is not of the expected type." class BadContentType(RPKI_Exception): - "Bad HTTP Content-Type." + "Bad HTTP Content-Type." diff --git a/rpki/fields.py b/rpki/fields.py index a470e272..1390d4ac 100644 --- a/rpki/fields.py +++ b/rpki/fields.py @@ -35,78 +35,78 @@ logger = logging.getLogger(__name__) class EnumField(models.PositiveSmallIntegerField): - """ - An enumeration type that uses strings in Python and small integers - in SQL. - """ + """ + An enumeration type that uses strings in Python and small integers + in SQL. + """ - description = "An enumeration type" + description = "An enumeration type" - __metaclass__ = models.SubfieldBase + __metaclass__ = models.SubfieldBase - def __init__(self, *args, **kwargs): - if isinstance(kwargs.get("choices"), (tuple, list)) and isinstance(kwargs["choices"][0], (str, unicode)): - kwargs["choices"] = tuple(enumerate(kwargs["choices"], 1)) - # Might need something here to handle string-valued default parameter - models.PositiveSmallIntegerField.__init__(self, *args, **kwargs) - self.enum_i2s = dict(self.flatchoices) - self.enum_s2i = dict((v, k) for k, v in self.flatchoices) + def __init__(self, *args, **kwargs): + if isinstance(kwargs.get("choices"), (tuple, list)) and isinstance(kwargs["choices"][0], (str, unicode)): + kwargs["choices"] = tuple(enumerate(kwargs["choices"], 1)) + # Might need something here to handle string-valued default parameter + models.PositiveSmallIntegerField.__init__(self, *args, **kwargs) + self.enum_i2s = dict(self.flatchoices) + self.enum_s2i = dict((v, k) for k, v in self.flatchoices) - def to_python(self, value): - return self.enum_i2s.get(value, value) + def to_python(self, value): + return self.enum_i2s.get(value, value) - def get_prep_value(self, value): - return self.enum_s2i.get(value, value) + def get_prep_value(self, value): + return self.enum_s2i.get(value, value) class SundialField(models.DateTimeField): - """ - A field type for our customized datetime objects. - """ - __metaclass__ = models.SubfieldBase + """ + A field type for our customized datetime objects. + """ + __metaclass__ = models.SubfieldBase - description = "A datetime type using our customized datetime objects" + description = "A datetime type using our customized datetime objects" - def to_python(self, value): - if isinstance(value, rpki.sundial.pydatetime.datetime): - return rpki.sundial.datetime.from_datetime( - models.DateTimeField.to_python(self, value)) - else: - return value + def to_python(self, value): + if isinstance(value, rpki.sundial.pydatetime.datetime): + return rpki.sundial.datetime.from_datetime( + models.DateTimeField.to_python(self, value)) + else: + return value - def get_prep_value(self, value): - if isinstance(value, rpki.sundial.datetime): - return value.to_datetime() - else: - return value + def get_prep_value(self, value): + if isinstance(value, rpki.sundial.datetime): + return value.to_datetime() + else: + return value class BlobField(models.Field): - """ - Old BLOB field type, predating Django's BinaryField type. + """ + Old BLOB field type, predating Django's BinaryField type. - Do not use, this is only here for backwards compatabilty during migrations. - """ + Do not use, this is only here for backwards compatabilty during migrations. + """ - __metaclass__ = models.SubfieldBase - description = "Raw BLOB type without ASN.1 encoding/decoding" + __metaclass__ = models.SubfieldBase + description = "Raw BLOB type without ASN.1 encoding/decoding" - def __init__(self, *args, **kwargs): - self.blob_type = kwargs.pop("blob_type", None) - kwargs["serialize"] = False - kwargs["blank"] = True - kwargs["default"] = None - models.Field.__init__(self, *args, **kwargs) + def __init__(self, *args, **kwargs): + self.blob_type = kwargs.pop("blob_type", None) + kwargs["serialize"] = False + kwargs["blank"] = True + kwargs["default"] = None + models.Field.__init__(self, *args, **kwargs) - def db_type(self, connection): - if self.blob_type is not None: - return self.blob_type - elif connection.settings_dict['ENGINE'] == "django.db.backends.mysql": - return "LONGBLOB" - elif connection.settings_dict['ENGINE'] == "django.db.backends.posgresql": - return "bytea" - else: - return "BLOB" + def db_type(self, connection): + if self.blob_type is not None: + return self.blob_type + elif connection.settings_dict['ENGINE'] == "django.db.backends.mysql": + return "LONGBLOB" + elif connection.settings_dict['ENGINE'] == "django.db.backends.posgresql": + return "bytea" + else: + return "BLOB" # For reasons which now escape me, I had a few fields in the old @@ -124,70 +124,70 @@ class BlobField(models.Field): # backwards compatability during migrations, class DERField(models.BinaryField): - """ - Field class for DER objects, with automatic translation between - ASN.1 and Python types. This is an abstract class, concrete field - classes are derived from it. - """ - - def __init__(self, *args, **kwargs): - kwargs["blank"] = True - kwargs["default"] = None - super(DERField, self).__init__(*args, **kwargs) - - def deconstruct(self): - name, path, args, kwargs = super(DERField, self).deconstruct() - del kwargs["blank"] - del kwargs["default"] - return name, path, args, kwargs - - def from_db_value(self, value, expression, connection, context): - if value is not None: - value = self.rpki_type(DER = str(value)) - return value - - def to_python(self, value): - value = super(DERField, self).to_python(value) - if value is not None and not isinstance(value, self.rpki_type): - value = self.rpki_type(DER = str(value)) - return value - - def get_prep_value(self, value): - if value is not None: - value = value.get_DER() - return super(DERField, self).get_prep_value(value) + """ + Field class for DER objects, with automatic translation between + ASN.1 and Python types. This is an abstract class, concrete field + classes are derived from it. + """ + + def __init__(self, *args, **kwargs): + kwargs["blank"] = True + kwargs["default"] = None + super(DERField, self).__init__(*args, **kwargs) + + def deconstruct(self): + name, path, args, kwargs = super(DERField, self).deconstruct() + del kwargs["blank"] + del kwargs["default"] + return name, path, args, kwargs + + def from_db_value(self, value, expression, connection, context): + if value is not None: + value = self.rpki_type(DER = str(value)) + return value + + def to_python(self, value): + value = super(DERField, self).to_python(value) + if value is not None and not isinstance(value, self.rpki_type): + value = self.rpki_type(DER = str(value)) + return value + + def get_prep_value(self, value): + if value is not None: + value = value.get_DER() + return super(DERField, self).get_prep_value(value) class CertificateField(DERField): - description = "X.509 certificate" - rpki_type = rpki.x509.X509 + description = "X.509 certificate" + rpki_type = rpki.x509.X509 class RSAPrivateKeyField(DERField): - description = "RSA keypair" - rpki_type = rpki.x509.RSA + description = "RSA keypair" + rpki_type = rpki.x509.RSA KeyField = RSAPrivateKeyField class PublicKeyField(DERField): - description = "RSA keypair" - rpki_type = rpki.x509.PublicKey + description = "RSA keypair" + rpki_type = rpki.x509.PublicKey class CRLField(DERField): - description = "Certificate Revocation List" - rpki_type = rpki.x509.CRL + description = "Certificate Revocation List" + rpki_type = rpki.x509.CRL class PKCS10Field(DERField): - description = "PKCS #10 certificate request" - rpki_type = rpki.x509.PKCS10 + description = "PKCS #10 certificate request" + rpki_type = rpki.x509.PKCS10 class ManifestField(DERField): - description = "RPKI Manifest" - rpki_type = rpki.x509.SignedManifest + description = "RPKI Manifest" + rpki_type = rpki.x509.SignedManifest class ROAField(DERField): - description = "ROA" - rpki_type = rpki.x509.ROA + description = "ROA" + rpki_type = rpki.x509.ROA class GhostbusterField(DERField): - description = "Ghostbuster Record" - rpki_type = rpki.x509.Ghostbuster + description = "Ghostbuster Record" + rpki_type = rpki.x509.Ghostbuster diff --git a/rpki/gui/app/forms.py b/rpki/gui/app/forms.py index 306b8dce..4a95c8da 100644 --- a/rpki/gui/app/forms.py +++ b/rpki/gui/app/forms.py @@ -170,105 +170,105 @@ def ROARequestFormFactory(conf): """ class Cls(forms.Form): - """Form for entering a ROA request. - - Handles both IPv4 and IPv6.""" - - prefix = forms.CharField( - widget=forms.TextInput(attrs={ - 'autofocus': 'true', 'placeholder': 'Prefix', - 'class': 'span4' - }) - ) - max_prefixlen = forms.CharField( - required=False, - widget=forms.TextInput(attrs={ - 'placeholder': 'Max len', - 'class': 'span1' - }) - ) - asn = forms.IntegerField( - widget=forms.TextInput(attrs={ - 'placeholder': 'ASN', - 'class': 'span1' - }) - ) - protect_children = forms.BooleanField(required=False) - - def __init__(self, *args, **kwargs): - kwargs['auto_id'] = False - super(Cls, self).__init__(*args, **kwargs) - self.conf = conf # conf is the arg to ROARequestFormFactory - self.inline = True - self.use_table = False - - def _as_resource_range(self): - """Convert the prefix in the form to a - rpki.resource_set.resource_range_ip object. - - If there is no mask provided, assume the closest classful mask. - - """ - prefix = self.cleaned_data.get('prefix') - if '/' not in prefix: - p = IPAddress(prefix) - - # determine the first nonzero bit starting from the lsb and - # subtract from the address size to find the closest classful - # mask that contains this single address - prefixlen = 0 - while (p != 0) and (p & 1) == 0: - prefixlen = prefixlen + 1 - p = p >> 1 - mask = p.bits - (8 * (prefixlen / 8)) - prefix = prefix + '/' + str(mask) - - return resource_range_ip.parse_str(prefix) - - def clean_asn(self): - value = self.cleaned_data.get('asn') - if value < 0: - raise forms.ValidationError('AS must be a positive value or 0') - return value - - def clean_prefix(self): - try: - r = self._as_resource_range() - except: - raise forms.ValidationError('invalid prefix') - - manager = models.ResourceRangeAddressV4 if r.version == 4 else models.ResourceRangeAddressV6 - if not manager.objects.filter(cert__conf=self.conf, - prefix_min__lte=r.min, - prefix_max__gte=r.max).exists(): - raise forms.ValidationError('prefix is not allocated to you') - return str(r) - - def clean_max_prefixlen(self): - v = self.cleaned_data.get('max_prefixlen') - if v: - if v[0] == '/': - v = v[1:] # allow user to specify /24 - try: - if int(v) < 0: - raise forms.ValidationError('max prefix length must be positive or 0') - except ValueError: - raise forms.ValidationError('invalid integer value') - return v - - def clean(self): - if 'prefix' in self.cleaned_data: - r = self._as_resource_range() - max_prefixlen = self.cleaned_data.get('max_prefixlen') - max_prefixlen = int(max_prefixlen) if max_prefixlen else r.prefixlen() - if max_prefixlen < r.prefixlen(): - raise forms.ValidationError( - 'max prefix length must be greater than or equal to the prefix length') - if max_prefixlen > r.min.bits: - raise forms.ValidationError( - 'max prefix length (%d) is out of range for IP version (%d)' % (max_prefixlen, r.min.bits)) - self.cleaned_data['max_prefixlen'] = str(max_prefixlen) - return self.cleaned_data + """Form for entering a ROA request. + + Handles both IPv4 and IPv6.""" + + prefix = forms.CharField( + widget=forms.TextInput(attrs={ + 'autofocus': 'true', 'placeholder': 'Prefix', + 'class': 'span4' + }) + ) + max_prefixlen = forms.CharField( + required=False, + widget=forms.TextInput(attrs={ + 'placeholder': 'Max len', + 'class': 'span1' + }) + ) + asn = forms.IntegerField( + widget=forms.TextInput(attrs={ + 'placeholder': 'ASN', + 'class': 'span1' + }) + ) + protect_children = forms.BooleanField(required=False) + + def __init__(self, *args, **kwargs): + kwargs['auto_id'] = False + super(Cls, self).__init__(*args, **kwargs) + self.conf = conf # conf is the arg to ROARequestFormFactory + self.inline = True + self.use_table = False + + def _as_resource_range(self): + """Convert the prefix in the form to a + rpki.resource_set.resource_range_ip object. + + If there is no mask provided, assume the closest classful mask. + + """ + prefix = self.cleaned_data.get('prefix') + if '/' not in prefix: + p = IPAddress(prefix) + + # determine the first nonzero bit starting from the lsb and + # subtract from the address size to find the closest classful + # mask that contains this single address + prefixlen = 0 + while (p != 0) and (p & 1) == 0: + prefixlen = prefixlen + 1 + p = p >> 1 + mask = p.bits - (8 * (prefixlen / 8)) + prefix = prefix + '/' + str(mask) + + return resource_range_ip.parse_str(prefix) + + def clean_asn(self): + value = self.cleaned_data.get('asn') + if value < 0: + raise forms.ValidationError('AS must be a positive value or 0') + return value + + def clean_prefix(self): + try: + r = self._as_resource_range() + except: + raise forms.ValidationError('invalid prefix') + + manager = models.ResourceRangeAddressV4 if r.version == 4 else models.ResourceRangeAddressV6 + if not manager.objects.filter(cert__conf=self.conf, + prefix_min__lte=r.min, + prefix_max__gte=r.max).exists(): + raise forms.ValidationError('prefix is not allocated to you') + return str(r) + + def clean_max_prefixlen(self): + v = self.cleaned_data.get('max_prefixlen') + if v: + if v[0] == '/': + v = v[1:] # allow user to specify /24 + try: + if int(v) < 0: + raise forms.ValidationError('max prefix length must be positive or 0') + except ValueError: + raise forms.ValidationError('invalid integer value') + return v + + def clean(self): + if 'prefix' in self.cleaned_data: + r = self._as_resource_range() + max_prefixlen = self.cleaned_data.get('max_prefixlen') + max_prefixlen = int(max_prefixlen) if max_prefixlen else r.prefixlen() + if max_prefixlen < r.prefixlen(): + raise forms.ValidationError( + 'max prefix length must be greater than or equal to the prefix length') + if max_prefixlen > r.min.bits: + raise forms.ValidationError( + 'max prefix length (%d) is out of range for IP version (%d)' % (max_prefixlen, r.min.bits)) + self.cleaned_data['max_prefixlen'] = str(max_prefixlen) + return self.cleaned_data return Cls diff --git a/rpki/gui/app/views.py b/rpki/gui/app/views.py index 28b8a498..1d468a07 100644 --- a/rpki/gui/app/views.py +++ b/rpki/gui/app/views.py @@ -148,27 +148,27 @@ def generic_import(request, queryset, configure, form_class=None, if handle == '': handle = None try: - # configure_repository returns None, so can't use tuple expansion - # here. Unpack the tuple below if post_import_redirect is None. - r = configure(z, tmpf.name, handle) + # configure_repository returns None, so can't use tuple expansion + # here. Unpack the tuple below if post_import_redirect is None. + r = configure(z, tmpf.name, handle) except lxml.etree.XMLSyntaxError as e: - logger.exception('caught XMLSyntaxError while parsing uploaded file') + logger.exception('caught XMLSyntaxError while parsing uploaded file') messages.error( request, 'The uploaded file has an invalid XML syntax' ) else: - # force rpkid run now - z.synchronize_ca(poke=True) - if post_import_redirect: - url = post_import_redirect - else: - _, handle = r - url = queryset.get(issuer=conf, - handle=handle).get_absolute_url() - return http.HttpResponseRedirect(url) + # force rpkid run now + z.synchronize_ca(poke=True) + if post_import_redirect: + url = post_import_redirect + else: + _, handle = r + url = queryset.get(issuer=conf, + handle=handle).get_absolute_url() + return http.HttpResponseRedirect(url) finally: - os.remove(tmpf.name) + os.remove(tmpf.name) else: form = form_class() @@ -474,10 +474,10 @@ def child_add_prefix(request, pk): child.address_ranges.create(start_ip=str(r.min), end_ip=str(r.max), version=version) Zookeeper( - handle=conf.handle, - logstream=logstream, - disable_signal_handlers=True - ).run_rpkid_now() + handle=conf.handle, + logstream=logstream, + disable_signal_handlers=True + ).run_rpkid_now() return http.HttpResponseRedirect(child.get_absolute_url()) else: form = forms.AddNetForm(child=child) @@ -497,10 +497,10 @@ def child_add_asn(request, pk): r = resource_range_as.parse_str(asns) child.asns.create(start_as=r.min, end_as=r.max) Zookeeper( - handle=conf.handle, - logstream=logstream, - disable_signal_handlers=True - ).run_rpkid_now() + handle=conf.handle, + logstream=logstream, + disable_signal_handlers=True + ).run_rpkid_now() return http.HttpResponseRedirect(child.get_absolute_url()) else: form = forms.AddASNForm(child=child) @@ -531,10 +531,10 @@ def child_edit(request, pk): models.ChildASN.objects.filter(child=child).exclude(pk__in=form.cleaned_data.get('as_ranges')).delete() models.ChildNet.objects.filter(child=child).exclude(pk__in=form.cleaned_data.get('address_ranges')).delete() Zookeeper( - handle=conf.handle, - logstream=logstream, - disable_signal_handlers=True - ).run_rpkid_now() + handle=conf.handle, + logstream=logstream, + disable_signal_handlers=True + ).run_rpkid_now() return http.HttpResponseRedirect(child.get_absolute_url()) else: form = form_class(initial={ @@ -713,27 +713,27 @@ def roa_create_multi(request): v = [] rng.chop_into_prefixes(v) init.extend([{'asn': asn, 'prefix': str(p)} for p in v]) - extra = 0 if init else 1 + extra = 0 if init else 1 formset = formset_factory(forms.ROARequestFormFactory(conf), extra=extra)(initial=init) elif request.method == 'POST': formset = formset_factory(forms.ROARequestFormFactory(conf), extra=0)(request.POST, request.FILES) - # We need to check .has_changed() because .is_valid() will return true - # if the user clicks the Preview button without filling in the blanks - # in the ROA form, leaving the form invalid from this view's POV. + # We need to check .has_changed() because .is_valid() will return true + # if the user clicks the Preview button without filling in the blanks + # in the ROA form, leaving the form invalid from this view's POV. if formset.has_changed() and formset.is_valid(): routes = [] v = [] query = Q() # for matching routes roas = [] for form in formset: - asn = form.cleaned_data['asn'] - rng = resource_range_ip.parse_str(form.cleaned_data['prefix']) - max_prefixlen = int(form.cleaned_data['max_prefixlen']) + asn = form.cleaned_data['asn'] + rng = resource_range_ip.parse_str(form.cleaned_data['prefix']) + max_prefixlen = int(form.cleaned_data['max_prefixlen']) protect_children = form.cleaned_data['protect_children'] roas.append((rng, max_prefixlen, asn, protect_children)) - v.append({'prefix': str(rng), 'max_prefixlen': max_prefixlen, - 'asn': asn}) + v.append({'prefix': str(rng), 'max_prefixlen': max_prefixlen, + 'asn': asn}) query |= Q(prefix_min__gte=rng.min, prefix_max__lte=rng.max) @@ -1451,10 +1451,10 @@ class RouterImportView(FormView): def form_valid(self, form): conf = get_conf(self.request.user, self.request.session['handle']) - tmpf = NamedTemporaryFile(prefix='import', suffix='.xml', - delete=False) - tmpf.write(form.cleaned_data['xml'].read()) - tmpf.close() + tmpf = NamedTemporaryFile(prefix='import', suffix='.xml', + delete=False) + tmpf.write(form.cleaned_data['xml'].read()) + tmpf.close() z = Zookeeper(handle=conf.handle, disable_signal_handlers=True) z.add_router_certificate_request(tmpf.name) z.run_rpkid_now() diff --git a/rpki/gui/cacheview/tests.py b/rpki/gui/cacheview/tests.py index daca07bf..c2958c72 100644 --- a/rpki/gui/cacheview/tests.py +++ b/rpki/gui/cacheview/tests.py @@ -21,4 +21,3 @@ Another way to test that 1 + 1 is equal to 2. >>> 1 + 1 == 2 True """} - diff --git a/rpki/http_simple.py b/rpki/http_simple.py index ee9cac35..6f73def5 100644 --- a/rpki/http_simple.py +++ b/rpki/http_simple.py @@ -31,106 +31,106 @@ default_content_type = "application/x-rpki" class HTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler): - """ - HTTP request handler simple RPKI servers. - """ - - def do_POST(self): - try: - content_type = self.headers.get("Content-Type") - content_length = self.headers.get("Content-Length") - for handler_path, handler, handler_content_type in self.rpki_handlers: - if self.path.startswith(handler_path) and content_type in handler_content_type: - return handler(self, - self.rfile.read() - if content_length is None else - self.rfile.read(int(content_length))) - self.send_error(404, "No handler for path %s" % self.path) - except Exception, e: - logger.exception("Unhandled exception") - self.send_error(501, "Unhandled exception %s" % e) - - def send_cms_response(self, der): - self.send_response(200) - self.send_header("Content-Type", default_content_type) - self.send_header("Content-Length", str(len(der))) - self.end_headers() - self.wfile.write(der) - - def log_message(self, *args): - logger.info(*args, extra = dict(context = "%s:%s" % self.client_address)) - - def send_error(self, code, message = None): - # BaseHTTPRequestHandler.send_error() generates HTML error messages, - # which we don't want, so we override the method to suppress this. - self.send_response(code, message) - self.send_header("Content-Type", default_content_type) - self.send_header("Connection", "close") - self.end_headers() + """ + HTTP request handler simple RPKI servers. + """ + + def do_POST(self): + try: + content_type = self.headers.get("Content-Type") + content_length = self.headers.get("Content-Length") + for handler_path, handler, handler_content_type in self.rpki_handlers: + if self.path.startswith(handler_path) and content_type in handler_content_type: + return handler(self, + self.rfile.read() + if content_length is None else + self.rfile.read(int(content_length))) + self.send_error(404, "No handler for path %s" % self.path) + except Exception, e: + logger.exception("Unhandled exception") + self.send_error(501, "Unhandled exception %s" % e) + + def send_cms_response(self, der): + self.send_response(200) + self.send_header("Content-Type", default_content_type) + self.send_header("Content-Length", str(len(der))) + self.end_headers() + self.wfile.write(der) + + def log_message(self, *args): + logger.info(*args, extra = dict(context = "%s:%s" % self.client_address)) + + def send_error(self, code, message = None): + # BaseHTTPRequestHandler.send_error() generates HTML error messages, + # which we don't want, so we override the method to suppress this. + self.send_response(code, message) + self.send_header("Content-Type", default_content_type) + self.send_header("Connection", "close") + self.end_headers() def server(handlers, port, host = ""): - """ - Run an HTTP server and wait (forever) for connections. - """ + """ + Run an HTTP server and wait (forever) for connections. + """ - if isinstance(handlers, (tuple, list)): - handlers = tuple(h[:3] if len(h) > 2 else (h[0], h[1], default_content_type) - for h in handlers) - else: - handlers = (("/", handlers, default_content_type),) + if isinstance(handlers, (tuple, list)): + handlers = tuple(h[:3] if len(h) > 2 else (h[0], h[1], default_content_type) + for h in handlers) + else: + handlers = (("/", handlers, default_content_type),) - class RequestHandler(HTTPRequestHandler): - rpki_handlers = handlers + class RequestHandler(HTTPRequestHandler): + rpki_handlers = handlers - BaseHTTPServer.HTTPServer((host, port), RequestHandler).serve_forever() + BaseHTTPServer.HTTPServer((host, port), RequestHandler).serve_forever() class BadURL(Exception): - "Bad contact URL" + "Bad contact URL" class RequestFailed(Exception): - "HTTP returned failure" + "HTTP returned failure" class BadContentType(Exception): - "Wrong HTTP Content-Type" + "Wrong HTTP Content-Type" def client(proto_cms_msg, client_key, client_cert, server_ta, server_cert, url, q_msg, debug = False, replay_track = None, client_crl = None, content_type = default_content_type): - """ - Issue single a query and return the response, handling all the CMS and XML goo. - """ + """ + Issue single a query and return the response, handling all the CMS and XML goo. + """ - u = urlparse.urlparse(url) + u = urlparse.urlparse(url) - if u.scheme not in ("", "http") or u.username or u.password or u.params or u.query or u.fragment: - raise BadURL("Unusable URL %s", url) + if u.scheme not in ("", "http") or u.username or u.password or u.params or u.query or u.fragment: + raise BadURL("Unusable URL %s", url) - q_cms = proto_cms_msg() - q_der = q_cms.wrap(q_msg, client_key, client_cert, client_crl) + q_cms = proto_cms_msg() + q_der = q_cms.wrap(q_msg, client_key, client_cert, client_crl) - if debug: - debug.write("<!-- Query -->\n" + q_cms.pretty_print_content() + "\n") + if debug: + debug.write("<!-- Query -->\n" + q_cms.pretty_print_content() + "\n") - http = httplib.HTTPConnection(u.hostname, u.port or httplib.HTTP_PORT) - http.request("POST", u.path, q_der, {"Content-Type" : content_type}) - r = http.getresponse() + http = httplib.HTTPConnection(u.hostname, u.port or httplib.HTTP_PORT) + http.request("POST", u.path, q_der, {"Content-Type" : content_type}) + r = http.getresponse() - if r.status != 200: - raise RequestFailed("HTTP request failed with status %r reason %r" % (r.status, r.reason)) + if r.status != 200: + raise RequestFailed("HTTP request failed with status %r reason %r" % (r.status, r.reason)) - if r.getheader("Content-Type") != content_type: - raise BadContentType("HTTP Content-Type %r, expected %r" % (r.getheader("Content-Type"), content_type)) + if r.getheader("Content-Type") != content_type: + raise BadContentType("HTTP Content-Type %r, expected %r" % (r.getheader("Content-Type"), content_type)) - r_der = r.read() - r_cms = proto_cms_msg(DER = r_der) - r_msg = r_cms.unwrap((server_ta, server_cert)) + r_der = r.read() + r_cms = proto_cms_msg(DER = r_der) + r_msg = r_cms.unwrap((server_ta, server_cert)) - if replay_track is not None: - replay_track.cms_timestamp = r_cms.check_replay(replay_track.cms_timestamp, url) + if replay_track is not None: + replay_track.cms_timestamp = r_cms.check_replay(replay_track.cms_timestamp, url) - if debug: - debug.write("<!-- Reply -->\n" + r_cms.pretty_print_content() + "\n") + if debug: + debug.write("<!-- Reply -->\n" + r_cms.pretty_print_content() + "\n") - return r_msg + return r_msg diff --git a/rpki/ipaddrs.py b/rpki/ipaddrs.py index 25eefd0d..5117585c 100644 --- a/rpki/ipaddrs.py +++ b/rpki/ipaddrs.py @@ -48,99 +48,99 @@ once, here, thus avoiding a lot of duplicate code elsewhere. import socket, struct class v4addr(long): - """ - IPv4 address. + """ + IPv4 address. - Derived from long, but supports IPv4 print syntax. - """ + Derived from long, but supports IPv4 print syntax. + """ - bits = 32 - ipversion = 4 + bits = 32 + ipversion = 4 - def __new__(cls, x): - """ - Construct a v4addr object. - """ + def __new__(cls, x): + """ + Construct a v4addr object. + """ - if isinstance(x, unicode): - x = x.encode("ascii") - if isinstance(x, str): - return cls.from_bytes(socket.inet_pton(socket.AF_INET, ".".join(str(int(i)) for i in x.split(".")))) - else: - return long.__new__(cls, x) + if isinstance(x, unicode): + x = x.encode("ascii") + if isinstance(x, str): + return cls.from_bytes(socket.inet_pton(socket.AF_INET, ".".join(str(int(i)) for i in x.split(".")))) + else: + return long.__new__(cls, x) - def to_bytes(self): - """ - Convert a v4addr object to a raw byte string. - """ + def to_bytes(self): + """ + Convert a v4addr object to a raw byte string. + """ - return struct.pack("!I", long(self)) + return struct.pack("!I", long(self)) - @classmethod - def from_bytes(cls, x): - """ - Convert from a raw byte string to a v4addr object. - """ + @classmethod + def from_bytes(cls, x): + """ + Convert from a raw byte string to a v4addr object. + """ - return cls(struct.unpack("!I", x)[0]) + return cls(struct.unpack("!I", x)[0]) - def __str__(self): - """ - Convert a v4addr object to string format. - """ + def __str__(self): + """ + Convert a v4addr object to string format. + """ - return socket.inet_ntop(socket.AF_INET, self.to_bytes()) + return socket.inet_ntop(socket.AF_INET, self.to_bytes()) class v6addr(long): - """ - IPv6 address. + """ + IPv6 address. - Derived from long, but supports IPv6 print syntax. - """ + Derived from long, but supports IPv6 print syntax. + """ - bits = 128 - ipversion = 6 + bits = 128 + ipversion = 6 - def __new__(cls, x): - """ - Construct a v6addr object. - """ + def __new__(cls, x): + """ + Construct a v6addr object. + """ - if isinstance(x, unicode): - x = x.encode("ascii") - if isinstance(x, str): - return cls.from_bytes(socket.inet_pton(socket.AF_INET6, x)) - else: - return long.__new__(cls, x) + if isinstance(x, unicode): + x = x.encode("ascii") + if isinstance(x, str): + return cls.from_bytes(socket.inet_pton(socket.AF_INET6, x)) + else: + return long.__new__(cls, x) - def to_bytes(self): - """ - Convert a v6addr object to a raw byte string. - """ + def to_bytes(self): + """ + Convert a v6addr object to a raw byte string. + """ - return struct.pack("!QQ", long(self) >> 64, long(self) & 0xFFFFFFFFFFFFFFFF) + return struct.pack("!QQ", long(self) >> 64, long(self) & 0xFFFFFFFFFFFFFFFF) - @classmethod - def from_bytes(cls, x): - """ - Convert from a raw byte string to a v6addr object. - """ + @classmethod + def from_bytes(cls, x): + """ + Convert from a raw byte string to a v6addr object. + """ - x = struct.unpack("!QQ", x) - return cls((x[0] << 64) | x[1]) + x = struct.unpack("!QQ", x) + return cls((x[0] << 64) | x[1]) - def __str__(self): - """ - Convert a v6addr object to string format. - """ + def __str__(self): + """ + Convert a v6addr object to string format. + """ - return socket.inet_ntop(socket.AF_INET6, self.to_bytes()) + return socket.inet_ntop(socket.AF_INET6, self.to_bytes()) def parse(s): - """ - Parse a string as either an IPv4 or IPv6 address, and return object of appropriate class. - """ + """ + Parse a string as either an IPv4 or IPv6 address, and return object of appropriate class. + """ - if isinstance(s, unicode): - s = s.encode("ascii") - return v6addr(s) if ":" in s else v4addr(s) + if isinstance(s, unicode): + s = s.encode("ascii") + return v6addr(s) if ":" in s else v4addr(s) diff --git a/rpki/irdb/models.py b/rpki/irdb/models.py index d2d6256b..4ff5734a 100644 --- a/rpki/irdb/models.py +++ b/rpki/irdb/models.py @@ -65,480 +65,480 @@ ee_certificate_lifetime = rpki.sundial.timedelta(days = 60) # Field classes class HandleField(django.db.models.CharField): - """ - A handle field class. Replace this with SlugField? - """ + """ + A handle field class. Replace this with SlugField? + """ - description = 'A "handle" in one of the RPKI protocols' + description = 'A "handle" in one of the RPKI protocols' - def __init__(self, *args, **kwargs): - kwargs["max_length"] = 120 - django.db.models.CharField.__init__(self, *args, **kwargs) + def __init__(self, *args, **kwargs): + kwargs["max_length"] = 120 + django.db.models.CharField.__init__(self, *args, **kwargs) class SignedReferralField(DERField): - description = "CMS signed object containing XML" - rpki_type = rpki.x509.SignedReferral + description = "CMS signed object containing XML" + rpki_type = rpki.x509.SignedReferral # Custom managers class CertificateManager(django.db.models.Manager): - def get_or_certify(self, **kwargs): - """ - Sort of like .get_or_create(), but for models containing - certificates which need to be generated based on other fields. - - Takes keyword arguments like .get(), checks for existing object. - If none, creates a new one; if found an existing object but some - of the non-key fields don't match, updates the existing object. - Runs certification method for new or updated objects. Returns a - tuple consisting of the object and a boolean indicating whether - anything has changed. - """ + def get_or_certify(self, **kwargs): + """ + Sort of like .get_or_create(), but for models containing + certificates which need to be generated based on other fields. + + Takes keyword arguments like .get(), checks for existing object. + If none, creates a new one; if found an existing object but some + of the non-key fields don't match, updates the existing object. + Runs certification method for new or updated objects. Returns a + tuple consisting of the object and a boolean indicating whether + anything has changed. + """ - changed = False + changed = False - try: - obj = self.get(**self._get_or_certify_keys(kwargs)) + try: + obj = self.get(**self._get_or_certify_keys(kwargs)) - except self.model.DoesNotExist: - obj = self.model(**kwargs) - changed = True + except self.model.DoesNotExist: + obj = self.model(**kwargs) + changed = True - else: - for k in kwargs: - if getattr(obj, k) != kwargs[k]: - setattr(obj, k, kwargs[k]) - changed = True + else: + for k in kwargs: + if getattr(obj, k) != kwargs[k]: + setattr(obj, k, kwargs[k]) + changed = True - if changed: - obj.avow() - obj.save() + if changed: + obj.avow() + obj.save() - return obj, changed + return obj, changed - def _get_or_certify_keys(self, kwargs): - assert len(self.model._meta.unique_together) == 1 - return dict((k, kwargs[k]) for k in self.model._meta.unique_together[0]) + def _get_or_certify_keys(self, kwargs): + assert len(self.model._meta.unique_together) == 1 + return dict((k, kwargs[k]) for k in self.model._meta.unique_together[0]) class ResourceHolderCAManager(CertificateManager): - def _get_or_certify_keys(self, kwargs): - return { "handle" : kwargs["handle"] } + def _get_or_certify_keys(self, kwargs): + return { "handle" : kwargs["handle"] } class ServerCAManager(CertificateManager): - def _get_or_certify_keys(self, kwargs): - return { "pk" : 1 } + def _get_or_certify_keys(self, kwargs): + return { "pk" : 1 } class ResourceHolderEEManager(CertificateManager): - def _get_or_certify_keys(self, kwargs): - return { "issuer" : kwargs["issuer"] } + def _get_or_certify_keys(self, kwargs): + return { "issuer" : kwargs["issuer"] } ### class CA(django.db.models.Model): - certificate = CertificateField() - private_key = RSAPrivateKeyField() - latest_crl = CRLField() - - # Might want to bring these into line with what rpkid does. Current - # variables here were chosen to map easily to what OpenSSL command - # line tool was keeping on disk. - - next_serial = django.db.models.BigIntegerField(default = 1) - next_crl_number = django.db.models.BigIntegerField(default = 1) - last_crl_update = SundialField() - next_crl_update = SundialField() - - class Meta: - abstract = True - - def avow(self): - if self.private_key is None: - self.private_key = rpki.x509.RSA.generate(quiet = True) - now = rpki.sundial.now() - notAfter = now + ca_certificate_lifetime - self.certificate = rpki.x509.X509.bpki_self_certify( - keypair = self.private_key, - subject_name = self.subject_name, - serial = self.next_serial, - now = now, - notAfter = notAfter) - self.next_serial += 1 - self.generate_crl() - return self.certificate - - def certify(self, subject_name, subject_key, validity_interval, is_ca, pathLenConstraint = None): - now = rpki.sundial.now() - notAfter = now + validity_interval - result = self.certificate.bpki_certify( - keypair = self.private_key, - subject_name = subject_name, - subject_key = subject_key, - serial = self.next_serial, - now = now, - notAfter = notAfter, - is_ca = is_ca, - pathLenConstraint = pathLenConstraint) - self.next_serial += 1 - return result - - def revoke(self, cert): - Revocation.objects.create( - issuer = self, - revoked = rpki.sundial.now(), - serial = cert.certificate.getSerial(), - expires = cert.certificate.getNotAfter() + crl_interval) - cert.delete() - self.generate_crl() - - def generate_crl(self): - now = rpki.sundial.now() - self.revocations.filter(expires__lt = now).delete() - revoked = [(r.serial, r.revoked) for r in self.revocations.all()] - self.latest_crl = rpki.x509.CRL.generate( - keypair = self.private_key, - issuer = self.certificate, - serial = self.next_crl_number, - thisUpdate = now, - nextUpdate = now + crl_interval, - revokedCertificates = revoked) - self.last_crl_update = now - self.next_crl_update = now + crl_interval - self.next_crl_number += 1 + certificate = CertificateField() + private_key = RSAPrivateKeyField() + latest_crl = CRLField() + + # Might want to bring these into line with what rpkid does. Current + # variables here were chosen to map easily to what OpenSSL command + # line tool was keeping on disk. + + next_serial = django.db.models.BigIntegerField(default = 1) + next_crl_number = django.db.models.BigIntegerField(default = 1) + last_crl_update = SundialField() + next_crl_update = SundialField() + + class Meta: + abstract = True + + def avow(self): + if self.private_key is None: + self.private_key = rpki.x509.RSA.generate(quiet = True) + now = rpki.sundial.now() + notAfter = now + ca_certificate_lifetime + self.certificate = rpki.x509.X509.bpki_self_certify( + keypair = self.private_key, + subject_name = self.subject_name, + serial = self.next_serial, + now = now, + notAfter = notAfter) + self.next_serial += 1 + self.generate_crl() + return self.certificate + + def certify(self, subject_name, subject_key, validity_interval, is_ca, pathLenConstraint = None): + now = rpki.sundial.now() + notAfter = now + validity_interval + result = self.certificate.bpki_certify( + keypair = self.private_key, + subject_name = subject_name, + subject_key = subject_key, + serial = self.next_serial, + now = now, + notAfter = notAfter, + is_ca = is_ca, + pathLenConstraint = pathLenConstraint) + self.next_serial += 1 + return result + + def revoke(self, cert): + Revocation.objects.create( + issuer = self, + revoked = rpki.sundial.now(), + serial = cert.certificate.getSerial(), + expires = cert.certificate.getNotAfter() + crl_interval) + cert.delete() + self.generate_crl() + + def generate_crl(self): + now = rpki.sundial.now() + self.revocations.filter(expires__lt = now).delete() + revoked = [(r.serial, r.revoked) for r in self.revocations.all()] + self.latest_crl = rpki.x509.CRL.generate( + keypair = self.private_key, + issuer = self.certificate, + serial = self.next_crl_number, + thisUpdate = now, + nextUpdate = now + crl_interval, + revokedCertificates = revoked) + self.last_crl_update = now + self.next_crl_update = now + crl_interval + self.next_crl_number += 1 class ServerCA(CA): - objects = ServerCAManager() + objects = ServerCAManager() - def __unicode__(self): - return "" + def __unicode__(self): + return "" - @property - def subject_name(self): - if self.certificate is not None: - return self.certificate.getSubject() - else: - return rpki.x509.X501DN.from_cn("%s BPKI server CA" % socket.gethostname()) + @property + def subject_name(self): + if self.certificate is not None: + return self.certificate.getSubject() + else: + return rpki.x509.X501DN.from_cn("%s BPKI server CA" % socket.gethostname()) class ResourceHolderCA(CA): - handle = HandleField(unique = True) - objects = ResourceHolderCAManager() + handle = HandleField(unique = True) + objects = ResourceHolderCAManager() - def __unicode__(self): - return self.handle + def __unicode__(self): + return self.handle - @property - def subject_name(self): - if self.certificate is not None: - return self.certificate.getSubject() - else: - return rpki.x509.X501DN.from_cn("%s BPKI resource CA" % self.handle) + @property + def subject_name(self): + if self.certificate is not None: + return self.certificate.getSubject() + else: + return rpki.x509.X501DN.from_cn("%s BPKI resource CA" % self.handle) class Certificate(django.db.models.Model): - certificate = CertificateField() - objects = CertificateManager() + certificate = CertificateField() + objects = CertificateManager() - class Meta: - abstract = True - unique_together = ("issuer", "handle") + class Meta: + abstract = True + unique_together = ("issuer", "handle") - def revoke(self): - self.issuer.revoke(self) + def revoke(self): + self.issuer.revoke(self) class CrossCertification(Certificate): - handle = HandleField() - ta = CertificateField() + handle = HandleField() + ta = CertificateField() - class Meta: - abstract = True + class Meta: + abstract = True - def avow(self): - self.certificate = self.issuer.certify( - subject_name = self.ta.getSubject(), - subject_key = self.ta.getPublicKey(), - validity_interval = ee_certificate_lifetime, - is_ca = True, - pathLenConstraint = 0) + def avow(self): + self.certificate = self.issuer.certify( + subject_name = self.ta.getSubject(), + subject_key = self.ta.getPublicKey(), + validity_interval = ee_certificate_lifetime, + is_ca = True, + pathLenConstraint = 0) - def __unicode__(self): - return self.handle + def __unicode__(self): + return self.handle class HostedCA(Certificate): - issuer = django.db.models.ForeignKey(ServerCA) - hosted = django.db.models.OneToOneField(ResourceHolderCA, related_name = "hosted_by") + issuer = django.db.models.ForeignKey(ServerCA) + hosted = django.db.models.OneToOneField(ResourceHolderCA, related_name = "hosted_by") - def avow(self): - self.certificate = self.issuer.certify( - subject_name = self.hosted.certificate.getSubject(), - subject_key = self.hosted.certificate.getPublicKey(), - validity_interval = ee_certificate_lifetime, - is_ca = True, - pathLenConstraint = 1) + def avow(self): + self.certificate = self.issuer.certify( + subject_name = self.hosted.certificate.getSubject(), + subject_key = self.hosted.certificate.getPublicKey(), + validity_interval = ee_certificate_lifetime, + is_ca = True, + pathLenConstraint = 1) - class Meta: - unique_together = ("issuer", "hosted") + class Meta: + unique_together = ("issuer", "hosted") - def __unicode__(self): - return self.hosted.handle + def __unicode__(self): + return self.hosted.handle class Revocation(django.db.models.Model): - serial = django.db.models.BigIntegerField() - revoked = SundialField() - expires = SundialField() + serial = django.db.models.BigIntegerField() + revoked = SundialField() + expires = SundialField() - class Meta: - abstract = True - unique_together = ("issuer", "serial") + class Meta: + abstract = True + unique_together = ("issuer", "serial") class ServerRevocation(Revocation): - issuer = django.db.models.ForeignKey(ServerCA, related_name = "revocations") + issuer = django.db.models.ForeignKey(ServerCA, related_name = "revocations") class ResourceHolderRevocation(Revocation): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "revocations") + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "revocations") class EECertificate(Certificate): - private_key = RSAPrivateKeyField() + private_key = RSAPrivateKeyField() - class Meta: - abstract = True + class Meta: + abstract = True - def avow(self): - if self.private_key is None: - self.private_key = rpki.x509.RSA.generate(quiet = True) - self.certificate = self.issuer.certify( - subject_name = self.subject_name, - subject_key = self.private_key.get_public(), - validity_interval = ee_certificate_lifetime, - is_ca = False) + def avow(self): + if self.private_key is None: + self.private_key = rpki.x509.RSA.generate(quiet = True) + self.certificate = self.issuer.certify( + subject_name = self.subject_name, + subject_key = self.private_key.get_public(), + validity_interval = ee_certificate_lifetime, + is_ca = False) class ServerEE(EECertificate): - issuer = django.db.models.ForeignKey(ServerCA, related_name = "ee_certificates") - purpose = EnumField(choices = ("rpkid", "pubd", "irdbd", "irbe")) + issuer = django.db.models.ForeignKey(ServerCA, related_name = "ee_certificates") + purpose = EnumField(choices = ("rpkid", "pubd", "irdbd", "irbe")) - class Meta: - unique_together = ("issuer", "purpose") + class Meta: + unique_together = ("issuer", "purpose") - @property - def subject_name(self): - return rpki.x509.X501DN.from_cn("%s BPKI %s EE" % (socket.gethostname(), - self.get_purpose_display())) + @property + def subject_name(self): + return rpki.x509.X501DN.from_cn("%s BPKI %s EE" % (socket.gethostname(), + self.get_purpose_display())) class Referral(EECertificate): - issuer = django.db.models.OneToOneField(ResourceHolderCA, related_name = "referral_certificate") - objects = ResourceHolderEEManager() + issuer = django.db.models.OneToOneField(ResourceHolderCA, related_name = "referral_certificate") + objects = ResourceHolderEEManager() - @property - def subject_name(self): - return rpki.x509.X501DN.from_cn("%s BPKI Referral EE" % self.issuer.handle) + @property + def subject_name(self): + return rpki.x509.X501DN.from_cn("%s BPKI Referral EE" % self.issuer.handle) class Turtle(django.db.models.Model): - service_uri = django.db.models.CharField(max_length = 255) + service_uri = django.db.models.CharField(max_length = 255) class Rootd(EECertificate, Turtle): - issuer = django.db.models.OneToOneField(ResourceHolderCA, related_name = "rootd") - objects = ResourceHolderEEManager() + issuer = django.db.models.OneToOneField(ResourceHolderCA, related_name = "rootd") + objects = ResourceHolderEEManager() - @property - def subject_name(self): - return rpki.x509.X501DN.from_cn("%s BPKI rootd EE" % self.issuer.handle) + @property + def subject_name(self): + return rpki.x509.X501DN.from_cn("%s BPKI rootd EE" % self.issuer.handle) class BSC(Certificate): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "bscs") - handle = HandleField() - pkcs10 = PKCS10Field() + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "bscs") + handle = HandleField() + pkcs10 = PKCS10Field() - def avow(self): - self.certificate = self.issuer.certify( - subject_name = self.pkcs10.getSubject(), - subject_key = self.pkcs10.getPublicKey(), - validity_interval = ee_certificate_lifetime, - is_ca = False) + def avow(self): + self.certificate = self.issuer.certify( + subject_name = self.pkcs10.getSubject(), + subject_key = self.pkcs10.getPublicKey(), + validity_interval = ee_certificate_lifetime, + is_ca = False) - def __unicode__(self): - return self.handle + def __unicode__(self): + return self.handle class ResourceSet(django.db.models.Model): - valid_until = SundialField() + valid_until = SundialField() - class Meta: - abstract = True + class Meta: + abstract = True - @property - def resource_bag(self): - raw_asn, raw_net = self._select_resource_bag() - asns = rpki.resource_set.resource_set_as.from_django( - (a.start_as, a.end_as) for a in raw_asn) - ipv4 = rpki.resource_set.resource_set_ipv4.from_django( - (a.start_ip, a.end_ip) for a in raw_net if a.version == "IPv4") - ipv6 = rpki.resource_set.resource_set_ipv6.from_django( - (a.start_ip, a.end_ip) for a in raw_net if a.version == "IPv6") - return rpki.resource_set.resource_bag( - valid_until = self.valid_until, asn = asns, v4 = ipv4, v6 = ipv6) + @property + def resource_bag(self): + raw_asn, raw_net = self._select_resource_bag() + asns = rpki.resource_set.resource_set_as.from_django( + (a.start_as, a.end_as) for a in raw_asn) + ipv4 = rpki.resource_set.resource_set_ipv4.from_django( + (a.start_ip, a.end_ip) for a in raw_net if a.version == "IPv4") + ipv6 = rpki.resource_set.resource_set_ipv6.from_django( + (a.start_ip, a.end_ip) for a in raw_net if a.version == "IPv6") + return rpki.resource_set.resource_bag( + valid_until = self.valid_until, asn = asns, v4 = ipv4, v6 = ipv6) - # Writing of .setter method deferred until something needs it. + # Writing of .setter method deferred until something needs it. class ResourceSetASN(django.db.models.Model): - start_as = django.db.models.BigIntegerField() - end_as = django.db.models.BigIntegerField() + start_as = django.db.models.BigIntegerField() + end_as = django.db.models.BigIntegerField() - class Meta: - abstract = True + class Meta: + abstract = True - def as_resource_range(self): - return rpki.resource_set.resource_range_as(self.start_as, self.end_as) + def as_resource_range(self): + return rpki.resource_set.resource_range_as(self.start_as, self.end_as) class ResourceSetNet(django.db.models.Model): - start_ip = django.db.models.CharField(max_length = 40) - end_ip = django.db.models.CharField(max_length = 40) - version = EnumField(choices = ip_version_choices) + start_ip = django.db.models.CharField(max_length = 40) + end_ip = django.db.models.CharField(max_length = 40) + version = EnumField(choices = ip_version_choices) - class Meta: - abstract = True + class Meta: + abstract = True - def as_resource_range(self): - return rpki.resource_set.resource_range_ip.from_strings(self.start_ip, self.end_ip) + def as_resource_range(self): + return rpki.resource_set.resource_range_ip.from_strings(self.start_ip, self.end_ip) class Child(CrossCertification, ResourceSet): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "children") - name = django.db.models.TextField(null = True, blank = True) - - def _select_resource_bag(self): - child_asn = rpki.irdb.models.ChildASN.objects.raw(""" - SELECT * - FROM irdb_childasn - WHERE child_id = %s - """, [self.id]) - child_net = list(rpki.irdb.models.ChildNet.objects.raw(""" - SELECT * - FROM irdb_childnet - WHERE child_id = %s - """, [self.id])) - return child_asn, child_net - - class Meta: - unique_together = ("issuer", "handle") + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "children") + name = django.db.models.TextField(null = True, blank = True) + + def _select_resource_bag(self): + child_asn = rpki.irdb.models.ChildASN.objects.raw(""" + SELECT * + FROM irdb_childasn + WHERE child_id = %s + """, [self.id]) + child_net = list(rpki.irdb.models.ChildNet.objects.raw(""" + SELECT * + FROM irdb_childnet + WHERE child_id = %s + """, [self.id])) + return child_asn, child_net + + class Meta: + unique_together = ("issuer", "handle") class ChildASN(ResourceSetASN): - child = django.db.models.ForeignKey(Child, related_name = "asns") + child = django.db.models.ForeignKey(Child, related_name = "asns") - class Meta: - unique_together = ("child", "start_as", "end_as") + class Meta: + unique_together = ("child", "start_as", "end_as") class ChildNet(ResourceSetNet): - child = django.db.models.ForeignKey(Child, related_name = "address_ranges") + child = django.db.models.ForeignKey(Child, related_name = "address_ranges") - class Meta: - unique_together = ("child", "start_ip", "end_ip", "version") + class Meta: + unique_together = ("child", "start_ip", "end_ip", "version") class Parent(CrossCertification, Turtle): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "parents") - parent_handle = HandleField() - child_handle = HandleField() - repository_type = EnumField(choices = ("none", "offer", "referral")) - referrer = HandleField(null = True, blank = True) - referral_authorization = SignedReferralField(null = True, blank = True) + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "parents") + parent_handle = HandleField() + child_handle = HandleField() + repository_type = EnumField(choices = ("none", "offer", "referral")) + referrer = HandleField(null = True, blank = True) + referral_authorization = SignedReferralField(null = True, blank = True) - # This shouldn't be necessary - class Meta: - unique_together = ("issuer", "handle") + # This shouldn't be necessary + class Meta: + unique_together = ("issuer", "handle") class ROARequest(django.db.models.Model): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "roa_requests") - asn = django.db.models.BigIntegerField() - - @property - def roa_prefix_bag(self): - prefixes = list(rpki.irdb.models.ROARequestPrefix.objects.raw(""" - SELECT * - FROM irdb_roarequestprefix - WHERE roa_request_id = %s - """, [self.id])) - v4 = rpki.resource_set.roa_prefix_set_ipv4.from_django( - (p.prefix, p.prefixlen, p.max_prefixlen) for p in prefixes if p.version == "IPv4") - v6 = rpki.resource_set.roa_prefix_set_ipv6.from_django( - (p.prefix, p.prefixlen, p.max_prefixlen) for p in prefixes if p.version == "IPv6") - return rpki.resource_set.roa_prefix_bag(v4 = v4, v6 = v6) - - # Writing of .setter method deferred until something needs it. + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "roa_requests") + asn = django.db.models.BigIntegerField() + + @property + def roa_prefix_bag(self): + prefixes = list(rpki.irdb.models.ROARequestPrefix.objects.raw(""" + SELECT * + FROM irdb_roarequestprefix + WHERE roa_request_id = %s + """, [self.id])) + v4 = rpki.resource_set.roa_prefix_set_ipv4.from_django( + (p.prefix, p.prefixlen, p.max_prefixlen) for p in prefixes if p.version == "IPv4") + v6 = rpki.resource_set.roa_prefix_set_ipv6.from_django( + (p.prefix, p.prefixlen, p.max_prefixlen) for p in prefixes if p.version == "IPv6") + return rpki.resource_set.roa_prefix_bag(v4 = v4, v6 = v6) + + # Writing of .setter method deferred until something needs it. class ROARequestPrefix(django.db.models.Model): - roa_request = django.db.models.ForeignKey(ROARequest, related_name = "prefixes") - version = EnumField(choices = ip_version_choices) - prefix = django.db.models.CharField(max_length = 40) - prefixlen = django.db.models.PositiveSmallIntegerField() - max_prefixlen = django.db.models.PositiveSmallIntegerField() + roa_request = django.db.models.ForeignKey(ROARequest, related_name = "prefixes") + version = EnumField(choices = ip_version_choices) + prefix = django.db.models.CharField(max_length = 40) + prefixlen = django.db.models.PositiveSmallIntegerField() + max_prefixlen = django.db.models.PositiveSmallIntegerField() - def as_roa_prefix(self): - if self.version == 'IPv4': - return rpki.resource_set.roa_prefix_ipv4(rpki.POW.IPAddress(self.prefix), self.prefixlen, self.max_prefixlen) - else: - return rpki.resource_set.roa_prefix_ipv6(rpki.POW.IPAddress(self.prefix), self.prefixlen, self.max_prefixlen) + def as_roa_prefix(self): + if self.version == 'IPv4': + return rpki.resource_set.roa_prefix_ipv4(rpki.POW.IPAddress(self.prefix), self.prefixlen, self.max_prefixlen) + else: + return rpki.resource_set.roa_prefix_ipv6(rpki.POW.IPAddress(self.prefix), self.prefixlen, self.max_prefixlen) - def as_resource_range(self): - return self.as_roa_prefix().to_resource_range() + def as_resource_range(self): + return self.as_roa_prefix().to_resource_range() - class Meta: - unique_together = ("roa_request", "version", "prefix", "prefixlen", "max_prefixlen") + class Meta: + unique_together = ("roa_request", "version", "prefix", "prefixlen", "max_prefixlen") class GhostbusterRequest(django.db.models.Model): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "ghostbuster_requests") - parent = django.db.models.ForeignKey(Parent, related_name = "ghostbuster_requests", null = True) - vcard = django.db.models.TextField() + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "ghostbuster_requests") + parent = django.db.models.ForeignKey(Parent, related_name = "ghostbuster_requests", null = True) + vcard = django.db.models.TextField() class EECertificateRequest(ResourceSet): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "ee_certificate_requests") - pkcs10 = PKCS10Field() - gski = django.db.models.CharField(max_length = 27) - cn = django.db.models.CharField(max_length = 64) - sn = django.db.models.CharField(max_length = 64) - eku = django.db.models.TextField(null = True) - - def _select_resource_bag(self): - ee_asn = rpki.irdb.models.EECertificateRequestASN.objects.raw(""" - SELECT * - FROM irdb_eecertificaterequestasn - WHERE ee_certificate_request_id = %s - """, [self.id]) - ee_net = rpki.irdb.models.EECertificateRequestNet.objects.raw(""" - SELECT * - FROM irdb_eecertificaterequestnet - WHERE ee_certificate_request_id = %s - """, [self.id]) - return ee_asn, ee_net - - class Meta: - unique_together = ("issuer", "gski") + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "ee_certificate_requests") + pkcs10 = PKCS10Field() + gski = django.db.models.CharField(max_length = 27) + cn = django.db.models.CharField(max_length = 64) + sn = django.db.models.CharField(max_length = 64) + eku = django.db.models.TextField(null = True) + + def _select_resource_bag(self): + ee_asn = rpki.irdb.models.EECertificateRequestASN.objects.raw(""" + SELECT * + FROM irdb_eecertificaterequestasn + WHERE ee_certificate_request_id = %s + """, [self.id]) + ee_net = rpki.irdb.models.EECertificateRequestNet.objects.raw(""" + SELECT * + FROM irdb_eecertificaterequestnet + WHERE ee_certificate_request_id = %s + """, [self.id]) + return ee_asn, ee_net + + class Meta: + unique_together = ("issuer", "gski") class EECertificateRequestASN(ResourceSetASN): - ee_certificate_request = django.db.models.ForeignKey(EECertificateRequest, related_name = "asns") + ee_certificate_request = django.db.models.ForeignKey(EECertificateRequest, related_name = "asns") - class Meta: - unique_together = ("ee_certificate_request", "start_as", "end_as") + class Meta: + unique_together = ("ee_certificate_request", "start_as", "end_as") class EECertificateRequestNet(ResourceSetNet): - ee_certificate_request = django.db.models.ForeignKey(EECertificateRequest, related_name = "address_ranges") + ee_certificate_request = django.db.models.ForeignKey(EECertificateRequest, related_name = "address_ranges") - class Meta: - unique_together = ("ee_certificate_request", "start_ip", "end_ip", "version") + class Meta: + unique_together = ("ee_certificate_request", "start_ip", "end_ip", "version") class Repository(CrossCertification): - issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "repositories") - client_handle = HandleField() - service_uri = django.db.models.CharField(max_length = 255) - sia_base = django.db.models.TextField() - rrdp_notification_uri = django.db.models.TextField(null = True) - turtle = django.db.models.OneToOneField(Turtle, related_name = "repository") + issuer = django.db.models.ForeignKey(ResourceHolderCA, related_name = "repositories") + client_handle = HandleField() + service_uri = django.db.models.CharField(max_length = 255) + sia_base = django.db.models.TextField() + rrdp_notification_uri = django.db.models.TextField(null = True) + turtle = django.db.models.OneToOneField(Turtle, related_name = "repository") - # This shouldn't be necessary - class Meta: - unique_together = ("issuer", "handle") + # This shouldn't be necessary + class Meta: + unique_together = ("issuer", "handle") class Client(CrossCertification): - issuer = django.db.models.ForeignKey(ServerCA, related_name = "clients") - sia_base = django.db.models.TextField() + issuer = django.db.models.ForeignKey(ServerCA, related_name = "clients") + sia_base = django.db.models.TextField() - # This shouldn't be necessary - class Meta: - unique_together = ("issuer", "handle") + # This shouldn't be necessary + class Meta: + unique_together = ("issuer", "handle") diff --git a/rpki/irdb/router.py b/rpki/irdb/router.py index 0aaf53ce..3cbd52f9 100644 --- a/rpki/irdb/router.py +++ b/rpki/irdb/router.py @@ -27,69 +27,69 @@ accomplishes this. """ class DBContextRouter(object): - """ - A Django database router for use with multiple IRDBs. - - This router is designed to work in conjunction with the - rpki.irdb.database context handler (q.v.). - """ - - _app = "irdb" - - _database = None - - def db_for_read(self, model, **hints): - if model._meta.app_label == self._app: - return self._database - else: - return None - - def db_for_write(self, model, **hints): - if model._meta.app_label == self._app: - return self._database - else: - return None - - def allow_relation(self, obj1, obj2, **hints): - if self._database is None: - return None - elif obj1._meta.app_label == self._app and obj2._meta.app_label == self._app: - return True - else: - return None - - def allow_migrate(self, db, model): - if db == self._database and model._meta.app_label == self._app: - return True - else: - return None + """ + A Django database router for use with multiple IRDBs. + + This router is designed to work in conjunction with the + rpki.irdb.database context handler (q.v.). + """ + + _app = "irdb" + + _database = None + + def db_for_read(self, model, **hints): + if model._meta.app_label == self._app: + return self._database + else: + return None + + def db_for_write(self, model, **hints): + if model._meta.app_label == self._app: + return self._database + else: + return None + + def allow_relation(self, obj1, obj2, **hints): + if self._database is None: + return None + elif obj1._meta.app_label == self._app and obj2._meta.app_label == self._app: + return True + else: + return None + + def allow_migrate(self, db, model): + if db == self._database and model._meta.app_label == self._app: + return True + else: + return None class database(object): - """ - Context manager for use with DBContextRouter. Use thusly: - - with rpki.irdb.database("blarg"): - do_stuff() - - This binds IRDB operations to database blarg for the duration of - the call to do_stuff(), then restores the prior state. - """ - - def __init__(self, name, on_entry = None, on_exit = None): - if not isinstance(name, str): - raise ValueError("database name must be a string, not %r" % name) - self.name = name - self.on_entry = on_entry - self.on_exit = on_exit - - def __enter__(self): - if self.on_entry is not None: - self.on_entry() - self.former = DBContextRouter._database - DBContextRouter._database = self.name - - def __exit__(self, _type, value, traceback): - assert DBContextRouter._database is self.name - DBContextRouter._database = self.former - if self.on_exit is not None: - self.on_exit() + """ + Context manager for use with DBContextRouter. Use thusly: + + with rpki.irdb.database("blarg"): + do_stuff() + + This binds IRDB operations to database blarg for the duration of + the call to do_stuff(), then restores the prior state. + """ + + def __init__(self, name, on_entry = None, on_exit = None): + if not isinstance(name, str): + raise ValueError("database name must be a string, not %r" % name) + self.name = name + self.on_entry = on_entry + self.on_exit = on_exit + + def __enter__(self): + if self.on_entry is not None: + self.on_entry() + self.former = DBContextRouter._database + DBContextRouter._database = self.name + + def __exit__(self, _type, value, traceback): + assert DBContextRouter._database is self.name + DBContextRouter._database = self.former + if self.on_exit is not None: + self.on_exit() diff --git a/rpki/irdb/zookeeper.py b/rpki/irdb/zookeeper.py index 7202f421..a65f1f5f 100644 --- a/rpki/irdb/zookeeper.py +++ b/rpki/irdb/zookeeper.py @@ -96,1651 +96,1654 @@ class CouldntFindRepoParent(Exception): "Couldn't find repository's parent." def B64Element(e, tag, obj, **kwargs): - """ - Create an XML element containing Base64 encoded data taken from a - DER object. - """ - - if e is None: - se = Element(tag, **kwargs) - else: - se = SubElement(e, tag, **kwargs) - if e is not None and e.text is None: - e.text = "\n" - se.text = "\n" + obj.get_Base64() - se.tail = "\n" - return se - -class PEM_writer(object): - """ - Write PEM files to disk, keeping track of which ones we've already - written and setting the file mode appropriately. - - Comparing the old file with what we're about to write serves no real - purpose except to calm users who find repeated messages about - writing the same file confusing. - """ - - def __init__(self, logstream = None): - self.wrote = set() - self.logstream = logstream - - def __call__(self, filename, obj, compare = True): - filename = os.path.realpath(filename) - if filename in self.wrote: - return - tempname = filename - pem = obj.get_PEM() - if not filename.startswith("/dev/"): - try: - if compare and pem == open(filename, "r").read(): - return - except: # pylint: disable=W0702 - pass - tempname += ".%s.tmp" % os.getpid() - mode = 0400 if filename.endswith(".key") else 0444 - if self.logstream is not None: - self.logstream.write("Writing %s\n" % filename) - f = os.fdopen(os.open(tempname, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode), "w") - f.write(pem) - f.close() - if tempname != filename: - os.rename(tempname, filename) - self.wrote.add(filename) - - -def etree_read(filename_or_etree_wrapper, schema = rpki.relaxng.oob_setup): - """ - Read an etree from a file, verifying then stripping XML namespace - cruft. As a convenience, we also accept an etree_wrapper object in - place of a filename, in which case we deepcopy the etree directly - from the etree_wrapper and there's no need for a file. - """ - - if isinstance(filename_or_etree_wrapper, etree_wrapper): - e = copy.deepcopy(filename_or_etree_wrapper.etree) - else: - e = ElementTree(file = filename_or_etree_wrapper).getroot() - schema.assertValid(e) - return e - - -class etree_wrapper(object): - """ - Wrapper for ETree objects so we can return them as function results - without requiring the caller to understand much about them. - """ - - def __init__(self, e, msg = None, debug = False, schema = rpki.relaxng.oob_setup): - self.msg = msg - e = copy.deepcopy(e) - if debug: - print ElementToString(e) - schema.assertValid(e) - self.etree = e - - def __str__(self): - return ElementToString(self.etree) - - def save(self, filename, logstream = None): - filename = os.path.realpath(filename) - tempname = filename - if not filename.startswith("/dev/"): - tempname += ".%s.tmp" % os.getpid() - ElementTree(self.etree).write(tempname) - if tempname != filename: - os.rename(tempname, filename) - if logstream is not None: - logstream.write("Wrote %s\n" % filename) - if self.msg is not None: - logstream.write(self.msg + "\n") - - @property - def file(self): - from cStringIO import StringIO - return StringIO(ElementToString(self.etree)) - - -class Zookeeper(object): - - ## @var show_xml - # If not None, a file-like object to which to prettyprint XML, for debugging. - - show_xml = None - - def __init__(self, cfg = None, handle = None, logstream = None, disable_signal_handlers = False): - - if cfg is None: - cfg = rpki.config.parser() - - if handle is None: - handle = cfg.get("handle", section = myrpki_section) - - self.cfg = cfg - - self.logstream = logstream - self.disable_signal_handlers = disable_signal_handlers - - self.run_rpkid = cfg.getboolean("run_rpkid", section = myrpki_section) - self.run_pubd = cfg.getboolean("run_pubd", section = myrpki_section) - self.run_rootd = cfg.getboolean("run_rootd", section = myrpki_section) - - if self.run_rootd and (not self.run_pubd or not self.run_rpkid): - raise CantRunRootd("Can't run rootd unless also running rpkid and pubd") - - self.default_repository = cfg.get("default_repository", "", section = myrpki_section) - self.pubd_contact_info = cfg.get("pubd_contact_info", "", section = myrpki_section) - - self.rsync_module = cfg.get("publication_rsync_module", section = myrpki_section) - self.rsync_server = cfg.get("publication_rsync_server", section = myrpki_section) - - self.reset_identity(handle) - - - def reset_identity(self, handle): - """ - Select handle of current resource holding entity. - """ - - if handle is None: - raise MissingHandle - self.handle = handle - - - def set_logstream(self, logstream): - """ - Set log stream for this Zookeeper. The log stream is a file-like - object, or None to suppress all logging. - """ - - self.logstream = logstream - - - def log(self, msg): - """ - Send some text to this Zookeeper's log stream, if one is set. - """ - - if self.logstream is not None: - self.logstream.write(msg) - self.logstream.write("\n") - - - @property - def resource_ca(self): - """ - Get ResourceHolderCA object associated with current handle. - """ - - if self.handle is None: - raise HandleNotSet - return rpki.irdb.models.ResourceHolderCA.objects.get(handle = self.handle) - - - @property - def server_ca(self): - """ - Get ServerCA object. - """ - - return rpki.irdb.models.ServerCA.objects.get() - - - @django.db.transaction.atomic - def initialize_server_bpki(self): """ - Initialize server BPKI portion of an RPKI installation. Reads the - configuration file and generates the initial BPKI server - certificates needed to start daemons. + Create an XML element containing Base64 encoded data taken from a + DER object. """ - if self.run_rpkid or self.run_pubd: - server_ca, created = rpki.irdb.models.ServerCA.objects.get_or_certify() - rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "irbe") - - if self.run_rpkid: - rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "rpkid") - rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "irdbd") - - if self.run_pubd: - rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "pubd") - - - @django.db.transaction.atomic - def initialize_resource_bpki(self): - """ - Initialize the resource-holding BPKI for an RPKI installation. - Returns XML describing the resource holder. - - This method is present primarily for backwards compatibility with - the old combined initialize() method which initialized both the - server BPKI and the default resource-holding BPKI in a single - method call. In the long run we want to replace this with - something that takes a handle as argument and creates the - resource-holding BPKI idenity if needed. - """ - - resource_ca, created = rpki.irdb.models.ResourceHolderCA.objects.get_or_certify(handle = self.handle) - return self.generate_identity() - - - def initialize(self): - """ - Backwards compatibility wrapper: calls initialize_server_bpki() - and initialize_resource_bpki(), returns latter's result. - """ - - self.initialize_server_bpki() - return self.initialize_resource_bpki() - - - def generate_identity(self): - """ - Generate identity XML. Broken out of .initialize() because it's - easier for the GUI this way. - """ - - e = Element(tag_oob_child_request, nsmap = oob_nsmap, version = oob_version, - child_handle = self.handle) - B64Element(e, tag_oob_child_bpki_ta, self.resource_ca.certificate) - return etree_wrapper(e, msg = 'This is the "identity" file you will need to send to your parent') - - - @django.db.transaction.atomic - def delete_tenant(self): - """ - Delete the ResourceHolderCA object corresponding to the current handle. - This corresponds to deleting an rpkid <tenant/> object. - - This code assumes the normal Django cascade-on-delete behavior, - that is, we assume that deleting the ResourceHolderCA object - deletes all the subordinate objects that refer to it via foreign - key relationships. - """ - - resource_ca = self.resource_ca - if resource_ca is not None: - resource_ca.delete() + if e is None: + se = Element(tag, **kwargs) else: - self.log("No such ResourceHolderCA \"%s\"" % self.handle) - - - @django.db.transaction.atomic - def configure_rootd(self): - - assert self.run_rpkid and self.run_pubd and self.run_rootd - - rpki.irdb.models.Rootd.objects.get_or_certify( - issuer = self.resource_ca, - service_uri = "http://localhost:%s/" % self.cfg.get("rootd_server_port", - section = myrpki_section)) - - return self.generate_rootd_repository_offer() - - - def generate_rootd_repository_offer(self): - """ - Generate repository offer for rootd. Split out of - configure_rootd() because that's easier for the GUI. - """ - - try: - self.resource_ca.repositories.get(handle = self.handle) - return None - - except rpki.irdb.models.Repository.DoesNotExist: - e = Element(tag_oob_publisher_request, nsmap = oob_nsmap, version = oob_version, - publisher_handle = self.handle) - B64Element(e, tag_oob_publisher_bpki_ta, self.resource_ca.certificate) - return etree_wrapper(e, msg = 'This is the "repository offer" file for you to use if you want to publish in your own repository') - - - def write_bpki_files(self): - """ - Write out BPKI certificate, key, and CRL files for daemons that - need them. - """ + se = SubElement(e, tag, **kwargs) + if e is not None and e.text is None: + e.text = "\n" + se.text = "\n" + obj.get_Base64() + se.tail = "\n" + return se - writer = PEM_writer(self.logstream) - - if self.run_rpkid: - rpkid = self.server_ca.ee_certificates.get(purpose = "rpkid") - writer(self.cfg.get("bpki-ta", section = rpkid_section), self.server_ca.certificate) - writer(self.cfg.get("rpkid-key", section = rpkid_section), rpkid.private_key) - writer(self.cfg.get("rpkid-cert", section = rpkid_section), rpkid.certificate) - writer(self.cfg.get("irdb-cert", section = rpkid_section), - self.server_ca.ee_certificates.get(purpose = "irdbd").certificate) - writer(self.cfg.get("irbe-cert", section = rpkid_section), - self.server_ca.ee_certificates.get(purpose = "irbe").certificate) - - if self.run_pubd: - pubd = self.server_ca.ee_certificates.get(purpose = "pubd") - writer(self.cfg.get("bpki-ta", section = pubd_section), self.server_ca.certificate) - writer(self.cfg.get("pubd-key", section = pubd_section), pubd.private_key) - writer(self.cfg.get("pubd-cert", section = pubd_section), pubd.certificate) - writer(self.cfg.get("irbe-cert", section = pubd_section), - self.server_ca.ee_certificates.get(purpose = "irbe").certificate) - - if self.run_rootd: - try: - rootd = rpki.irdb.models.ResourceHolderCA.objects.get(handle = self.handle).rootd - writer(self.cfg.get("bpki-ta", section = rootd_section), self.server_ca.certificate) - writer(self.cfg.get("rootd-bpki-crl", section = rootd_section), self.server_ca.latest_crl) - writer(self.cfg.get("rootd-bpki-key", section = rootd_section), rootd.private_key) - writer(self.cfg.get("rootd-bpki-cert", section = rootd_section), rootd.certificate) - writer(self.cfg.get("child-bpki-cert", section = rootd_section), rootd.issuer.certificate) - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - self.log("rootd enabled but resource holding entity not yet configured, skipping rootd setup") - except rpki.irdb.models.Rootd.DoesNotExist: - self.log("rootd enabled but not yet configured, skipping rootd setup") - - - @django.db.transaction.atomic - def update_bpki(self): +class PEM_writer(object): """ - Update BPKI certificates. Assumes an existing RPKI installation. - - Basic plan here is to reissue all BPKI certificates we can, right - now. In the long run we might want to be more clever about only - touching ones that need maintenance, but this will do for a start. + Write PEM files to disk, keeping track of which ones we've already + written and setting the file mode appropriately. + + Comparing the old file with what we're about to write serves no real + purpose except to calm users who find repeated messages about + writing the same file confusing. + """ + + def __init__(self, logstream = None): + self.wrote = set() + self.logstream = logstream + + def __call__(self, filename, obj, compare = True): + filename = os.path.realpath(filename) + if filename in self.wrote: + return + tempname = filename + pem = obj.get_PEM() + if not filename.startswith("/dev/"): + try: + if compare and pem == open(filename, "r").read(): + return + except: # pylint: disable=W0702 + pass + tempname += ".%s.tmp" % os.getpid() + mode = 0400 if filename.endswith(".key") else 0444 + if self.logstream is not None: + self.logstream.write("Writing %s\n" % filename) + f = os.fdopen(os.open(tempname, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode), "w") + f.write(pem) + f.close() + if tempname != filename: + os.rename(tempname, filename) + self.wrote.add(filename) - We also reissue CRLs for all CAs. - - Most likely this should be run under cron. - """ - for model in (rpki.irdb.models.ServerCA, - rpki.irdb.models.ResourceHolderCA, - rpki.irdb.models.ServerEE, - rpki.irdb.models.Referral, - rpki.irdb.models.Rootd, - rpki.irdb.models.HostedCA, - rpki.irdb.models.BSC, - rpki.irdb.models.Child, - rpki.irdb.models.Parent, - rpki.irdb.models.Client, - rpki.irdb.models.Repository): - for obj in model.objects.all(): - self.log("Regenerating BPKI certificate %s" % obj.certificate.getSubject()) - obj.avow() - obj.save() - - self.log("Regenerating Server BPKI CRL") - self.server_ca.generate_crl() - self.server_ca.save() - - for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): - self.log("Regenerating BPKI CRL for Resource Holder %s" % ca.handle) - ca.generate_crl() - ca.save() - - - @staticmethod - def _compose_left_right_query(): +def etree_read(filename_or_etree_wrapper, schema = rpki.relaxng.oob_setup): """ - Compose top level element of a left-right query. + Read an etree from a file, verifying then stripping XML namespace + cruft. As a convenience, we also accept an etree_wrapper object in + place of a filename, in which case we deepcopy the etree directly + from the etree_wrapper and there's no need for a file. """ - return Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, - type = "query", version = rpki.left_right.version) + if isinstance(filename_or_etree_wrapper, etree_wrapper): + e = copy.deepcopy(filename_or_etree_wrapper.etree) + else: + e = ElementTree(file = filename_or_etree_wrapper).getroot() + schema.assertValid(e) + return e - @staticmethod - def _compose_publication_control_query(): +class etree_wrapper(object): """ - Compose top level element of a publication-control query. + Wrapper for ETree objects so we can return them as function results + without requiring the caller to understand much about them. """ - return Element(rpki.publication_control.tag_msg, nsmap = rpki.publication_control.nsmap, - type = "query", version = rpki.publication_control.version) + def __init__(self, e, msg = None, debug = False, schema = rpki.relaxng.oob_setup): + self.msg = msg + e = copy.deepcopy(e) + if debug: + print ElementToString(e) + schema.assertValid(e) + self.etree = e + def __str__(self): + return ElementToString(self.etree) - @django.db.transaction.atomic - def synchronize_bpki(self): - """ - Synchronize BPKI updates. This is separate from .update_bpki() - because this requires rpkid to be running and none of the other - BPKI update stuff does; there may be circumstances under which it - makes sense to do the rest of the BPKI update and allow this to - fail with a warning. - """ + def save(self, filename, logstream = None): + filename = os.path.realpath(filename) + tempname = filename + if not filename.startswith("/dev/"): + tempname += ".%s.tmp" % os.getpid() + ElementTree(self.etree).write(tempname) + if tempname != filename: + os.rename(tempname, filename) + if logstream is not None: + logstream.write("Wrote %s\n" % filename) + if self.msg is not None: + logstream.write(self.msg + "\n") - if self.run_rpkid: - q_msg = self._compose_left_right_query() - - for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): - q_pdu = SubElement(q_msg, rpki.left_right.tag_tenant, - action = "set", - tag = "%s__tenant" % ca.handle, - tenant_handle = ca.handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = ca.certificate.get_Base64() - - for bsc in rpki.irdb.models.BSC.objects.all(): - q_pdu = SubElement(q_msg, rpki.left_right.tag_bsc, - action = "set", - tag = "%s__bsc__%s" % (bsc.issuer.handle, bsc.handle), - tenant_handle = bsc.issuer.handle, - bsc_handle = bsc.handle) - SubElement(q_pdu, rpki.left_right.tag_signing_cert).text = bsc.certificate.get_Base64() - SubElement(q_pdu, rpki.left_right.tag_signing_cert_crl).text = bsc.issuer.latest_crl.get_Base64() - - for repository in rpki.irdb.models.Repository.objects.all(): - q_pdu = SubElement(q_msg, rpki.left_right.tag_repository, - action = "set", - tag = "%s__repository__%s" % (repository.issuer.handle, repository.handle), - tenant_handle = repository.issuer.handle, - repository_handle = repository.handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = repository.certificate.get_Base64() - - for parent in rpki.irdb.models.Parent.objects.all(): - q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, - action = "set", - tag = "%s__parent__%s" % (parent.issuer.handle, parent.handle), - tenant_handle = parent.issuer.handle, - parent_handle = parent.handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = parent.certificate.get_Base64() - - for rootd in rpki.irdb.models.Rootd.objects.all(): - q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, - action = "set", - tag = "%s__rootd" % rootd.issuer.handle, - tenant_handle = rootd.issuer.handle, - parent_handle = rootd.issuer.handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = rootd.certificate.get_Base64() - - for child in rpki.irdb.models.Child.objects.all(): - q_pdu = SubElement(q_msg, rpki.left_right.tag_child, - action = "set", - tag = "%s__child__%s" % (child.issuer.handle, child.handle), - tenant_handle = child.issuer.handle, - child_handle = child.handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = child.certificate.get_Base64() - - if len(q_msg) > 0: - self.call_rpkid(q_msg) + @property + def file(self): + from cStringIO import StringIO + return StringIO(ElementToString(self.etree)) - if self.run_pubd: - q_msg = self._compose_publication_control_query() - for client in self.server_ca.clients.all(): - q_pdu = SubElement(q_msg, rpki.publication_control.tag_client, action = "set", client_handle = client.handle) - SubElement(q_pdu, rpki.publication_control.tag_bpki_cert).text = client.certificate.get_Base64() +class Zookeeper(object): - if len(q_msg) > 0: - self.call_pubd(q_msg) + ## @var show_xml + # If not None, a file-like object to which to prettyprint XML, for debugging. + show_xml = None - @django.db.transaction.atomic - def configure_child(self, filename, child_handle = None, valid_until = None): - """ - Configure a new child of this RPKI entity, given the child's XML - identity file as an input. Extracts the child's data from the - XML, cross-certifies the child's resource-holding BPKI - certificate, and generates an XML file describing the relationship - between the child and this parent, including this parent's BPKI - data and up-down protocol service URI. - """ + def __init__(self, cfg = None, handle = None, logstream = None, disable_signal_handlers = False): - x = etree_read(filename) + if cfg is None: + cfg = rpki.config.parser() - if x.tag != tag_oob_child_request: - raise BadXMLMessage("Expected %s, got %s", tag_oob_child_request, x.tag) + if handle is None: + handle = cfg.get("handle", section = myrpki_section) - if child_handle is None: - child_handle = x.get("child_handle") + self.cfg = cfg - if valid_until is None: - valid_until = rpki.sundial.now() + rpki.sundial.timedelta(days = 365) - else: - valid_until = rpki.sundial.datetime.fromXMLtime(valid_until) - if valid_until < rpki.sundial.now(): - raise PastExpiration("Specified new expiration time %s has passed" % valid_until) + self.logstream = logstream + self.disable_signal_handlers = disable_signal_handlers - self.log("Child calls itself %r, we call it %r" % (x.get("child_handle"), child_handle)) + self.run_rpkid = cfg.getboolean("run_rpkid", section = myrpki_section) + self.run_pubd = cfg.getboolean("run_pubd", section = myrpki_section) + self.run_rootd = cfg.getboolean("run_rootd", section = myrpki_section) - child, created = rpki.irdb.models.Child.objects.get_or_certify( - issuer = self.resource_ca, - handle = child_handle, - ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_child_bpki_ta)), - valid_until = valid_until) + if self.run_rootd and (not self.run_pubd or not self.run_rpkid): + raise CantRunRootd("Can't run rootd unless also running rpkid and pubd") - return self.generate_parental_response(child), child_handle + self.default_repository = cfg.get("default_repository", "", section = myrpki_section) + self.pubd_contact_info = cfg.get("pubd_contact_info", "", section = myrpki_section) + self.rsync_module = cfg.get("publication_rsync_module", section = myrpki_section) + self.rsync_server = cfg.get("publication_rsync_server", section = myrpki_section) - @django.db.transaction.atomic - def generate_parental_response(self, child): - """ - Generate parental response XML. Broken out of .configure_child() - for GUI. - """ + self.reset_identity(handle) - service_uri = "http://%s:%s/up-down/%s/%s" % ( - self.cfg.get("rpkid_server_host", section = myrpki_section), - self.cfg.get("rpkid_server_port", section = myrpki_section), - self.handle, child.handle) - e = Element(tag_oob_parent_response, nsmap = oob_nsmap, version = oob_version, - service_uri = service_uri, - child_handle = child.handle, - parent_handle = self.handle) - B64Element(e, tag_oob_parent_bpki_ta, self.resource_ca.certificate) + def reset_identity(self, handle): + """ + Select handle of current resource holding entity. + """ - try: - if self.default_repository: - repo = self.resource_ca.repositories.get(handle = self.default_repository) - else: - repo = self.resource_ca.repositories.get() - except rpki.irdb.models.Repository.DoesNotExist: - repo = None + if handle is None: + raise MissingHandle + self.handle = handle - if repo is None: - self.log("Couldn't find any usable repositories, not giving referral") - elif repo.handle == self.handle: - SubElement(e, tag_oob_offer) + def set_logstream(self, logstream): + """ + Set log stream for this Zookeeper. The log stream is a file-like + object, or None to suppress all logging. + """ - else: - proposed_sia_base = repo.sia_base + child.handle + "/" - referral_cert, created = rpki.irdb.models.Referral.objects.get_or_certify(issuer = self.resource_ca) - auth = rpki.x509.SignedReferral() - auth.set_content(B64Element(None, tag_oob_authorization, child.ta, - nsmap = oob_nsmap, version = oob_version, - authorized_sia_base = proposed_sia_base)) - auth.schema_check() - auth.sign(referral_cert.private_key, referral_cert.certificate, self.resource_ca.latest_crl) - B64Element(e, tag_oob_referral, auth, referrer = repo.client_handle) + self.logstream = logstream - return etree_wrapper(e, msg = "Send this file back to the child you just configured") + def log(self, msg): + """ + Send some text to this Zookeeper's log stream, if one is set. + """ - @django.db.transaction.atomic - def delete_child(self, child_handle): - """ - Delete a child of this RPKI entity. - """ - - self.resource_ca.children.get(handle = child_handle).delete() + if self.logstream is not None: + self.logstream.write(msg) + self.logstream.write("\n") - @django.db.transaction.atomic - def configure_parent(self, filename, parent_handle = None): - """ - Configure a new parent of this RPKI entity, given the output of - the parent's configure_child command as input. Reads the parent's - response XML, extracts the parent's BPKI and service URI - information, cross-certifies the parent's BPKI data into this - entity's BPKI, and checks for offers or referrals of publication - service. If a publication offer or referral is present, we - generate a request-for-service message to that repository, in case - the user wants to avail herself of the referral or offer. - """ + @property + def resource_ca(self): + """ + Get ResourceHolderCA object associated with current handle. + """ - x = etree_read(filename) + if self.handle is None: + raise HandleNotSet + return rpki.irdb.models.ResourceHolderCA.objects.get(handle = self.handle) - if x.tag != tag_oob_parent_response: - raise BadXMLMessage("Expected %s, got %s", tag_oob_parent_response, x.tag) - if parent_handle is None: - parent_handle = x.get("parent_handle") + @property + def server_ca(self): + """ + Get ServerCA object. + """ - offer = x.find(tag_oob_offer) - referral = x.find(tag_oob_referral) + return rpki.irdb.models.ServerCA.objects.get() - if offer is not None: - repository_type = "offer" - referrer = None - referral_authorization = None - elif referral is not None: - repository_type = "referral" - referrer = referral.get("referrer") - referral_authorization = rpki.x509.SignedReferral(Base64 = referral.text) + @django.db.transaction.atomic + def initialize_server_bpki(self): + """ + Initialize server BPKI portion of an RPKI installation. Reads the + configuration file and generates the initial BPKI server + certificates needed to start daemons. + """ - else: - repository_type = "none" - referrer = None - referral_authorization = None + if self.run_rpkid or self.run_pubd: + server_ca, created = rpki.irdb.models.ServerCA.objects.get_or_certify() + rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "irbe") - self.log("Parent calls itself %r, we call it %r" % (x.get("parent_handle"), parent_handle)) - self.log("Parent calls us %r" % x.get("child_handle")) + if self.run_rpkid: + rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "rpkid") + rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "irdbd") - parent, created = rpki.irdb.models.Parent.objects.get_or_certify( - issuer = self.resource_ca, - handle = parent_handle, - child_handle = x.get("child_handle"), - parent_handle = x.get("parent_handle"), - service_uri = x.get("service_uri"), - ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_parent_bpki_ta)), - repository_type = repository_type, - referrer = referrer, - referral_authorization = referral_authorization) + if self.run_pubd: + rpki.irdb.models.ServerEE.objects.get_or_certify(issuer = server_ca, purpose = "pubd") - return self.generate_repository_request(parent), parent_handle + @django.db.transaction.atomic + def initialize_resource_bpki(self): + """ + Initialize the resource-holding BPKI for an RPKI installation. + Returns XML describing the resource holder. - def generate_repository_request(self, parent): - """ - Generate repository request for a given parent. - """ + This method is present primarily for backwards compatibility with + the old combined initialize() method which initialized both the + server BPKI and the default resource-holding BPKI in a single + method call. In the long run we want to replace this with + something that takes a handle as argument and creates the + resource-holding BPKI idenity if needed. + """ - e = Element(tag_oob_publisher_request, nsmap = oob_nsmap, version = oob_version, - publisher_handle = self.handle) - B64Element(e, tag_oob_publisher_bpki_ta, self.resource_ca.certificate) - if parent.repository_type == "referral": - B64Element(e, tag_oob_referral, parent.referral_authorization, - referrer = parent.referrer) + resource_ca, created = rpki.irdb.models.ResourceHolderCA.objects.get_or_certify(handle = self.handle) + return self.generate_identity() - return etree_wrapper(e, msg = "This is the file to send to the repository operator") + def initialize(self): + """ + Backwards compatibility wrapper: calls initialize_server_bpki() + and initialize_resource_bpki(), returns latter's result. + """ - @django.db.transaction.atomic - def delete_parent(self, parent_handle): - """ - Delete a parent of this RPKI entity. - """ + self.initialize_server_bpki() + return self.initialize_resource_bpki() - self.resource_ca.parents.get(handle = parent_handle).delete() + def generate_identity(self): + """ + Generate identity XML. Broken out of .initialize() because it's + easier for the GUI this way. + """ - @django.db.transaction.atomic - def delete_rootd(self): - """ - Delete rootd associated with this RPKI entity. - """ + e = Element(tag_oob_child_request, nsmap = oob_nsmap, version = oob_version, + child_handle = self.handle) + B64Element(e, tag_oob_child_bpki_ta, self.resource_ca.certificate) + return etree_wrapper(e, msg = 'This is the "identity" file you will need to send to your parent') - self.resource_ca.rootd.delete() + @django.db.transaction.atomic + def delete_tenant(self): + """ + Delete the ResourceHolderCA object corresponding to the current handle. + This corresponds to deleting an rpkid <tenant/> object. - @django.db.transaction.atomic - def configure_publication_client(self, filename, sia_base = None, flat = False): - """ - Configure publication server to know about a new client, given the - client's request-for-service message as input. Reads the client's - request for service, cross-certifies the client's BPKI data, and - generates a response message containing the repository's BPKI data - and service URI. - """ + This code assumes the normal Django cascade-on-delete behavior, + that is, we assume that deleting the ResourceHolderCA object + deletes all the subordinate objects that refer to it via foreign + key relationships. + """ - x = etree_read(filename) - - if x.tag != tag_oob_publisher_request: - raise BadXMLMessage("Expected %s, got %s", tag_oob_publisher_request, x.tag) - - client_ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_publisher_bpki_ta)) - - referral = x.find(tag_oob_referral) - - default_sia_base = "rsync://{self.rsync_server}/{self.rsync_module}/{handle}/".format( - self = self, handle = x.get("publisher_handle")) - - if sia_base is None and flat: - self.log("Flat publication structure forced, homing client at top-level") - sia_base = default_sia_base - - if sia_base is None and referral is not None: - self.log("This looks like a referral, checking") - try: - referrer = referral.get("referrer") - referrer = self.server_ca.clients.get(handle = referrer) - referral = rpki.x509.SignedReferral(Base64 = referral.text) - referral = referral.unwrap(ta = (referrer.certificate, self.server_ca.certificate)) - if rpki.x509.X509(Base64 = referral.text) != client_ta: - raise BadXMLMessage("Referral trust anchor does not match") - sia_base = referral.get("authorized_sia_base") - except rpki.irdb.models.Client.DoesNotExist: - self.log("We have no record of the client ({}) alleged to have made this referral".format(referrer)) - - if sia_base is None and referral is None: - self.log("This might be an offer, checking") - try: - parent = rpki.irdb.models.ResourceHolderCA.objects.get(children__ta = client_ta) - if "/" in parent.repositories.get(ta = self.server_ca.certificate).client_handle: - self.log("Client's parent is not top-level, this is not a valid offer") + resource_ca = self.resource_ca + if resource_ca is not None: + resource_ca.delete() else: - self.log("Found client and its parent, nesting") - sia_base = "rsync://{self.rsync_server}/{self.rsync_module}/{parent_handle}/{client_handle}/".format( - self = self, parent_handle = parent.handle, client_handle = x.get("publisher_handle")) - except rpki.irdb.models.Repository.DoesNotExist: - self.log("Found client's parent, but repository isn't set, this shouldn't happen!") - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - try: - rpki.irdb.models.Rootd.objects.get(issuer__certificate = client_ta) - self.log("This client's parent is rootd") - sia_base = default_sia_base - except rpki.irdb.models.Rootd.DoesNotExist: - self.log("We don't host this client's parent, so we didn't make an offer") - - if sia_base is None: - self.log("Don't know where else to nest this client, so defaulting to top-level") - sia_base = default_sia_base + self.log("No such ResourceHolderCA \"%s\"" % self.handle) - if not sia_base.startswith("rsync://"): - raise BadXMLMessage("Malformed sia_base parameter %r, should start with 'rsync://'" % sia_base) - client_handle = "/".join(sia_base.rstrip("/").split("/")[4:]) + @django.db.transaction.atomic + def configure_rootd(self): - self.log("Client calls itself %r, we call it %r" % ( - x.get("publisher_handle"), client_handle)) + assert self.run_rpkid and self.run_pubd and self.run_rootd - client, created = rpki.irdb.models.Client.objects.get_or_certify( - issuer = self.server_ca, - handle = client_handle, - ta = client_ta, - sia_base = sia_base) + rpki.irdb.models.Rootd.objects.get_or_certify( + issuer = self.resource_ca, + service_uri = "http://localhost:%s/" % self.cfg.get("rootd_server_port", + section = myrpki_section)) - return self.generate_repository_response(client), client_handle + return self.generate_rootd_repository_offer() - def generate_repository_response(self, client): - """ - Generate repository response XML to a given client. - """ - - service_uri = "http://{host}:{port}/client/{handle}".format( - host = self.cfg.get("pubd_server_host", section = myrpki_section), - port = self.cfg.get("pubd_server_port", section = myrpki_section), - handle = client.handle) - - rrdp_uri = self.cfg.get("publication_rrdp_notification_uri", section = myrpki_section, - default = "") or None - - e = Element(tag_oob_repository_response, nsmap = oob_nsmap, version = oob_version, - service_uri = service_uri, - publisher_handle = client.handle, - sia_base = client.sia_base) - - if rrdp_uri is not None: - e.set("rrdp_notification_uri", rrdp_uri) - - B64Element(e, tag_oob_repository_bpki_ta, self.server_ca.certificate) - return etree_wrapper(e, msg = "Send this file back to the publication client you just configured") + def generate_rootd_repository_offer(self): + """ + Generate repository offer for rootd. Split out of + configure_rootd() because that's easier for the GUI. + """ + try: + self.resource_ca.repositories.get(handle = self.handle) + return None - @django.db.transaction.atomic - def delete_publication_client(self, client_handle): - """ - Delete a publication client of this RPKI entity. - """ + except rpki.irdb.models.Repository.DoesNotExist: + e = Element(tag_oob_publisher_request, nsmap = oob_nsmap, version = oob_version, + publisher_handle = self.handle) + B64Element(e, tag_oob_publisher_bpki_ta, self.resource_ca.certificate) + return etree_wrapper(e, msg = 'This is the "repository offer" file for you to use if you want to publish in your own repository') + + + def write_bpki_files(self): + """ + Write out BPKI certificate, key, and CRL files for daemons that + need them. + """ + + writer = PEM_writer(self.logstream) + + if self.run_rpkid: + rpkid = self.server_ca.ee_certificates.get(purpose = "rpkid") + writer(self.cfg.get("bpki-ta", section = rpkid_section), self.server_ca.certificate) + writer(self.cfg.get("rpkid-key", section = rpkid_section), rpkid.private_key) + writer(self.cfg.get("rpkid-cert", section = rpkid_section), rpkid.certificate) + writer(self.cfg.get("irdb-cert", section = rpkid_section), + self.server_ca.ee_certificates.get(purpose = "irdbd").certificate) + writer(self.cfg.get("irbe-cert", section = rpkid_section), + self.server_ca.ee_certificates.get(purpose = "irbe").certificate) + + if self.run_pubd: + pubd = self.server_ca.ee_certificates.get(purpose = "pubd") + writer(self.cfg.get("bpki-ta", section = pubd_section), self.server_ca.certificate) + writer(self.cfg.get("pubd-key", section = pubd_section), pubd.private_key) + writer(self.cfg.get("pubd-cert", section = pubd_section), pubd.certificate) + writer(self.cfg.get("irbe-cert", section = pubd_section), + self.server_ca.ee_certificates.get(purpose = "irbe").certificate) + + if self.run_rootd: + try: + rootd = rpki.irdb.models.ResourceHolderCA.objects.get(handle = self.handle).rootd + writer(self.cfg.get("bpki-ta", section = rootd_section), self.server_ca.certificate) + writer(self.cfg.get("rootd-bpki-crl", section = rootd_section), self.server_ca.latest_crl) + writer(self.cfg.get("rootd-bpki-key", section = rootd_section), rootd.private_key) + writer(self.cfg.get("rootd-bpki-cert", section = rootd_section), rootd.certificate) + writer(self.cfg.get("child-bpki-cert", section = rootd_section), rootd.issuer.certificate) + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + self.log("rootd enabled but resource holding entity not yet configured, skipping rootd setup") + except rpki.irdb.models.Rootd.DoesNotExist: + self.log("rootd enabled but not yet configured, skipping rootd setup") + + + @django.db.transaction.atomic + def update_bpki(self): + """ + Update BPKI certificates. Assumes an existing RPKI installation. + + Basic plan here is to reissue all BPKI certificates we can, right + now. In the long run we might want to be more clever about only + touching ones that need maintenance, but this will do for a start. + + We also reissue CRLs for all CAs. + + Most likely this should be run under cron. + """ + + for model in (rpki.irdb.models.ServerCA, + rpki.irdb.models.ResourceHolderCA, + rpki.irdb.models.ServerEE, + rpki.irdb.models.Referral, + rpki.irdb.models.Rootd, + rpki.irdb.models.HostedCA, + rpki.irdb.models.BSC, + rpki.irdb.models.Child, + rpki.irdb.models.Parent, + rpki.irdb.models.Client, + rpki.irdb.models.Repository): + for obj in model.objects.all(): + self.log("Regenerating BPKI certificate %s" % obj.certificate.getSubject()) + obj.avow() + obj.save() + + self.log("Regenerating Server BPKI CRL") + self.server_ca.generate_crl() + self.server_ca.save() + + for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): + self.log("Regenerating BPKI CRL for Resource Holder %s" % ca.handle) + ca.generate_crl() + ca.save() + + + @staticmethod + def _compose_left_right_query(): + """ + Compose top level element of a left-right query. + """ + + return Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, + type = "query", version = rpki.left_right.version) + + + @staticmethod + def _compose_publication_control_query(): + """ + Compose top level element of a publication-control query. + """ + + return Element(rpki.publication_control.tag_msg, nsmap = rpki.publication_control.nsmap, + type = "query", version = rpki.publication_control.version) + + + @django.db.transaction.atomic + def synchronize_bpki(self): + """ + Synchronize BPKI updates. This is separate from .update_bpki() + because this requires rpkid to be running and none of the other + BPKI update stuff does; there may be circumstances under which it + makes sense to do the rest of the BPKI update and allow this to + fail with a warning. + """ + + if self.run_rpkid: + q_msg = self._compose_left_right_query() + + for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): + q_pdu = SubElement(q_msg, rpki.left_right.tag_tenant, + action = "set", + tag = "%s__tenant" % ca.handle, + tenant_handle = ca.handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = ca.certificate.get_Base64() + + for bsc in rpki.irdb.models.BSC.objects.all(): + q_pdu = SubElement(q_msg, rpki.left_right.tag_bsc, + action = "set", + tag = "%s__bsc__%s" % (bsc.issuer.handle, bsc.handle), + tenant_handle = bsc.issuer.handle, + bsc_handle = bsc.handle) + SubElement(q_pdu, rpki.left_right.tag_signing_cert).text = bsc.certificate.get_Base64() + SubElement(q_pdu, rpki.left_right.tag_signing_cert_crl).text = bsc.issuer.latest_crl.get_Base64() + + for repository in rpki.irdb.models.Repository.objects.all(): + q_pdu = SubElement(q_msg, rpki.left_right.tag_repository, + action = "set", + tag = "%s__repository__%s" % (repository.issuer.handle, repository.handle), + tenant_handle = repository.issuer.handle, + repository_handle = repository.handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = repository.certificate.get_Base64() + + for parent in rpki.irdb.models.Parent.objects.all(): + q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, + action = "set", + tag = "%s__parent__%s" % (parent.issuer.handle, parent.handle), + tenant_handle = parent.issuer.handle, + parent_handle = parent.handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = parent.certificate.get_Base64() + + for rootd in rpki.irdb.models.Rootd.objects.all(): + q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, + action = "set", + tag = "%s__rootd" % rootd.issuer.handle, + tenant_handle = rootd.issuer.handle, + parent_handle = rootd.issuer.handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = rootd.certificate.get_Base64() + + for child in rpki.irdb.models.Child.objects.all(): + q_pdu = SubElement(q_msg, rpki.left_right.tag_child, + action = "set", + tag = "%s__child__%s" % (child.issuer.handle, child.handle), + tenant_handle = child.issuer.handle, + child_handle = child.handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = child.certificate.get_Base64() + + if len(q_msg) > 0: + self.call_rpkid(q_msg) + + if self.run_pubd: + q_msg = self._compose_publication_control_query() + + for client in self.server_ca.clients.all(): + q_pdu = SubElement(q_msg, rpki.publication_control.tag_client, action = "set", client_handle = client.handle) + SubElement(q_pdu, rpki.publication_control.tag_bpki_cert).text = client.certificate.get_Base64() + + if len(q_msg) > 0: + self.call_pubd(q_msg) + + + @django.db.transaction.atomic + def configure_child(self, filename, child_handle = None, valid_until = None): + """ + Configure a new child of this RPKI entity, given the child's XML + identity file as an input. Extracts the child's data from the + XML, cross-certifies the child's resource-holding BPKI + certificate, and generates an XML file describing the relationship + between the child and this parent, including this parent's BPKI + data and up-down protocol service URI. + """ + + x = etree_read(filename) + + if x.tag != tag_oob_child_request: + raise BadXMLMessage("Expected %s, got %s", tag_oob_child_request, x.tag) + + if child_handle is None: + child_handle = x.get("child_handle") + + if valid_until is None: + valid_until = rpki.sundial.now() + rpki.sundial.timedelta(days = 365) + else: + valid_until = rpki.sundial.datetime.fromXMLtime(valid_until) + if valid_until < rpki.sundial.now(): + raise PastExpiration("Specified new expiration time %s has passed" % valid_until) - self.server_ca.clients.get(handle = client_handle).delete() + self.log("Child calls itself %r, we call it %r" % (x.get("child_handle"), child_handle)) + child, created = rpki.irdb.models.Child.objects.get_or_certify( + issuer = self.resource_ca, + handle = child_handle, + ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_child_bpki_ta)), + valid_until = valid_until) - @django.db.transaction.atomic - def configure_repository(self, filename, parent_handle = None): - """ - Configure a publication repository for this RPKI entity, given the - repository's response to our request-for-service message as input. - Reads the repository's response, extracts and cross-certifies the - BPKI data and service URI, and links the repository data with the - corresponding parent data in our local database. - """ + return self.generate_parental_response(child), child_handle - x = etree_read(filename) - if x.tag != tag_oob_repository_response: - raise BadXMLMessage("Expected %s, got %s", tag_oob_repository_response, x.tag) + @django.db.transaction.atomic + def generate_parental_response(self, child): + """ + Generate parental response XML. Broken out of .configure_child() + for GUI. + """ - self.log("Repository calls us %r" % (x.get("publisher_handle"))) + service_uri = "http://%s:%s/up-down/%s/%s" % ( + self.cfg.get("rpkid_server_host", section = myrpki_section), + self.cfg.get("rpkid_server_port", section = myrpki_section), + self.handle, child.handle) - if parent_handle is not None: - self.log("Explicit parent_handle given") - try: - if parent_handle == self.handle: - turtle = self.resource_ca.rootd - else: - turtle = self.resource_ca.parents.get(handle = parent_handle) - except (rpki.irdb.models.Parent.DoesNotExist, rpki.irdb.models.Rootd.DoesNotExist): - self.log("Could not find parent %r in our database" % parent_handle) - raise CouldntFindRepoParent + e = Element(tag_oob_parent_response, nsmap = oob_nsmap, version = oob_version, + service_uri = service_uri, + child_handle = child.handle, + parent_handle = self.handle) + B64Element(e, tag_oob_parent_bpki_ta, self.resource_ca.certificate) - else: - turtles = [] - for parent in self.resource_ca.parents.all(): try: - _ = parent.repository + if self.default_repository: + repo = self.resource_ca.repositories.get(handle = self.default_repository) + else: + repo = self.resource_ca.repositories.get() except rpki.irdb.models.Repository.DoesNotExist: - turtles.append(parent) - try: - _ = self.resource_ca.rootd.repository - except rpki.irdb.models.Repository.DoesNotExist: - turtles.append(self.resource_ca.rootd) - except rpki.irdb.models.Rootd.DoesNotExist: - pass - if len(turtles) != 1: - self.log("No explicit parent_handle given and unable to guess") - raise CouldntFindRepoParent - turtle = turtles[0] - if isinstance(turtle, rpki.irdb.models.Rootd): - parent_handle = self.handle - else: - parent_handle = turtle.handle - self.log("No explicit parent_handle given, guessing parent {}".format(parent_handle)) - - rpki.irdb.models.Repository.objects.get_or_certify( - issuer = self.resource_ca, - handle = parent_handle, - client_handle = x.get("publisher_handle"), - service_uri = x.get("service_uri"), - sia_base = x.get("sia_base"), - rrdp_notification_uri = x.get("rrdp_notification_uri"), - ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_repository_bpki_ta)), - turtle = turtle) - - - @django.db.transaction.atomic - def delete_repository(self, repository_handle): - """ - Delete a repository of this RPKI entity. - """ - - self.resource_ca.repositories.get(handle = repository_handle).delete() - - - @django.db.transaction.atomic - def renew_children(self, child_handle, valid_until = None): - """ - Update validity period for one child entity or, if child_handle is - None, for all child entities. - """ - - if child_handle is None: - children = self.resource_ca.children.all() - else: - children = self.resource_ca.children.filter(handle = child_handle) - - if valid_until is None: - valid_until = rpki.sundial.now() + rpki.sundial.timedelta(days = 365) - else: - valid_until = rpki.sundial.datetime.fromXMLtime(valid_until) - if valid_until < rpki.sundial.now(): - raise PastExpiration("Specified new expiration time %s has passed" % valid_until) - - self.log("New validity date %s" % valid_until) + repo = None - for child in children: - child.valid_until = valid_until - child.save() - - - @django.db.transaction.atomic - def load_prefixes(self, filename, ignore_missing_children = False): - """ - Whack IRDB to match prefixes.csv. - """ + if repo is None: + self.log("Couldn't find any usable repositories, not giving referral") - grouped4 = {} - grouped6 = {} + elif repo.handle == self.handle: + SubElement(e, tag_oob_offer) - for handle, prefix in csv_reader(filename, columns = 2): - grouped = grouped6 if ":" in prefix else grouped4 - if handle not in grouped: - grouped[handle] = [] - grouped[handle].append(prefix) - - primary_keys = [] - - for version, grouped, rset in ((4, grouped4, rpki.resource_set.resource_set_ipv4), - (6, grouped6, rpki.resource_set.resource_set_ipv6)): - for handle, prefixes in grouped.iteritems(): - try: - child = self.resource_ca.children.get(handle = handle) - except rpki.irdb.models.Child.DoesNotExist: - if not ignore_missing_children: - raise else: - for prefix in rset(",".join(prefixes)): - obj, created = rpki.irdb.models.ChildNet.objects.get_or_create( - child = child, - start_ip = str(prefix.min), - end_ip = str(prefix.max), - version = version) - primary_keys.append(obj.pk) - - q = rpki.irdb.models.ChildNet.objects - q = q.filter(child__issuer = self.resource_ca) - q = q.exclude(pk__in = primary_keys) - q.delete() - - - @django.db.transaction.atomic - def load_asns(self, filename, ignore_missing_children = False): - """ - Whack IRDB to match asns.csv. - """ - - grouped = {} - - for handle, asn in csv_reader(filename, columns = 2): - if handle not in grouped: - grouped[handle] = [] - grouped[handle].append(asn) - - primary_keys = [] - - for handle, asns in grouped.iteritems(): - try: - child = self.resource_ca.children.get(handle = handle) - except rpki.irdb.models.Child.DoesNotExist: - if not ignore_missing_children: - raise - else: - for asn in rpki.resource_set.resource_set_as(",".join(asns)): - obj, created = rpki.irdb.models.ChildASN.objects.get_or_create( - child = child, - start_as = str(asn.min), - end_as = str(asn.max)) - primary_keys.append(obj.pk) - - q = rpki.irdb.models.ChildASN.objects - q = q.filter(child__issuer = self.resource_ca) - q = q.exclude(pk__in = primary_keys) - q.delete() - - - @django.db.transaction.atomic - def load_roa_requests(self, filename): - """ - Whack IRDB to match roa.csv. - """ - - grouped = {} - - # format: p/n-m asn group - for pnm, asn, group in csv_reader(filename, columns = 3): - key = (asn, group) - if key not in grouped: - grouped[key] = [] - grouped[key].append(pnm) - - # Deleting and recreating all the ROA requests is inefficient, - # but rpkid's current representation of ROA requests is wrong - # (see #32), so it's not worth a lot of effort here as we're - # just going to have to rewrite this soon anyway. + proposed_sia_base = repo.sia_base + child.handle + "/" + referral_cert, created = rpki.irdb.models.Referral.objects.get_or_certify(issuer = self.resource_ca) + auth = rpki.x509.SignedReferral() + auth.set_content(B64Element(None, tag_oob_authorization, child.ta, + nsmap = oob_nsmap, version = oob_version, + authorized_sia_base = proposed_sia_base)) + auth.schema_check() + auth.sign(referral_cert.private_key, referral_cert.certificate, self.resource_ca.latest_crl) + B64Element(e, tag_oob_referral, auth, referrer = repo.client_handle) + + return etree_wrapper(e, msg = "Send this file back to the child you just configured") + + + @django.db.transaction.atomic + def delete_child(self, child_handle): + """ + Delete a child of this RPKI entity. + """ + + self.resource_ca.children.get(handle = child_handle).delete() + + + @django.db.transaction.atomic + def configure_parent(self, filename, parent_handle = None): + """ + Configure a new parent of this RPKI entity, given the output of + the parent's configure_child command as input. Reads the parent's + response XML, extracts the parent's BPKI and service URI + information, cross-certifies the parent's BPKI data into this + entity's BPKI, and checks for offers or referrals of publication + service. If a publication offer or referral is present, we + generate a request-for-service message to that repository, in case + the user wants to avail herself of the referral or offer. + """ + + x = etree_read(filename) + + if x.tag != tag_oob_parent_response: + raise BadXMLMessage("Expected %s, got %s", tag_oob_parent_response, x.tag) + + if parent_handle is None: + parent_handle = x.get("parent_handle") + + offer = x.find(tag_oob_offer) + referral = x.find(tag_oob_referral) + + if offer is not None: + repository_type = "offer" + referrer = None + referral_authorization = None + + elif referral is not None: + repository_type = "referral" + referrer = referral.get("referrer") + referral_authorization = rpki.x509.SignedReferral(Base64 = referral.text) - self.resource_ca.roa_requests.all().delete() - - for key, pnms in grouped.iteritems(): - asn, group = key - - roa_request = self.resource_ca.roa_requests.create(asn = asn) - - for pnm in pnms: - if ":" in pnm: - p = rpki.resource_set.roa_prefix_ipv6.parse_str(pnm) - v = 6 else: - p = rpki.resource_set.roa_prefix_ipv4.parse_str(pnm) - v = 4 - roa_request.prefixes.create( - version = v, - prefix = str(p.prefix), - prefixlen = int(p.prefixlen), - max_prefixlen = int(p.max_prefixlen)) - + repository_type = "none" + referrer = None + referral_authorization = None + + self.log("Parent calls itself %r, we call it %r" % (x.get("parent_handle"), parent_handle)) + self.log("Parent calls us %r" % x.get("child_handle")) + + parent, created = rpki.irdb.models.Parent.objects.get_or_certify( + issuer = self.resource_ca, + handle = parent_handle, + child_handle = x.get("child_handle"), + parent_handle = x.get("parent_handle"), + service_uri = x.get("service_uri"), + ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_parent_bpki_ta)), + repository_type = repository_type, + referrer = referrer, + referral_authorization = referral_authorization) + + return self.generate_repository_request(parent), parent_handle + + + def generate_repository_request(self, parent): + """ + Generate repository request for a given parent. + """ + + e = Element(tag_oob_publisher_request, nsmap = oob_nsmap, version = oob_version, + publisher_handle = self.handle) + B64Element(e, tag_oob_publisher_bpki_ta, self.resource_ca.certificate) + if parent.repository_type == "referral": + B64Element(e, tag_oob_referral, parent.referral_authorization, + referrer = parent.referrer) + + return etree_wrapper(e, msg = "This is the file to send to the repository operator") + + + @django.db.transaction.atomic + def delete_parent(self, parent_handle): + """ + Delete a parent of this RPKI entity. + """ + + self.resource_ca.parents.get(handle = parent_handle).delete() + + + @django.db.transaction.atomic + def delete_rootd(self): + """ + Delete rootd associated with this RPKI entity. + """ + + self.resource_ca.rootd.delete() + + + @django.db.transaction.atomic + def configure_publication_client(self, filename, sia_base = None, flat = False): + """ + Configure publication server to know about a new client, given the + client's request-for-service message as input. Reads the client's + request for service, cross-certifies the client's BPKI data, and + generates a response message containing the repository's BPKI data + and service URI. + """ + + x = etree_read(filename) + + if x.tag != tag_oob_publisher_request: + raise BadXMLMessage("Expected %s, got %s", tag_oob_publisher_request, x.tag) + + client_ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_publisher_bpki_ta)) + + referral = x.find(tag_oob_referral) + + default_sia_base = "rsync://{self.rsync_server}/{self.rsync_module}/{handle}/".format( + self = self, + handle = x.get("publisher_handle")) + + if sia_base is None and flat: + self.log("Flat publication structure forced, homing client at top-level") + sia_base = default_sia_base + + if sia_base is None and referral is not None: + self.log("This looks like a referral, checking") + try: + referrer = referral.get("referrer") + referrer = self.server_ca.clients.get(handle = referrer) + referral = rpki.x509.SignedReferral(Base64 = referral.text) + referral = referral.unwrap(ta = (referrer.certificate, self.server_ca.certificate)) + if rpki.x509.X509(Base64 = referral.text) != client_ta: + raise BadXMLMessage("Referral trust anchor does not match") + sia_base = referral.get("authorized_sia_base") + except rpki.irdb.models.Client.DoesNotExist: + self.log("We have no record of the client ({}) alleged to have made this referral".format(referrer)) + + if sia_base is None and referral is None: + self.log("This might be an offer, checking") + try: + parent = rpki.irdb.models.ResourceHolderCA.objects.get(children__ta = client_ta) + if "/" in parent.repositories.get(ta = self.server_ca.certificate).client_handle: + self.log("Client's parent is not top-level, this is not a valid offer") + else: + self.log("Found client and its parent, nesting") + sia_base = "rsync://{self.rsync_server}/{self.rsync_module}/{parent_handle}/{client_handle}/".format( + self = self, + parent_handle = parent.handle, + client_handle = x.get("publisher_handle")) + except rpki.irdb.models.Repository.DoesNotExist: + self.log("Found client's parent, but repository isn't set, this shouldn't happen!") + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + try: + rpki.irdb.models.Rootd.objects.get(issuer__certificate = client_ta) + self.log("This client's parent is rootd") + sia_base = default_sia_base + except rpki.irdb.models.Rootd.DoesNotExist: + self.log("We don't host this client's parent, so we didn't make an offer") + + if sia_base is None: + self.log("Don't know where else to nest this client, so defaulting to top-level") + sia_base = default_sia_base + + if not sia_base.startswith("rsync://"): + raise BadXMLMessage("Malformed sia_base parameter %r, should start with 'rsync://'" % sia_base) + + client_handle = "/".join(sia_base.rstrip("/").split("/")[4:]) + + self.log("Client calls itself %r, we call it %r" % ( + x.get("publisher_handle"), client_handle)) + + client, created = rpki.irdb.models.Client.objects.get_or_certify( + issuer = self.server_ca, + handle = client_handle, + ta = client_ta, + sia_base = sia_base) + + return self.generate_repository_response(client), client_handle + + + def generate_repository_response(self, client): + """ + Generate repository response XML to a given client. + """ + + service_uri = "http://{host}:{port}/client/{handle}".format( + host = self.cfg.get("pubd_server_host", section = myrpki_section), + port = self.cfg.get("pubd_server_port", section = myrpki_section), + handle = client.handle) + + rrdp_uri = self.cfg.get("publication_rrdp_notification_uri", section = myrpki_section, + default = "") or None + + e = Element(tag_oob_repository_response, nsmap = oob_nsmap, version = oob_version, + service_uri = service_uri, + publisher_handle = client.handle, + sia_base = client.sia_base) + + if rrdp_uri is not None: + e.set("rrdp_notification_uri", rrdp_uri) + + B64Element(e, tag_oob_repository_bpki_ta, self.server_ca.certificate) + return etree_wrapper(e, msg = "Send this file back to the publication client you just configured") + + + @django.db.transaction.atomic + def delete_publication_client(self, client_handle): + """ + Delete a publication client of this RPKI entity. + """ + + self.server_ca.clients.get(handle = client_handle).delete() + + + @django.db.transaction.atomic + def configure_repository(self, filename, parent_handle = None): + """ + Configure a publication repository for this RPKI entity, given the + repository's response to our request-for-service message as input. + Reads the repository's response, extracts and cross-certifies the + BPKI data and service URI, and links the repository data with the + corresponding parent data in our local database. + """ + + x = etree_read(filename) + + if x.tag != tag_oob_repository_response: + raise BadXMLMessage("Expected %s, got %s", tag_oob_repository_response, x.tag) + + self.log("Repository calls us %r" % (x.get("publisher_handle"))) + + if parent_handle is not None: + self.log("Explicit parent_handle given") + try: + if parent_handle == self.handle: + turtle = self.resource_ca.rootd + else: + turtle = self.resource_ca.parents.get(handle = parent_handle) + except (rpki.irdb.models.Parent.DoesNotExist, rpki.irdb.models.Rootd.DoesNotExist): + self.log("Could not find parent %r in our database" % parent_handle) + raise CouldntFindRepoParent - @django.db.transaction.atomic - def load_ghostbuster_requests(self, filename, parent = None): - """ - Whack IRDB to match ghostbusters.vcard. - - This accepts one or more vCards from a file. - """ - - self.resource_ca.ghostbuster_requests.filter(parent = parent).delete() + else: + turtles = [] + for parent in self.resource_ca.parents.all(): + try: + _ = parent.repository + except rpki.irdb.models.Repository.DoesNotExist: + turtles.append(parent) + try: + _ = self.resource_ca.rootd.repository + except rpki.irdb.models.Repository.DoesNotExist: + turtles.append(self.resource_ca.rootd) + except rpki.irdb.models.Rootd.DoesNotExist: + pass + if len(turtles) != 1: + self.log("No explicit parent_handle given and unable to guess") + raise CouldntFindRepoParent + turtle = turtles[0] + if isinstance(turtle, rpki.irdb.models.Rootd): + parent_handle = self.handle + else: + parent_handle = turtle.handle + self.log("No explicit parent_handle given, guessing parent {}".format(parent_handle)) + + rpki.irdb.models.Repository.objects.get_or_certify( + issuer = self.resource_ca, + handle = parent_handle, + client_handle = x.get("publisher_handle"), + service_uri = x.get("service_uri"), + sia_base = x.get("sia_base"), + rrdp_notification_uri = x.get("rrdp_notification_uri"), + ta = rpki.x509.X509(Base64 = x.findtext(tag_oob_repository_bpki_ta)), + turtle = turtle) + + + @django.db.transaction.atomic + def delete_repository(self, repository_handle): + """ + Delete a repository of this RPKI entity. + """ + + self.resource_ca.repositories.get(handle = repository_handle).delete() + + + @django.db.transaction.atomic + def renew_children(self, child_handle, valid_until = None): + """ + Update validity period for one child entity or, if child_handle is + None, for all child entities. + """ + + if child_handle is None: + children = self.resource_ca.children.all() + else: + children = self.resource_ca.children.filter(handle = child_handle) - vcard = [] + if valid_until is None: + valid_until = rpki.sundial.now() + rpki.sundial.timedelta(days = 365) + else: + valid_until = rpki.sundial.datetime.fromXMLtime(valid_until) + if valid_until < rpki.sundial.now(): + raise PastExpiration("Specified new expiration time %s has passed" % valid_until) + + self.log("New validity date %s" % valid_until) + + for child in children: + child.valid_until = valid_until + child.save() + + + @django.db.transaction.atomic + def load_prefixes(self, filename, ignore_missing_children = False): + """ + Whack IRDB to match prefixes.csv. + """ + + grouped4 = {} + grouped6 = {} + + for handle, prefix in csv_reader(filename, columns = 2): + grouped = grouped6 if ":" in prefix else grouped4 + if handle not in grouped: + grouped[handle] = [] + grouped[handle].append(prefix) + + primary_keys = [] + + for version, grouped, rset in ((4, grouped4, rpki.resource_set.resource_set_ipv4), + (6, grouped6, rpki.resource_set.resource_set_ipv6)): + for handle, prefixes in grouped.iteritems(): + try: + child = self.resource_ca.children.get(handle = handle) + except rpki.irdb.models.Child.DoesNotExist: + if not ignore_missing_children: + raise + else: + for prefix in rset(",".join(prefixes)): + obj, created = rpki.irdb.models.ChildNet.objects.get_or_create( + child = child, + start_ip = str(prefix.min), + end_ip = str(prefix.max), + version = version) + primary_keys.append(obj.pk) + + q = rpki.irdb.models.ChildNet.objects + q = q.filter(child__issuer = self.resource_ca) + q = q.exclude(pk__in = primary_keys) + q.delete() + + + @django.db.transaction.atomic + def load_asns(self, filename, ignore_missing_children = False): + """ + Whack IRDB to match asns.csv. + """ + + grouped = {} + + for handle, asn in csv_reader(filename, columns = 2): + if handle not in grouped: + grouped[handle] = [] + grouped[handle].append(asn) + + primary_keys = [] + + for handle, asns in grouped.iteritems(): + try: + child = self.resource_ca.children.get(handle = handle) + except rpki.irdb.models.Child.DoesNotExist: + if not ignore_missing_children: + raise + else: + for asn in rpki.resource_set.resource_set_as(",".join(asns)): + obj, created = rpki.irdb.models.ChildASN.objects.get_or_create( + child = child, + start_as = str(asn.min), + end_as = str(asn.max)) + primary_keys.append(obj.pk) + + q = rpki.irdb.models.ChildASN.objects + q = q.filter(child__issuer = self.resource_ca) + q = q.exclude(pk__in = primary_keys) + q.delete() + + + @django.db.transaction.atomic + def load_roa_requests(self, filename): + """ + Whack IRDB to match roa.csv. + """ + + grouped = {} + + # format: p/n-m asn group + for pnm, asn, group in csv_reader(filename, columns = 3): + key = (asn, group) + if key not in grouped: + grouped[key] = [] + grouped[key].append(pnm) + + # Deleting and recreating all the ROA requests is inefficient, + # but rpkid's current representation of ROA requests is wrong + # (see #32), so it's not worth a lot of effort here as we're + # just going to have to rewrite this soon anyway. + + self.resource_ca.roa_requests.all().delete() + + for key, pnms in grouped.iteritems(): + asn, group = key + + roa_request = self.resource_ca.roa_requests.create(asn = asn) + + for pnm in pnms: + if ":" in pnm: + p = rpki.resource_set.roa_prefix_ipv6.parse_str(pnm) + v = 6 + else: + p = rpki.resource_set.roa_prefix_ipv4.parse_str(pnm) + v = 4 + roa_request.prefixes.create( + version = v, + prefix = str(p.prefix), + prefixlen = int(p.prefixlen), + max_prefixlen = int(p.max_prefixlen)) + + + @django.db.transaction.atomic + def load_ghostbuster_requests(self, filename, parent = None): + """ + Whack IRDB to match ghostbusters.vcard. + + This accepts one or more vCards from a file. + """ + + self.resource_ca.ghostbuster_requests.filter(parent = parent).delete() - for line in open(filename, "r"): - if not vcard and not line.upper().startswith("BEGIN:VCARD"): - continue - vcard.append(line) - if line.upper().startswith("END:VCARD"): - self.resource_ca.ghostbuster_requests.create(vcard = "".join(vcard), parent = parent) vcard = [] + for line in open(filename, "r"): + if not vcard and not line.upper().startswith("BEGIN:VCARD"): + continue + vcard.append(line) + if line.upper().startswith("END:VCARD"): + self.resource_ca.ghostbuster_requests.create(vcard = "".join(vcard), parent = parent) + vcard = [] - def call_rpkid(self, q_msg, suppress_error_check = False): - """ - Issue a call to rpkid, return result. - """ - url = "http://%s:%s/left-right" % ( - self.cfg.get("rpkid_server_host", section = myrpki_section), - self.cfg.get("rpkid_server_port", section = myrpki_section)) + def call_rpkid(self, q_msg, suppress_error_check = False): + """ + Issue a call to rpkid, return result. + """ - rpkid = self.server_ca.ee_certificates.get(purpose = "rpkid") - irbe = self.server_ca.ee_certificates.get(purpose = "irbe") + url = "http://%s:%s/left-right" % ( + self.cfg.get("rpkid_server_host", section = myrpki_section), + self.cfg.get("rpkid_server_port", section = myrpki_section)) - r_msg = rpki.http_simple.client( - proto_cms_msg = rpki.left_right.cms_msg, - client_key = irbe.private_key, - client_cert = irbe.certificate, - server_ta = self.server_ca.certificate, - server_cert = rpkid.certificate, - url = url, - q_msg = q_msg, - debug = self.show_xml) + rpkid = self.server_ca.ee_certificates.get(purpose = "rpkid") + irbe = self.server_ca.ee_certificates.get(purpose = "irbe") - if not suppress_error_check: - self.check_error_report(r_msg) - return r_msg + r_msg = rpki.http_simple.client( + proto_cms_msg = rpki.left_right.cms_msg, + client_key = irbe.private_key, + client_cert = irbe.certificate, + server_ta = self.server_ca.certificate, + server_cert = rpkid.certificate, + url = url, + q_msg = q_msg, + debug = self.show_xml) + if not suppress_error_check: + self.check_error_report(r_msg) + return r_msg - def _rpkid_tenant_control(self, *bools): - assert all(isinstance(b, str) for b in bools) - q_msg = self._compose_left_right_query() - q_pdu = SubElement(q_msg, rpki.left_right.tag_tenant, action = "set", tenant_handle = self.handle) - for b in bools: - q_pdu.set(b, "yes") - return self.call_rpkid(q_msg) + def _rpkid_tenant_control(self, *bools): + assert all(isinstance(b, str) for b in bools) + q_msg = self._compose_left_right_query() + q_pdu = SubElement(q_msg, rpki.left_right.tag_tenant, action = "set", tenant_handle = self.handle) + for b in bools: + q_pdu.set(b, "yes") + return self.call_rpkid(q_msg) - def run_rpkid_now(self): - """ - Poke rpkid to immediately run the cron job for the current handle. - This method is used by the GUI when a user has changed something in the - IRDB (ghostbuster, roa) which does not require a full synchronize() call, - to force the object to be immediately issued. - """ + def run_rpkid_now(self): + """ + Poke rpkid to immediately run the cron job for the current handle. - return self._rpkid_tenant_control("run_now") + This method is used by the GUI when a user has changed something in the + IRDB (ghostbuster, roa) which does not require a full synchronize() call, + to force the object to be immediately issued. + """ + return self._rpkid_tenant_control("run_now") - def publish_world_now(self): - """ - Poke rpkid to (re)publish everything for the current handle. - """ - return self._rpkid_tenant_control("publish_world_now") + def publish_world_now(self): + """ + Poke rpkid to (re)publish everything for the current handle. + """ + return self._rpkid_tenant_control("publish_world_now") - def reissue(self): - """ - Poke rpkid to reissue everything for the current handle. - """ - return self._rpkid_tenant_control("reissue") + def reissue(self): + """ + Poke rpkid to reissue everything for the current handle. + """ + return self._rpkid_tenant_control("reissue") - def rekey(self): - """ - Poke rpkid to rekey all RPKI certificates received for the current - handle. - """ - - return self._rpkid_tenant_control("rekey") - - - def revoke(self): - """ - Poke rpkid to revoke old RPKI keys for the current handle. - """ - - return self._rpkid_tenant_control("revoke") + def rekey(self): + """ + Poke rpkid to rekey all RPKI certificates received for the current + handle. + """ + + return self._rpkid_tenant_control("rekey") - def revoke_forgotten(self): - """ - Poke rpkid to revoke old forgotten RPKI keys for the current handle. - """ - return self._rpkid_tenant_control("revoke_forgotten") + def revoke(self): + """ + Poke rpkid to revoke old RPKI keys for the current handle. + """ + return self._rpkid_tenant_control("revoke") - def clear_all_sql_cms_replay_protection(self): - """ - Tell rpkid and pubd to clear replay protection for all SQL-based - entities. This is a fairly blunt instrument, but as we don't - expect this to be necessary except in the case of gross - misconfiguration, it should suffice. - """ - if self.run_rpkid: - q_msg = self._compose_left_right_query() - for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): - SubElement(q_msg, rpki.left_right.tag_tenant, action = "set", - tenant_handle = ca.handle, clear_replay_protection = "yes") - self.call_rpkid(q_msg) + def revoke_forgotten(self): + """ + Poke rpkid to revoke old forgotten RPKI keys for the current handle. + """ + + return self._rpkid_tenant_control("revoke_forgotten") - if self.run_pubd: - q_msg = self._compose_publication_control_query() - for client in self.server_ca.clients.all(): - SubElement(q_msg, rpki.publication_control.tag_client, action = "set", - client_handle = client.handle, clear_reply_protection = "yes") - self.call_pubd(q_msg) + def clear_all_sql_cms_replay_protection(self): + """ + Tell rpkid and pubd to clear replay protection for all SQL-based + entities. This is a fairly blunt instrument, but as we don't + expect this to be necessary except in the case of gross + misconfiguration, it should suffice. + """ - def call_pubd(self, q_msg): - """ - Issue a call to pubd, return result. - """ + if self.run_rpkid: + q_msg = self._compose_left_right_query() + for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): + SubElement(q_msg, rpki.left_right.tag_tenant, action = "set", + tenant_handle = ca.handle, clear_replay_protection = "yes") + self.call_rpkid(q_msg) - url = "http://%s:%s/control" % ( - self.cfg.get("pubd_server_host", section = myrpki_section), - self.cfg.get("pubd_server_port", section = myrpki_section)) + if self.run_pubd: + q_msg = self._compose_publication_control_query() + for client in self.server_ca.clients.all(): + SubElement(q_msg, rpki.publication_control.tag_client, action = "set", + client_handle = client.handle, clear_reply_protection = "yes") + self.call_pubd(q_msg) - pubd = self.server_ca.ee_certificates.get(purpose = "pubd") - irbe = self.server_ca.ee_certificates.get(purpose = "irbe") - r_msg = rpki.http_simple.client( - proto_cms_msg = rpki.publication_control.cms_msg, - client_key = irbe.private_key, - client_cert = irbe.certificate, - server_ta = self.server_ca.certificate, - server_cert = pubd.certificate, - url = url, - q_msg = q_msg, - debug = self.show_xml) + def call_pubd(self, q_msg): + """ + Issue a call to pubd, return result. + """ - self.check_error_report(r_msg) - return r_msg + url = "http://%s:%s/control" % ( + self.cfg.get("pubd_server_host", section = myrpki_section), + self.cfg.get("pubd_server_port", section = myrpki_section)) + pubd = self.server_ca.ee_certificates.get(purpose = "pubd") + irbe = self.server_ca.ee_certificates.get(purpose = "irbe") - def check_error_report(self, r_msg): - """ - Check a response from rpkid or pubd for error_report PDUs, log and - throw exceptions as needed. - """ + r_msg = rpki.http_simple.client( + proto_cms_msg = rpki.publication_control.cms_msg, + client_key = irbe.private_key, + client_cert = irbe.certificate, + server_ta = self.server_ca.certificate, + server_cert = pubd.certificate, + url = url, + q_msg = q_msg, + debug = self.show_xml) - failed = False - for r_pdu in r_msg.getiterator(rpki.left_right.tag_report_error): - failed = True - self.log("rpkid reported failure: %s" % r_pdu.get("error_code")) - if r_pdu.text: - self.log(r_pdu.text) - for r_pdu in r_msg.getiterator(rpki.publication_control.tag_report_error): - failed = True - self.log("pubd reported failure: %s" % r_pdu.get("error_code")) - if r_pdu.text: - self.log(r_pdu.text) - if failed: - raise CouldntTalkToDaemon - - - @django.db.transaction.atomic - def synchronize(self, *handles_to_poke): - """ - Configure RPKI daemons with the data built up by the other - commands in this program. Commands which modify the IRDB and want - to whack everything into sync should call this when they're done, - but be warned that this can be slow with a lot of CAs. + self.check_error_report(r_msg) + return r_msg - Any arguments given are handles of CAs which should be poked with a - <tenant run_now="yes"/> operation. - """ - for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): - self.synchronize_rpkid_one_ca_core(ca, ca.handle in handles_to_poke) - self.synchronize_pubd_core() - self.synchronize_rpkid_deleted_core() + def check_error_report(self, r_msg): + """ + Check a response from rpkid or pubd for error_report PDUs, log and + throw exceptions as needed. + """ + failed = False + for r_pdu in r_msg.getiterator(rpki.left_right.tag_report_error): + failed = True + self.log("rpkid reported failure: %s" % r_pdu.get("error_code")) + if r_pdu.text: + self.log(r_pdu.text) + for r_pdu in r_msg.getiterator(rpki.publication_control.tag_report_error): + failed = True + self.log("pubd reported failure: %s" % r_pdu.get("error_code")) + if r_pdu.text: + self.log(r_pdu.text) + if failed: + raise CouldntTalkToDaemon - @django.db.transaction.atomic - def synchronize_ca(self, ca = None, poke = False): - """ - Synchronize one CA. Most commands which modify a CA should call - this. CA to synchronize defaults to the current resource CA. - """ - if ca is None: - ca = self.resource_ca - self.synchronize_rpkid_one_ca_core(ca, poke) + @django.db.transaction.atomic + def synchronize(self, *handles_to_poke): + """ + Configure RPKI daemons with the data built up by the other + commands in this program. Commands which modify the IRDB and want + to whack everything into sync should call this when they're done, + but be warned that this can be slow with a lot of CAs. + + Any arguments given are handles of CAs which should be poked with a + <tenant run_now="yes"/> operation. + """ + + for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): + self.synchronize_rpkid_one_ca_core(ca, ca.handle in handles_to_poke) + self.synchronize_pubd_core() + self.synchronize_rpkid_deleted_core() - @django.db.transaction.atomic - def synchronize_deleted_ca(self): - """ - Delete CAs which are present in rpkid's database but not in the - IRDB. - """ + @django.db.transaction.atomic + def synchronize_ca(self, ca = None, poke = False): + """ + Synchronize one CA. Most commands which modify a CA should call + this. CA to synchronize defaults to the current resource CA. + """ + + if ca is None: + ca = self.resource_ca + self.synchronize_rpkid_one_ca_core(ca, poke) + + + @django.db.transaction.atomic + def synchronize_deleted_ca(self): + """ + Delete CAs which are present in rpkid's database but not in the + IRDB. + """ + + self.synchronize_rpkid_deleted_core() + + + @django.db.transaction.atomic + def synchronize_pubd(self): + """ + Synchronize pubd. Most commands which modify pubd should call this. + """ + + self.synchronize_pubd_core() + + + def synchronize_rpkid_one_ca_core(self, ca, poke = False): + """ + Synchronize one CA. This is the core synchronization code. Don't + call this directly, instead call one of the methods that calls + this inside a Django commit wrapper. + + This method configures rpkid with data built up by the other + commands in this program. Most commands which modify IRDB values + related to rpkid should call this when they're done. + + If poke is True, we append a left-right run_now operation for this + CA to the end of whatever other commands this method generates. + """ + + # We can use a single BSC for everything -- except BSC key + # rollovers. Drive off that bridge when we get to it. + + bsc_handle = "bsc" + + # A default RPKI CRL cycle time of six hours seems sane. One + # might make a case for a day instead, but we've been running with + # six hours for a while now and haven't seen a lot of whining. + + tenant_crl_interval = self.cfg.getint("tenant_crl_interval", 6 * 60 * 60, section = myrpki_section) + + # regen_margin now just controls how long before RPKI certificate + # expiration we should regenerate; it used to control the interval + # before RPKI CRL staleness at which to regenerate the CRL, but + # using the same timer value for both of these is hopeless. + # + # A default regeneration margin of two weeks gives enough time for + # humans to react. We add a two hour fudge factor in the hope + # that this will regenerate certificates just *before* the + # companion cron job warns of impending doom. + + tenant_regen_margin = self.cfg.getint("tenant_regen_margin", 14 * 24 * 60 * 60 + 2 * 60, section = myrpki_section) + + # See what rpkid already has on file for this entity. + + q_msg = self._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_tenant, action = "get", tenant_handle = ca.handle) + SubElement(q_msg, rpki.left_right.tag_bsc, action = "list", tenant_handle = ca.handle) + SubElement(q_msg, rpki.left_right.tag_repository, action = "list", tenant_handle = ca.handle) + SubElement(q_msg, rpki.left_right.tag_parent, action = "list", tenant_handle = ca.handle) + SubElement(q_msg, rpki.left_right.tag_child, action = "list", tenant_handle = ca.handle) + + r_msg = self.call_rpkid(q_msg, suppress_error_check = True) + + self.check_error_report(r_msg) + + tenant_pdu = r_msg.find(rpki.left_right.tag_tenant) + + bsc_pdus = dict((r_pdu.get("bsc_handle"), r_pdu) + for r_pdu in r_msg.getiterator(rpki.left_right.tag_bsc)) + repository_pdus = dict((r_pdu.get("repository_handle"), r_pdu) + for r_pdu in r_msg.getiterator(rpki.left_right.tag_repository)) + parent_pdus = dict((r_pdu.get("parent_handle"), r_pdu) + for r_pdu in r_msg.getiterator(rpki.left_right.tag_parent)) + child_pdus = dict((r_pdu.get("child_handle"), r_pdu) + for r_pdu in r_msg.getiterator(rpki.left_right.tag_child)) + + q_msg = self._compose_left_right_query() + + tenant_cert, created = rpki.irdb.models.HostedCA.objects.get_or_certify( + issuer = self.server_ca, + hosted = ca) + + # There should be exactly one <tenant/> object per hosted entity, by definition + + if (tenant_pdu is None or + tenant_pdu.get("crl_interval") != str(tenant_crl_interval) or + tenant_pdu.get("regen_margin") != str(tenant_regen_margin) or + tenant_pdu.findtext(rpki.left_right.tag_bpki_cert, "").decode("base64") != tenant_cert.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.left_right.tag_tenant, + action = "create" if tenant_pdu is None else "set", + tag = "tenant", + tenant_handle = ca.handle, + crl_interval = str(tenant_crl_interval), + regen_margin = str(tenant_regen_margin)) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = ca.certificate.get_Base64() + + # In general we only need one <bsc/> per <tenant/>. BSC objects + # are a little unusual in that the keypair and PKCS #10 + # subelement are generated by rpkid, so complete setup requires + # two round trips. + + bsc_pdu = bsc_pdus.pop(bsc_handle, None) + + if bsc_pdu is None or bsc_pdu.find(rpki.left_right.tag_pkcs10_request) is None: + SubElement(q_msg, rpki.left_right.tag_bsc, + action = "create" if bsc_pdu is None else "set", + tag = "bsc", + tenant_handle = ca.handle, + bsc_handle = bsc_handle, + generate_keypair = "yes") + + for bsc_handle in bsc_pdus: + SubElement(q_msg, rpki.left_right.tag_bsc, + action = "destroy", tenant_handle = ca.handle, bsc_handle = bsc_handle) + + # If we've already got actions queued up, run them now, so we + # can finish setting up the BSC before anything tries to use it. + + if len(q_msg) > 0: + SubElement(q_msg, rpki.left_right.tag_bsc, action = "list", tag = "bsc", tenant_handle = ca.handle) + r_msg = self.call_rpkid(q_msg) + bsc_pdus = dict((r_pdu.get("bsc_handle"), r_pdu) + for r_pdu in r_msg.getiterator(rpki.left_right.tag_bsc) + if r_pdu.get("action") == "list") + bsc_pdu = bsc_pdus.pop(bsc_handle, None) + + q_msg = self._compose_left_right_query() + + bsc_pkcs10 = bsc_pdu.find(rpki.left_right.tag_pkcs10_request) + assert bsc_pkcs10 is not None + + bsc, created = rpki.irdb.models.BSC.objects.get_or_certify( + issuer = ca, + handle = bsc_handle, + pkcs10 = rpki.x509.PKCS10(Base64 = bsc_pkcs10.text)) + + if (bsc_pdu.findtext(rpki.left_right.tag_signing_cert, "").decode("base64") != bsc.certificate.get_DER() or + bsc_pdu.findtext(rpki.left_right.tag_signing_cert_crl, "").decode("base64") != ca.latest_crl.get_DER()): + q_pdu = SubElement(q_msg, rpki.left_right.tag_bsc, + action = "set", + tag = "bsc", + tenant_handle = ca.handle, + bsc_handle = bsc_handle) + SubElement(q_pdu, rpki.left_right.tag_signing_cert).text = bsc.certificate.get_Base64() + SubElement(q_pdu, rpki.left_right.tag_signing_cert_crl).text = ca.latest_crl.get_Base64() + + # At present we need one <repository/> per <parent/>, not because + # rpkid requires that, but because pubd does. pubd probably should + # be fixed to support a single client allowed to update multiple + # trees, but for the moment the easiest way forward is just to + # enforce a 1:1 mapping between <parent/> and <repository/> objects + + for repository in ca.repositories.all(): + + repository_pdu = repository_pdus.pop(repository.handle, None) + + if (repository_pdu is None or + repository_pdu.get("bsc_handle") != bsc_handle or + repository_pdu.get("peer_contact_uri") != repository.service_uri or + repository_pdu.get("rrdp_notification_uri") != repository.rrdp_notification_uri or + repository_pdu.findtext(rpki.left_right.tag_bpki_cert, "").decode("base64") != repository.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.left_right.tag_repository, + action = "create" if repository_pdu is None else "set", + tag = repository.handle, + tenant_handle = ca.handle, + repository_handle = repository.handle, + bsc_handle = bsc_handle, + peer_contact_uri = repository.service_uri) + if repository.rrdp_notification_uri: + q_pdu.set("rrdp_notification_uri", repository.rrdp_notification_uri) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = repository.certificate.get_Base64() + + for repository_handle in repository_pdus: + SubElement(q_msg, rpki.left_right.tag_repository, action = "destroy", + tenant_handle = ca.handle, repository_handle = repository_handle) + + # <parent/> setup code currently assumes 1:1 mapping between + # <repository/> and <parent/>, and further assumes that the handles + # for an associated pair are the identical (that is: + # parent.repository_handle == parent.parent_handle). + # + # If no such repository exists, our choices are to ignore the + # parent entry or throw an error. For now, we ignore the parent. + + for parent in ca.parents.all(): + + try: + parent_pdu = parent_pdus.pop(parent.handle, None) + + if (parent_pdu is None or + parent_pdu.get("bsc_handle") != bsc_handle or + parent_pdu.get("repository_handle") != parent.handle or + parent_pdu.get("peer_contact_uri") != parent.service_uri or + parent_pdu.get("sia_base") != parent.repository.sia_base or + parent_pdu.get("sender_name") != parent.child_handle or + parent_pdu.get("recipient_name") != parent.parent_handle or + parent_pdu.findtext(rpki.left_right.tag_bpki_cert, "").decode("base64") != parent.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, + action = "create" if parent_pdu is None else "set", + tag = parent.handle, + tenant_handle = ca.handle, + parent_handle = parent.handle, + bsc_handle = bsc_handle, + repository_handle = parent.handle, + peer_contact_uri = parent.service_uri, + sia_base = parent.repository.sia_base, + sender_name = parent.child_handle, + recipient_name = parent.parent_handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = parent.certificate.get_Base64() + + except rpki.irdb.models.Repository.DoesNotExist: + pass - self.synchronize_rpkid_deleted_core() - - - @django.db.transaction.atomic - def synchronize_pubd(self): - """ - Synchronize pubd. Most commands which modify pubd should call this. - """ - - self.synchronize_pubd_core() - - - def synchronize_rpkid_one_ca_core(self, ca, poke = False): - """ - Synchronize one CA. This is the core synchronization code. Don't - call this directly, instead call one of the methods that calls - this inside a Django commit wrapper. - - This method configures rpkid with data built up by the other - commands in this program. Most commands which modify IRDB values - related to rpkid should call this when they're done. - - If poke is True, we append a left-right run_now operation for this - CA to the end of whatever other commands this method generates. - """ - - # We can use a single BSC for everything -- except BSC key - # rollovers. Drive off that bridge when we get to it. - - bsc_handle = "bsc" - - # A default RPKI CRL cycle time of six hours seems sane. One - # might make a case for a day instead, but we've been running with - # six hours for a while now and haven't seen a lot of whining. - - tenant_crl_interval = self.cfg.getint("tenant_crl_interval", 6 * 60 * 60, section = myrpki_section) - - # regen_margin now just controls how long before RPKI certificate - # expiration we should regenerate; it used to control the interval - # before RPKI CRL staleness at which to regenerate the CRL, but - # using the same timer value for both of these is hopeless. - # - # A default regeneration margin of two weeks gives enough time for - # humans to react. We add a two hour fudge factor in the hope - # that this will regenerate certificates just *before* the - # companion cron job warns of impending doom. - - tenant_regen_margin = self.cfg.getint("tenant_regen_margin", 14 * 24 * 60 * 60 + 2 * 60, section = myrpki_section) - - # See what rpkid already has on file for this entity. - - q_msg = self._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_tenant, action = "get", tenant_handle = ca.handle) - SubElement(q_msg, rpki.left_right.tag_bsc, action = "list", tenant_handle = ca.handle) - SubElement(q_msg, rpki.left_right.tag_repository, action = "list", tenant_handle = ca.handle) - SubElement(q_msg, rpki.left_right.tag_parent, action = "list", tenant_handle = ca.handle) - SubElement(q_msg, rpki.left_right.tag_child, action = "list", tenant_handle = ca.handle) - - r_msg = self.call_rpkid(q_msg, suppress_error_check = True) - - self.check_error_report(r_msg) - - tenant_pdu = r_msg.find(rpki.left_right.tag_tenant) - - bsc_pdus = dict((r_pdu.get("bsc_handle"), r_pdu) - for r_pdu in r_msg.getiterator(rpki.left_right.tag_bsc)) - repository_pdus = dict((r_pdu.get("repository_handle"), r_pdu) - for r_pdu in r_msg.getiterator(rpki.left_right.tag_repository)) - parent_pdus = dict((r_pdu.get("parent_handle"), r_pdu) - for r_pdu in r_msg.getiterator(rpki.left_right.tag_parent)) - child_pdus = dict((r_pdu.get("child_handle"), r_pdu) - for r_pdu in r_msg.getiterator(rpki.left_right.tag_child)) - - q_msg = self._compose_left_right_query() - - tenant_cert, created = rpki.irdb.models.HostedCA.objects.get_or_certify( - issuer = self.server_ca, - hosted = ca) - - # There should be exactly one <tenant/> object per hosted entity, by definition - - if (tenant_pdu is None or - tenant_pdu.get("crl_interval") != str(tenant_crl_interval) or - tenant_pdu.get("regen_margin") != str(tenant_regen_margin) or - tenant_pdu.findtext(rpki.left_right.tag_bpki_cert, "").decode("base64") != tenant_cert.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.left_right.tag_tenant, - action = "create" if tenant_pdu is None else "set", - tag = "tenant", - tenant_handle = ca.handle, - crl_interval = str(tenant_crl_interval), - regen_margin = str(tenant_regen_margin)) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = ca.certificate.get_Base64() - - # In general we only need one <bsc/> per <tenant/>. BSC objects - # are a little unusual in that the keypair and PKCS #10 - # subelement are generated by rpkid, so complete setup requires - # two round trips. - - bsc_pdu = bsc_pdus.pop(bsc_handle, None) - - if bsc_pdu is None or bsc_pdu.find(rpki.left_right.tag_pkcs10_request) is None: - SubElement(q_msg, rpki.left_right.tag_bsc, - action = "create" if bsc_pdu is None else "set", - tag = "bsc", - tenant_handle = ca.handle, - bsc_handle = bsc_handle, - generate_keypair = "yes") - - for bsc_handle in bsc_pdus: - SubElement(q_msg, rpki.left_right.tag_bsc, - action = "destroy", tenant_handle = ca.handle, bsc_handle = bsc_handle) - - # If we've already got actions queued up, run them now, so we - # can finish setting up the BSC before anything tries to use it. - - if len(q_msg) > 0: - SubElement(q_msg, rpki.left_right.tag_bsc, action = "list", tag = "bsc", tenant_handle = ca.handle) - r_msg = self.call_rpkid(q_msg) - bsc_pdus = dict((r_pdu.get("bsc_handle"), r_pdu) - for r_pdu in r_msg.getiterator(rpki.left_right.tag_bsc) - if r_pdu.get("action") == "list") - bsc_pdu = bsc_pdus.pop(bsc_handle, None) - - q_msg = self._compose_left_right_query() - - bsc_pkcs10 = bsc_pdu.find(rpki.left_right.tag_pkcs10_request) - assert bsc_pkcs10 is not None - - bsc, created = rpki.irdb.models.BSC.objects.get_or_certify( - issuer = ca, - handle = bsc_handle, - pkcs10 = rpki.x509.PKCS10(Base64 = bsc_pkcs10.text)) - - if (bsc_pdu.findtext(rpki.left_right.tag_signing_cert, "").decode("base64") != bsc.certificate.get_DER() or - bsc_pdu.findtext(rpki.left_right.tag_signing_cert_crl, "").decode("base64") != ca.latest_crl.get_DER()): - q_pdu = SubElement(q_msg, rpki.left_right.tag_bsc, - action = "set", - tag = "bsc", - tenant_handle = ca.handle, - bsc_handle = bsc_handle) - SubElement(q_pdu, rpki.left_right.tag_signing_cert).text = bsc.certificate.get_Base64() - SubElement(q_pdu, rpki.left_right.tag_signing_cert_crl).text = ca.latest_crl.get_Base64() - - # At present we need one <repository/> per <parent/>, not because - # rpkid requires that, but because pubd does. pubd probably should - # be fixed to support a single client allowed to update multiple - # trees, but for the moment the easiest way forward is just to - # enforce a 1:1 mapping between <parent/> and <repository/> objects - - for repository in ca.repositories.all(): - - repository_pdu = repository_pdus.pop(repository.handle, None) - - if (repository_pdu is None or - repository_pdu.get("bsc_handle") != bsc_handle or - repository_pdu.get("peer_contact_uri") != repository.service_uri or - repository_pdu.get("rrdp_notification_uri") != repository.rrdp_notification_uri or - repository_pdu.findtext(rpki.left_right.tag_bpki_cert, "").decode("base64") != repository.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.left_right.tag_repository, - action = "create" if repository_pdu is None else "set", - tag = repository.handle, - tenant_handle = ca.handle, - repository_handle = repository.handle, - bsc_handle = bsc_handle, - peer_contact_uri = repository.service_uri) - if repository.rrdp_notification_uri: - q_pdu.set("rrdp_notification_uri", repository.rrdp_notification_uri) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = repository.certificate.get_Base64() - - for repository_handle in repository_pdus: - SubElement(q_msg, rpki.left_right.tag_repository, action = "destroy", - tenant_handle = ca.handle, repository_handle = repository_handle) - - # <parent/> setup code currently assumes 1:1 mapping between - # <repository/> and <parent/>, and further assumes that the handles - # for an associated pair are the identical (that is: - # parent.repository_handle == parent.parent_handle). - # - # If no such repository exists, our choices are to ignore the - # parent entry or throw an error. For now, we ignore the parent. - - for parent in ca.parents.all(): - - try: - parent_pdu = parent_pdus.pop(parent.handle, None) - - if (parent_pdu is None or - parent_pdu.get("bsc_handle") != bsc_handle or - parent_pdu.get("repository_handle") != parent.handle or - parent_pdu.get("peer_contact_uri") != parent.service_uri or - parent_pdu.get("sia_base") != parent.repository.sia_base or - parent_pdu.get("sender_name") != parent.child_handle or - parent_pdu.get("recipient_name") != parent.parent_handle or - parent_pdu.findtext(rpki.left_right.tag_bpki_cert, "").decode("base64") != parent.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, - action = "create" if parent_pdu is None else "set", - tag = parent.handle, - tenant_handle = ca.handle, - parent_handle = parent.handle, - bsc_handle = bsc_handle, - repository_handle = parent.handle, - peer_contact_uri = parent.service_uri, - sia_base = parent.repository.sia_base, - sender_name = parent.child_handle, - recipient_name = parent.parent_handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = parent.certificate.get_Base64() - - except rpki.irdb.models.Repository.DoesNotExist: - pass - - try: - - parent_pdu = parent_pdus.pop(ca.handle, None) - - if (parent_pdu is None or - parent_pdu.get("bsc_handle") != bsc_handle or - parent_pdu.get("repository_handle") != ca.handle or - parent_pdu.get("peer_contact_uri") != ca.rootd.service_uri or - parent_pdu.get("sia_base") != ca.rootd.repository.sia_base or - parent_pdu.get("sender_name") != ca.handle or - parent_pdu.get("recipient_name") != ca.handle or - parent_pdu.findtext(rpki.left_right.tag_bpki_cert).decode("base64") != ca.rootd.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, - action = "create" if parent_pdu is None else "set", - tag = ca.handle, - tenant_handle = ca.handle, - parent_handle = ca.handle, - bsc_handle = bsc_handle, - repository_handle = ca.handle, - peer_contact_uri = ca.rootd.service_uri, - sia_base = ca.rootd.repository.sia_base, - sender_name = ca.handle, - recipient_name = ca.handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = ca.rootd.certificate.get_Base64() - - except rpki.irdb.models.Rootd.DoesNotExist: - pass - - for parent_handle in parent_pdus: - SubElement(q_msg, rpki.left_right.tag_parent, action = "destroy", - tenant_handle = ca.handle, parent_handle = parent_handle) - - # Children are simpler than parents, because they call us, so no URL - # to construct and figuring out what certificate to use is their - # problem, not ours. - - for child in ca.children.all(): - - child_pdu = child_pdus.pop(child.handle, None) - - if (child_pdu is None or - child_pdu.get("bsc_handle") != bsc_handle or - child_pdu.findtext(rpki.left_right.tag_bpki_cert).decode("base64") != child.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.left_right.tag_child, - action = "create" if child_pdu is None else "set", - tag = child.handle, - tenant_handle = ca.handle, - child_handle = child.handle, - bsc_handle = bsc_handle) - SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = child.certificate.get_Base64() - - for child_handle in child_pdus: - SubElement(q_msg, rpki.left_right.tag_child, action = "destroy", - tenant_handle = ca.handle, child_handle = child_handle) - - # If caller wants us to poke rpkid, add that to the very end of the message - - if poke: - SubElement(q_msg, rpki.left_right.tag_tenant, action = "set", tenant_handle = ca.handle, run_now = "yes") - - # If we changed anything, ship updates off to rpkid. - - if len(q_msg) > 0: - self.call_rpkid(q_msg) - - - def synchronize_pubd_core(self): - """ - Configure pubd with data built up by the other commands in this - program. This is the core synchronization code. Don't call this - directly, instead call a methods that calls this inside a Django - commit wrapper. - - This method configures pubd with data built up by the other - commands in this program. Commands which modify IRDB fields - related to pubd should call this when they're done. - """ - - # If we're not running pubd, the rest of this is a waste of time - - if not self.run_pubd: - return - - # See what pubd already has on file - - q_msg = self._compose_publication_control_query() - SubElement(q_msg, rpki.publication_control.tag_client, action = "list") - r_msg = self.call_pubd(q_msg) - client_pdus = dict((r_pdu.get("client_handle"), r_pdu) - for r_pdu in r_msg) + try: - # Check all clients + parent_pdu = parent_pdus.pop(ca.handle, None) + + if (parent_pdu is None or + parent_pdu.get("bsc_handle") != bsc_handle or + parent_pdu.get("repository_handle") != ca.handle or + parent_pdu.get("peer_contact_uri") != ca.rootd.service_uri or + parent_pdu.get("sia_base") != ca.rootd.repository.sia_base or + parent_pdu.get("sender_name") != ca.handle or + parent_pdu.get("recipient_name") != ca.handle or + parent_pdu.findtext(rpki.left_right.tag_bpki_cert).decode("base64") != ca.rootd.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.left_right.tag_parent, + action = "create" if parent_pdu is None else "set", + tag = ca.handle, + tenant_handle = ca.handle, + parent_handle = ca.handle, + bsc_handle = bsc_handle, + repository_handle = ca.handle, + peer_contact_uri = ca.rootd.service_uri, + sia_base = ca.rootd.repository.sia_base, + sender_name = ca.handle, + recipient_name = ca.handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = ca.rootd.certificate.get_Base64() - q_msg = self._compose_publication_control_query() + except rpki.irdb.models.Rootd.DoesNotExist: + pass - for client in self.server_ca.clients.all(): + for parent_handle in parent_pdus: + SubElement(q_msg, rpki.left_right.tag_parent, action = "destroy", + tenant_handle = ca.handle, parent_handle = parent_handle) - client_pdu = client_pdus.pop(client.handle, None) + # Children are simpler than parents, because they call us, so no URL + # to construct and figuring out what certificate to use is their + # problem, not ours. - if (client_pdu is None or - client_pdu.get("base_uri") != client.sia_base or - client_pdu.findtext(rpki.publication_control.tag_bpki_cert, "").decode("base64") != client.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.publication_control.tag_client, - action = "create" if client_pdu is None else "set", - client_handle = client.handle, - base_uri = client.sia_base) - SubElement(q_pdu, rpki.publication_control.tag_bpki_cert).text = client.certificate.get_Base64() + for child in ca.children.all(): - # rootd instances are also a weird sort of client + child_pdu = child_pdus.pop(child.handle, None) - for rootd in rpki.irdb.models.Rootd.objects.all(): + if (child_pdu is None or + child_pdu.get("bsc_handle") != bsc_handle or + child_pdu.findtext(rpki.left_right.tag_bpki_cert).decode("base64") != child.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.left_right.tag_child, + action = "create" if child_pdu is None else "set", + tag = child.handle, + tenant_handle = ca.handle, + child_handle = child.handle, + bsc_handle = bsc_handle) + SubElement(q_pdu, rpki.left_right.tag_bpki_cert).text = child.certificate.get_Base64() - client_handle = rootd.issuer.handle + "-root" - client_pdu = client_pdus.pop(client_handle, None) - sia_base = "rsync://%s/%s/%s/" % (self.rsync_server, self.rsync_module, client_handle) + for child_handle in child_pdus: + SubElement(q_msg, rpki.left_right.tag_child, action = "destroy", + tenant_handle = ca.handle, child_handle = child_handle) - if (client_pdu is None or - client_pdu.get("base_uri") != sia_base or - client_pdu.findtext(rpki.publication_control.tag_bpki_cert, "").decode("base64") != rootd.issuer.certificate.get_DER()): - q_pdu = SubElement(q_msg, rpki.publication_control.tag_client, - action = "create" if client_pdu is None else "set", - client_handle = client_handle, - base_uri = sia_base) - SubElement(q_pdu, rpki.publication_control.tag_bpki_cert).text = rootd.issuer.certificate.get_Base64() + # If caller wants us to poke rpkid, add that to the very end of the message - # Delete any unknown clients + if poke: + SubElement(q_msg, rpki.left_right.tag_tenant, action = "set", tenant_handle = ca.handle, run_now = "yes") - for client_handle in client_pdus: - SubElement(q_msg, rpki.publication_control.tag_client, action = "destroy", client_handle = client_handle) + # If we changed anything, ship updates off to rpkid. - # If we changed anything, ship updates off to pubd + if len(q_msg) > 0: + self.call_rpkid(q_msg) - if len(q_msg) > 0: - self.call_pubd(q_msg) + def synchronize_pubd_core(self): + """ + Configure pubd with data built up by the other commands in this + program. This is the core synchronization code. Don't call this + directly, instead call a methods that calls this inside a Django + commit wrapper. - def synchronize_rpkid_deleted_core(self): - """ - Remove any <tenant/> objects present in rpkid's database but not - present in the IRDB. This is the core synchronization code. - Don't call this directly, instead call a methods that calls this - inside a Django commit wrapper. - """ + This method configures pubd with data built up by the other + commands in this program. Commands which modify IRDB fields + related to pubd should call this when they're done. + """ - q_msg = self._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_tenant, action = "list") - self.call_rpkid(q_msg) + # If we're not running pubd, the rest of this is a waste of time - tenant_handles = set(s.get("tenant_handle") for s in q_msg) - ca_handles = set(ca.handle for ca in rpki.irdb.models.ResourceHolderCA.objects.all()) - assert ca_handles <= tenant_handles + if not self.run_pubd: + return - q_msg = self._compose_left_right_query() - for handle in (tenant_handles - ca_handles): - SubElement(q_msg, rpki.left_right.tag_tenant, action = "destroy", tenant_handle = handle) + # See what pubd already has on file - if len(q_msg) > 0: - self.call_rpkid(q_msg) + q_msg = self._compose_publication_control_query() + SubElement(q_msg, rpki.publication_control.tag_client, action = "list") + r_msg = self.call_pubd(q_msg) + client_pdus = dict((r_pdu.get("client_handle"), r_pdu) + for r_pdu in r_msg) + # Check all clients - @django.db.transaction.atomic - def add_ee_certificate_request(self, pkcs10, resources): - """ - Check a PKCS #10 request to see if it complies with the - specification for a RPKI EE certificate; if it does, add an - EECertificateRequest for it to the IRDB. + q_msg = self._compose_publication_control_query() - Not yet sure what we want for update and delete semantics here, so - for the moment this is straight addition. See methods like - .load_asns() and .load_prefixes() for other strategies. - """ + for client in self.server_ca.clients.all(): - pkcs10.check_valid_request_ee() - ee_request = self.resource_ca.ee_certificate_requests.create( - pkcs10 = pkcs10, - gski = pkcs10.gSKI(), - valid_until = resources.valid_until) - for r in resources.asn: - ee_request.asns.create(start_as = str(r.min), end_as = str(r.max)) - for r in resources.v4: - ee_request.address_ranges.create(start_ip = str(r.min), end_ip = str(r.max), version = 4) - for r in resources.v6: - ee_request.address_ranges.create(start_ip = str(r.min), end_ip = str(r.max), version = 6) - - - @django.db.transaction.atomic - def add_router_certificate_request(self, router_certificate_request_xml, valid_until = None): - """ - Read XML file containing one or more router certificate requests, - attempt to add request(s) to IRDB. + client_pdu = client_pdus.pop(client.handle, None) - Check each PKCS #10 request to see if it complies with the - specification for a router certificate; if it does, create an EE - certificate request for it along with the ASN resources and - router-ID supplied in the XML. - """ + if (client_pdu is None or + client_pdu.get("base_uri") != client.sia_base or + client_pdu.findtext(rpki.publication_control.tag_bpki_cert, "").decode("base64") != client.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.publication_control.tag_client, + action = "create" if client_pdu is None else "set", + client_handle = client.handle, + base_uri = client.sia_base) + SubElement(q_pdu, rpki.publication_control.tag_bpki_cert).text = client.certificate.get_Base64() - x = etree_read(router_certificate_request_xml, schema = rpki.relaxng.router_certificate) + # rootd instances are also a weird sort of client - for x in x.getiterator(tag_router_certificate_request): + for rootd in rpki.irdb.models.Rootd.objects.all(): - pkcs10 = rpki.x509.PKCS10(Base64 = x.text) - router_id = long(x.get("router_id")) - asns = rpki.resource_set.resource_set_as(x.get("asn")) - if not valid_until: - valid_until = x.get("valid_until") + client_handle = rootd.issuer.handle + "-root" + client_pdu = client_pdus.pop(client_handle, None) + sia_base = "rsync://%s/%s/%s/" % (self.rsync_server, self.rsync_module, client_handle) - if valid_until and isinstance(valid_until, (str, unicode)): - valid_until = rpki.sundial.datetime.fromXMLtime(valid_until) + if (client_pdu is None or + client_pdu.get("base_uri") != sia_base or + client_pdu.findtext(rpki.publication_control.tag_bpki_cert, "").decode("base64") != rootd.issuer.certificate.get_DER()): + q_pdu = SubElement(q_msg, rpki.publication_control.tag_client, + action = "create" if client_pdu is None else "set", + client_handle = client_handle, + base_uri = sia_base) + SubElement(q_pdu, rpki.publication_control.tag_bpki_cert).text = rootd.issuer.certificate.get_Base64() - if not valid_until: - valid_until = rpki.sundial.now() + rpki.sundial.timedelta(days = 365) - elif valid_until < rpki.sundial.now(): - raise PastExpiration("Specified expiration date %s has already passed" % valid_until) + # Delete any unknown clients - pkcs10.check_valid_request_router() + for client_handle in client_pdus: + SubElement(q_msg, rpki.publication_control.tag_client, action = "destroy", client_handle = client_handle) - cn = "ROUTER-%08x" % asns[0].min - sn = "%08x" % router_id + # If we changed anything, ship updates off to pubd - ee_request = self.resource_ca.ee_certificate_requests.create( - pkcs10 = pkcs10, - gski = pkcs10.gSKI(), - valid_until = valid_until, - cn = cn, - sn = sn, - eku = rpki.oids.id_kp_bgpsec_router) + if len(q_msg) > 0: + self.call_pubd(q_msg) - for r in asns: - ee_request.asns.create(start_as = str(r.min), end_as = str(r.max)) + def synchronize_rpkid_deleted_core(self): + """ + Remove any <tenant/> objects present in rpkid's database but not + present in the IRDB. This is the core synchronization code. + Don't call this directly, instead call a methods that calls this + inside a Django commit wrapper. + """ - @django.db.transaction.atomic - def delete_router_certificate_request(self, gski): - """ - Delete a router certificate request from this RPKI entity. - """ + q_msg = self._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_tenant, action = "list") + self.call_rpkid(q_msg) - self.resource_ca.ee_certificate_requests.get(gski = gski).delete() + tenant_handles = set(s.get("tenant_handle") for s in q_msg) + ca_handles = set(ca.handle for ca in rpki.irdb.models.ResourceHolderCA.objects.all()) + assert ca_handles <= tenant_handles + + q_msg = self._compose_left_right_query() + for handle in (tenant_handles - ca_handles): + SubElement(q_msg, rpki.left_right.tag_tenant, action = "destroy", tenant_handle = handle) + + if len(q_msg) > 0: + self.call_rpkid(q_msg) + + + @django.db.transaction.atomic + def add_ee_certificate_request(self, pkcs10, resources): + """ + Check a PKCS #10 request to see if it complies with the + specification for a RPKI EE certificate; if it does, add an + EECertificateRequest for it to the IRDB. + + Not yet sure what we want for update and delete semantics here, so + for the moment this is straight addition. See methods like + .load_asns() and .load_prefixes() for other strategies. + """ + + pkcs10.check_valid_request_ee() + ee_request = self.resource_ca.ee_certificate_requests.create( + pkcs10 = pkcs10, + gski = pkcs10.gSKI(), + valid_until = resources.valid_until) + for r in resources.asn: + ee_request.asns.create(start_as = str(r.min), end_as = str(r.max)) + for r in resources.v4: + ee_request.address_ranges.create(start_ip = str(r.min), end_ip = str(r.max), version = 4) + for r in resources.v6: + ee_request.address_ranges.create(start_ip = str(r.min), end_ip = str(r.max), version = 6) + + + @django.db.transaction.atomic + def add_router_certificate_request(self, router_certificate_request_xml, valid_until = None): + """ + Read XML file containing one or more router certificate requests, + attempt to add request(s) to IRDB. + + Check each PKCS #10 request to see if it complies with the + specification for a router certificate; if it does, create an EE + certificate request for it along with the ASN resources and + router-ID supplied in the XML. + """ + + x = etree_read(router_certificate_request_xml, schema = rpki.relaxng.router_certificate) + + for x in x.getiterator(tag_router_certificate_request): + + pkcs10 = rpki.x509.PKCS10(Base64 = x.text) + router_id = long(x.get("router_id")) + asns = rpki.resource_set.resource_set_as(x.get("asn")) + if not valid_until: + valid_until = x.get("valid_until") + + if valid_until and isinstance(valid_until, (str, unicode)): + valid_until = rpki.sundial.datetime.fromXMLtime(valid_until) + + if not valid_until: + valid_until = rpki.sundial.now() + rpki.sundial.timedelta(days = 365) + elif valid_until < rpki.sundial.now(): + raise PastExpiration("Specified expiration date %s has already passed" % valid_until) + + pkcs10.check_valid_request_router() + + cn = "ROUTER-%08x" % asns[0].min + sn = "%08x" % router_id + + ee_request = self.resource_ca.ee_certificate_requests.create( + pkcs10 = pkcs10, + gski = pkcs10.gSKI(), + valid_until = valid_until, + cn = cn, + sn = sn, + eku = rpki.oids.id_kp_bgpsec_router) + + for r in asns: + ee_request.asns.create(start_as = str(r.min), end_as = str(r.max)) + + + @django.db.transaction.atomic + def delete_router_certificate_request(self, gski): + """ + Delete a router certificate request from this RPKI entity. + """ + + self.resource_ca.ee_certificate_requests.get(gski = gski).delete() diff --git a/rpki/irdbd.py b/rpki/irdbd.py index 96757477..91859f5d 100644 --- a/rpki/irdbd.py +++ b/rpki/irdbd.py @@ -41,183 +41,183 @@ logger = logging.getLogger(__name__) class main(object): - def handle_list_resources(self, q_pdu, r_msg): - tenant_handle = q_pdu.get("tenant_handle") - child_handle = q_pdu.get("child_handle") - child = rpki.irdb.models.Child.objects.get(issuer__handle = tenant_handle, handle = child_handle) - resources = child.resource_bag - r_pdu = SubElement(r_msg, rpki.left_right.tag_list_resources, tenant_handle = tenant_handle, child_handle = child_handle, - valid_until = child.valid_until.strftime("%Y-%m-%dT%H:%M:%SZ")) - for k, v in (("asn", resources.asn), - ("ipv4", resources.v4), - ("ipv6", resources.v6), - ("tag", q_pdu.get("tag"))): - if v: - r_pdu.set(k, str(v)) - - def handle_list_roa_requests(self, q_pdu, r_msg): - tenant_handle = q_pdu.get("tenant_handle") - for request in rpki.irdb.models.ROARequest.objects.raw(""" - SELECT irdb_roarequest.* - FROM irdb_roarequest, irdb_resourceholderca - WHERE irdb_roarequest.issuer_id = irdb_resourceholderca.id - AND irdb_resourceholderca.handle = %s - """, [tenant_handle]): - prefix_bag = request.roa_prefix_bag - r_pdu = SubElement(r_msg, rpki.left_right.tag_list_roa_requests, tenant_handle = tenant_handle, asn = str(request.asn)) - for k, v in (("ipv4", prefix_bag.v4), - ("ipv6", prefix_bag.v6), - ("tag", q_pdu.get("tag"))): - if v: - r_pdu.set(k, str(v)) - - def handle_list_ghostbuster_requests(self, q_pdu, r_msg): - tenant_handle = q_pdu.get("tenant_handle") - parent_handle = q_pdu.get("parent_handle") - ghostbusters = rpki.irdb.models.GhostbusterRequest.objects.filter(issuer__handle = tenant_handle, parent__handle = parent_handle) - if ghostbusters.count() == 0: - ghostbusters = rpki.irdb.models.GhostbusterRequest.objects.filter(issuer__handle = tenant_handle, parent = None) - for ghostbuster in ghostbusters: - r_pdu = SubElement(r_msg, q_pdu.tag, tenant_handle = tenant_handle, parent_handle = parent_handle) - if q_pdu.get("tag"): - r_pdu.set("tag", q_pdu.get("tag")) - r_pdu.text = ghostbuster.vcard - - def handle_list_ee_certificate_requests(self, q_pdu, r_msg): - tenant_handle = q_pdu.get("tenant_handle") - for ee_req in rpki.irdb.models.EECertificateRequest.objects.filter(issuer__handle = tenant_handle): - resources = ee_req.resource_bag - r_pdu = SubElement(r_msg, q_pdu.tag, tenant_handle = tenant_handle, gski = ee_req.gski, - valid_until = ee_req.valid_until.strftime("%Y-%m-%dT%H:%M:%SZ"), - cn = ee_req.cn, sn = ee_req.sn) - for k, v in (("asn", resources.asn), - ("ipv4", resources.v4), - ("ipv6", resources.v6), - ("eku", ee_req.eku), - ("tag", q_pdu.get("tag"))): - if v: - r_pdu.set(k, str(v)) - SubElement(r_pdu, rpki.left_right.tag_pkcs10).text = ee_req.pkcs10.get_Base64() - - def handler(self, request, q_der): - try: - from django.db import connection - connection.cursor() # Reconnect to mysqld if necessary - self.start_new_transaction() - serverCA = rpki.irdb.models.ServerCA.objects.get() - rpkid = serverCA.ee_certificates.get(purpose = "rpkid") - irdbd = serverCA.ee_certificates.get(purpose = "irdbd") - q_cms = rpki.left_right.cms_msg(DER = q_der) - q_msg = q_cms.unwrap((serverCA.certificate, rpkid.certificate)) - self.cms_timestamp = q_cms.check_replay(self.cms_timestamp, request.path) - if q_msg.get("type") != "query": - raise rpki.exceptions.BadQuery("Message type is %s, expected query" % q_msg.get("type")) - r_msg = Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, - type = "reply", version = rpki.left_right.version) - try: - for q_pdu in q_msg: - getattr(self, "handle_" + q_pdu.tag[len(rpki.left_right.xmlns):])(q_pdu, r_msg) - - except Exception, e: - logger.exception("Exception processing PDU %r", q_pdu) - r_pdu = SubElement(r_msg, rpki.left_right.tag_report_error, error_code = e.__class__.__name__) - r_pdu.text = str(e) - if q_pdu.get("tag") is not None: - r_pdu.set("tag", q_pdu.get("tag")) - - request.send_cms_response(rpki.left_right.cms_msg().wrap(r_msg, irdbd.private_key, irdbd.certificate)) - - except Exception, e: - logger.exception("Unhandled exception while processing HTTP request") - request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) - - def __init__(self, **kwargs): - - global rpki # pylint: disable=W0602 - - os.environ.update(TZ = "UTC", - DJANGO_SETTINGS_MODULE = "rpki.django_settings.irdb") - time.tzset() - - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument("-c", "--config", - help = "override default location of configuration file") - parser.add_argument("-f", "--foreground", action = "store_true", - help = "do not daemonize") - parser.add_argument("--pidfile", - help = "override default location of pid file") - parser.add_argument("--profile", - help = "enable profiling, saving data to PROFILE") - rpki.log.argparse_setup(parser) - args = parser.parse_args() - - rpki.log.init("irdbd", args) - - self.cfg = rpki.config.parser(set_filename = args.config, section = "irdbd") - self.cfg.set_global_flags() - - if not args.foreground: - rpki.daemonize.daemon(pidfile = args.pidfile) - - if args.profile: - import cProfile - prof = cProfile.Profile() - try: - prof.runcall(self.main) - finally: - prof.dump_stats(args.profile) - logger.info("Dumped profile data to %s", args.profile) - else: - self.main() - - def main(self): - - startup_msg = self.cfg.get("startup-message", "") - if startup_msg: - logger.info(startup_msg) - - # Now that we know which configuration file to use, it's OK to - # load modules that require Django's settings module. - - import django - django.setup() - - global rpki # pylint: disable=W0602 - import rpki.irdb # pylint: disable=W0621 - - self.http_server_host = self.cfg.get("server-host", "") - self.http_server_port = self.cfg.getint("server-port") - - self.cms_timestamp = None - - rpki.http_simple.server( - host = self.http_server_host, - port = self.http_server_port, - handlers = self.handler) - - def start_new_transaction(self): - - # Entirely too much fun with read-only access to transactional databases. - # - # http://stackoverflow.com/questions/3346124/how-do-i-force-django-to-ignore-any-caches-and-reload-data - # http://devblog.resolversystems.com/?p=439 - # http://groups.google.com/group/django-users/browse_thread/thread/e25cec400598c06d - # http://stackoverflow.com/questions/1028671/python-mysqldb-update-query-fails - # http://dev.mysql.com/doc/refman/5.0/en/set-transaction.html - # - # It turns out that MySQL is doing us a favor with this weird - # transactional behavior on read, because without it there's a - # race condition if multiple updates are committed to the IRDB - # while we're in the middle of processing a query. Note that - # proper transaction management by the committers doesn't protect - # us, this is a transactional problem on read. So we need to use - # explicit transaction management. Since irdbd is a read-only - # consumer of IRDB data, this means we need to commit an empty - # transaction at the beginning of processing each query, to reset - # the transaction isolation snapshot. - - import django.db.transaction - - with django.db.transaction.atomic(): - #django.db.transaction.commit() - pass + def handle_list_resources(self, q_pdu, r_msg): + tenant_handle = q_pdu.get("tenant_handle") + child_handle = q_pdu.get("child_handle") + child = rpki.irdb.models.Child.objects.get(issuer__handle = tenant_handle, handle = child_handle) + resources = child.resource_bag + r_pdu = SubElement(r_msg, rpki.left_right.tag_list_resources, tenant_handle = tenant_handle, child_handle = child_handle, + valid_until = child.valid_until.strftime("%Y-%m-%dT%H:%M:%SZ")) + for k, v in (("asn", resources.asn), + ("ipv4", resources.v4), + ("ipv6", resources.v6), + ("tag", q_pdu.get("tag"))): + if v: + r_pdu.set(k, str(v)) + + def handle_list_roa_requests(self, q_pdu, r_msg): + tenant_handle = q_pdu.get("tenant_handle") + for request in rpki.irdb.models.ROARequest.objects.raw(""" + SELECT irdb_roarequest.* + FROM irdb_roarequest, irdb_resourceholderca + WHERE irdb_roarequest.issuer_id = irdb_resourceholderca.id + AND irdb_resourceholderca.handle = %s + """, [tenant_handle]): + prefix_bag = request.roa_prefix_bag + r_pdu = SubElement(r_msg, rpki.left_right.tag_list_roa_requests, tenant_handle = tenant_handle, asn = str(request.asn)) + for k, v in (("ipv4", prefix_bag.v4), + ("ipv6", prefix_bag.v6), + ("tag", q_pdu.get("tag"))): + if v: + r_pdu.set(k, str(v)) + + def handle_list_ghostbuster_requests(self, q_pdu, r_msg): + tenant_handle = q_pdu.get("tenant_handle") + parent_handle = q_pdu.get("parent_handle") + ghostbusters = rpki.irdb.models.GhostbusterRequest.objects.filter(issuer__handle = tenant_handle, parent__handle = parent_handle) + if ghostbusters.count() == 0: + ghostbusters = rpki.irdb.models.GhostbusterRequest.objects.filter(issuer__handle = tenant_handle, parent = None) + for ghostbuster in ghostbusters: + r_pdu = SubElement(r_msg, q_pdu.tag, tenant_handle = tenant_handle, parent_handle = parent_handle) + if q_pdu.get("tag"): + r_pdu.set("tag", q_pdu.get("tag")) + r_pdu.text = ghostbuster.vcard + + def handle_list_ee_certificate_requests(self, q_pdu, r_msg): + tenant_handle = q_pdu.get("tenant_handle") + for ee_req in rpki.irdb.models.EECertificateRequest.objects.filter(issuer__handle = tenant_handle): + resources = ee_req.resource_bag + r_pdu = SubElement(r_msg, q_pdu.tag, tenant_handle = tenant_handle, gski = ee_req.gski, + valid_until = ee_req.valid_until.strftime("%Y-%m-%dT%H:%M:%SZ"), + cn = ee_req.cn, sn = ee_req.sn) + for k, v in (("asn", resources.asn), + ("ipv4", resources.v4), + ("ipv6", resources.v6), + ("eku", ee_req.eku), + ("tag", q_pdu.get("tag"))): + if v: + r_pdu.set(k, str(v)) + SubElement(r_pdu, rpki.left_right.tag_pkcs10).text = ee_req.pkcs10.get_Base64() + + def handler(self, request, q_der): + try: + from django.db import connection + connection.cursor() # Reconnect to mysqld if necessary + self.start_new_transaction() + serverCA = rpki.irdb.models.ServerCA.objects.get() + rpkid = serverCA.ee_certificates.get(purpose = "rpkid") + irdbd = serverCA.ee_certificates.get(purpose = "irdbd") + q_cms = rpki.left_right.cms_msg(DER = q_der) + q_msg = q_cms.unwrap((serverCA.certificate, rpkid.certificate)) + self.cms_timestamp = q_cms.check_replay(self.cms_timestamp, request.path) + if q_msg.get("type") != "query": + raise rpki.exceptions.BadQuery("Message type is %s, expected query" % q_msg.get("type")) + r_msg = Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, + type = "reply", version = rpki.left_right.version) + try: + for q_pdu in q_msg: + getattr(self, "handle_" + q_pdu.tag[len(rpki.left_right.xmlns):])(q_pdu, r_msg) + + except Exception, e: + logger.exception("Exception processing PDU %r", q_pdu) + r_pdu = SubElement(r_msg, rpki.left_right.tag_report_error, error_code = e.__class__.__name__) + r_pdu.text = str(e) + if q_pdu.get("tag") is not None: + r_pdu.set("tag", q_pdu.get("tag")) + + request.send_cms_response(rpki.left_right.cms_msg().wrap(r_msg, irdbd.private_key, irdbd.certificate)) + + except Exception, e: + logger.exception("Unhandled exception while processing HTTP request") + request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + + def __init__(self, **kwargs): + + global rpki # pylint: disable=W0602 + + os.environ.update(TZ = "UTC", + DJANGO_SETTINGS_MODULE = "rpki.django_settings.irdb") + time.tzset() + + parser = argparse.ArgumentParser(description = __doc__) + parser.add_argument("-c", "--config", + help = "override default location of configuration file") + parser.add_argument("-f", "--foreground", action = "store_true", + help = "do not daemonize") + parser.add_argument("--pidfile", + help = "override default location of pid file") + parser.add_argument("--profile", + help = "enable profiling, saving data to PROFILE") + rpki.log.argparse_setup(parser) + args = parser.parse_args() + + rpki.log.init("irdbd", args) + + self.cfg = rpki.config.parser(set_filename = args.config, section = "irdbd") + self.cfg.set_global_flags() + + if not args.foreground: + rpki.daemonize.daemon(pidfile = args.pidfile) + + if args.profile: + import cProfile + prof = cProfile.Profile() + try: + prof.runcall(self.main) + finally: + prof.dump_stats(args.profile) + logger.info("Dumped profile data to %s", args.profile) + else: + self.main() + + def main(self): + + startup_msg = self.cfg.get("startup-message", "") + if startup_msg: + logger.info(startup_msg) + + # Now that we know which configuration file to use, it's OK to + # load modules that require Django's settings module. + + import django + django.setup() + + global rpki # pylint: disable=W0602 + import rpki.irdb # pylint: disable=W0621 + + self.http_server_host = self.cfg.get("server-host", "") + self.http_server_port = self.cfg.getint("server-port") + + self.cms_timestamp = None + + rpki.http_simple.server( + host = self.http_server_host, + port = self.http_server_port, + handlers = self.handler) + + def start_new_transaction(self): + + # Entirely too much fun with read-only access to transactional databases. + # + # http://stackoverflow.com/questions/3346124/how-do-i-force-django-to-ignore-any-caches-and-reload-data + # http://devblog.resolversystems.com/?p=439 + # http://groups.google.com/group/django-users/browse_thread/thread/e25cec400598c06d + # http://stackoverflow.com/questions/1028671/python-mysqldb-update-query-fails + # http://dev.mysql.com/doc/refman/5.0/en/set-transaction.html + # + # It turns out that MySQL is doing us a favor with this weird + # transactional behavior on read, because without it there's a + # race condition if multiple updates are committed to the IRDB + # while we're in the middle of processing a query. Note that + # proper transaction management by the committers doesn't protect + # us, this is a transactional problem on read. So we need to use + # explicit transaction management. Since irdbd is a read-only + # consumer of IRDB data, this means we need to commit an empty + # transaction at the beginning of processing each query, to reset + # the transaction isolation snapshot. + + import django.db.transaction + + with django.db.transaction.atomic(): + #django.db.transaction.commit() + pass diff --git a/rpki/left_right.py b/rpki/left_right.py index 387e908f..3572ee98 100644 --- a/rpki/left_right.py +++ b/rpki/left_right.py @@ -71,9 +71,9 @@ allowed_content_types = (content_type,) class cms_msg(rpki.x509.XML_CMS_object): - """ - CMS-signed left-right PDU. - """ + """ + CMS-signed left-right PDU. + """ - encoding = "us-ascii" - schema = rpki.relaxng.left_right + encoding = "us-ascii" + schema = rpki.relaxng.left_right diff --git a/rpki/log.py b/rpki/log.py index 828982da..8afee4ba 100644 --- a/rpki/log.py +++ b/rpki/log.py @@ -30,12 +30,12 @@ import argparse import traceback as tb try: - have_setproctitle = False - if os.getenv("DISABLE_SETPROCTITLE") is None: - import setproctitle # pylint: disable=F0401 - have_setproctitle = True + have_setproctitle = False + if os.getenv("DISABLE_SETPROCTITLE") is None: + import setproctitle # pylint: disable=F0401 + have_setproctitle = True except ImportError: - pass + pass logger = logging.getLogger(__name__) @@ -67,234 +67,234 @@ proctitle_extra = os.path.basename(os.getcwd()) class Formatter(object): - """ - Reimplementation (easier than subclassing in this case) of - logging.Formatter. - - It turns out that the logging code only cares about this class's - .format(record) method, everything else is internal; so long as - .format() converts a record into a properly formatted string, the - logging code is happy. - - So, rather than mess around with dynamically constructing and - deconstructing and tweaking format strings and ten zillion options - we don't use, we just provide our own implementation that supports - what we do need. - """ - - converter = time.gmtime - - def __init__(self, ident, handler): - self.ident = ident - self.is_syslog = isinstance(handler, logging.handlers.SysLogHandler) - - def format(self, record): - return "".join(self.coformat(record)).rstrip("\n") - - def coformat(self, record): - - try: - if not self.is_syslog: - yield time.strftime("%Y-%m-%d %H:%M:%S ", time.gmtime(record.created)) - except: # pylint: disable=W0702 - yield "[$!$Time format failed]" - - try: - yield "%s[%d]: " % (self.ident, record.process) - except: # pylint: disable=W0702 - yield "[$!$ident format failed]" - - try: - if isinstance(record.context, (str, unicode)): - yield record.context + " " - else: - yield repr(record.context) + " " - except AttributeError: - pass - except: # pylint: disable=W0702 - yield "[$!$context format failed]" - - try: - yield record.getMessage() - except: # pylint: disable=W0702 - yield "[$!$record.getMessage() failed]" - - try: - if record.exc_info: - if self.is_syslog or not enable_tracebacks: - lines = tb.format_exception_only(record.exc_info[0], record.exc_info[1]) - lines.insert(0, ": ") - else: - lines = tb.format_exception(record.exc_info[0], record.exc_info[1], record.exc_info[2]) - lines.insert(0, "\n") - for line in lines: - yield line - except: # pylint: disable=W0702 - yield "[$!$exception formatting failed]" + """ + Reimplementation (easier than subclassing in this case) of + logging.Formatter. + + It turns out that the logging code only cares about this class's + .format(record) method, everything else is internal; so long as + .format() converts a record into a properly formatted string, the + logging code is happy. + + So, rather than mess around with dynamically constructing and + deconstructing and tweaking format strings and ten zillion options + we don't use, we just provide our own implementation that supports + what we do need. + """ + + converter = time.gmtime + + def __init__(self, ident, handler): + self.ident = ident + self.is_syslog = isinstance(handler, logging.handlers.SysLogHandler) + + def format(self, record): + return "".join(self.coformat(record)).rstrip("\n") + + def coformat(self, record): + + try: + if not self.is_syslog: + yield time.strftime("%Y-%m-%d %H:%M:%S ", time.gmtime(record.created)) + except: # pylint: disable=W0702 + yield "[$!$Time format failed]" + + try: + yield "%s[%d]: " % (self.ident, record.process) + except: # pylint: disable=W0702 + yield "[$!$ident format failed]" + + try: + if isinstance(record.context, (str, unicode)): + yield record.context + " " + else: + yield repr(record.context) + " " + except AttributeError: + pass + except: # pylint: disable=W0702 + yield "[$!$context format failed]" + + try: + yield record.getMessage() + except: # pylint: disable=W0702 + yield "[$!$record.getMessage() failed]" + + try: + if record.exc_info: + if self.is_syslog or not enable_tracebacks: + lines = tb.format_exception_only(record.exc_info[0], record.exc_info[1]) + lines.insert(0, ": ") + else: + lines = tb.format_exception(record.exc_info[0], record.exc_info[1], record.exc_info[2]) + lines.insert(0, "\n") + for line in lines: + yield line + except: # pylint: disable=W0702 + yield "[$!$exception formatting failed]" def argparse_setup(parser, default_thunk = None): - """ - Set up argparse stuff for functionality in this module. + """ + Set up argparse stuff for functionality in this module. - Default logging destination is syslog, but you can change this - by setting default_thunk to a callable which takes no arguments - and which returns a instance of a logging.Handler subclass. + Default logging destination is syslog, but you can change this + by setting default_thunk to a callable which takes no arguments + and which returns a instance of a logging.Handler subclass. - Also see rpki.log.init(). - """ + Also see rpki.log.init(). + """ - class LogLevelAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string = None): - setattr(namespace, self.dest, getattr(logging, values.upper())) + class LogLevelAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string = None): + setattr(namespace, self.dest, getattr(logging, values.upper())) - parser.add_argument("--log-level", default = logging.WARNING, action = LogLevelAction, - choices = ("debug", "info", "warning", "error", "critical"), - help = "how verbosely to log") + parser.add_argument("--log-level", default = logging.WARNING, action = LogLevelAction, + choices = ("debug", "info", "warning", "error", "critical"), + help = "how verbosely to log") - group = parser.add_mutually_exclusive_group() + group = parser.add_mutually_exclusive_group() - syslog_address = "/dev/log" if os.path.exists("/dev/log") else ("localhost", logging.handlers.SYSLOG_UDP_PORT) + syslog_address = "/dev/log" if os.path.exists("/dev/log") else ("localhost", logging.handlers.SYSLOG_UDP_PORT) - class SyslogAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string = None): - namespace.log_handler = lambda: logging.handlers.SysLogHandler(address = syslog_address, facility = values) + class SyslogAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string = None): + namespace.log_handler = lambda: logging.handlers.SysLogHandler(address = syslog_address, facility = values) - group.add_argument("--log-syslog", nargs = "?", const = "daemon", action = SyslogAction, - choices = sorted(logging.handlers.SysLogHandler.facility_names.keys()), - help = "send logging to syslog") + group.add_argument("--log-syslog", nargs = "?", const = "daemon", action = SyslogAction, + choices = sorted(logging.handlers.SysLogHandler.facility_names.keys()), + help = "send logging to syslog") - class StreamAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string = None): - namespace.log_handler = lambda: logging.StreamHandler(stream = self.const) + class StreamAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string = None): + namespace.log_handler = lambda: logging.StreamHandler(stream = self.const) - group.add_argument("--log-stderr", nargs = 0, action = StreamAction, const = sys.stderr, - help = "send logging to standard error") + group.add_argument("--log-stderr", nargs = 0, action = StreamAction, const = sys.stderr, + help = "send logging to standard error") - group.add_argument("--log-stdout", nargs = 0, action = StreamAction, const = sys.stdout, - help = "send logging to standard output") + group.add_argument("--log-stdout", nargs = 0, action = StreamAction, const = sys.stdout, + help = "send logging to standard output") - class WatchedFileAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string = None): - namespace.log_handler = lambda: logging.handlers.WatchedFileHandler(filename = values) + class WatchedFileAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string = None): + namespace.log_handler = lambda: logging.handlers.WatchedFileHandler(filename = values) - group.add_argument("--log-file", action = WatchedFileAction, - help = "send logging to a file, reopening if rotated away") + group.add_argument("--log-file", action = WatchedFileAction, + help = "send logging to a file, reopening if rotated away") - class RotatingFileAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string = None): - namespace.log_handler = lambda: logging.handlers.RotatingFileHandler( - filename = values[0], - maxBytes = int(values[1]) * 1024, - backupCount = int(values[2])) + class RotatingFileAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string = None): + namespace.log_handler = lambda: logging.handlers.RotatingFileHandler( + filename = values[0], + maxBytes = int(values[1]) * 1024, + backupCount = int(values[2])) - group.add_argument("--log-rotating-file", action = RotatingFileAction, - nargs = 3, metavar = ("FILENAME", "KBYTES", "COUNT"), - help = "send logging to rotating file") + group.add_argument("--log-rotating-file", action = RotatingFileAction, + nargs = 3, metavar = ("FILENAME", "KBYTES", "COUNT"), + help = "send logging to rotating file") - class TimedRotatingFileAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string = None): - namespace.log_handler = lambda: logging.handlers.TimedRotatingFileHandler( - filename = values[0], - interval = int(values[1]), - backupCount = int(values[2]), - when = "H", - utc = True) + class TimedRotatingFileAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string = None): + namespace.log_handler = lambda: logging.handlers.TimedRotatingFileHandler( + filename = values[0], + interval = int(values[1]), + backupCount = int(values[2]), + when = "H", + utc = True) - group.add_argument("--log-timed-rotating-file", action = TimedRotatingFileAction, - nargs = 3, metavar = ("FILENAME", "HOURS", "COUNT"), - help = "send logging to timed rotating file") + group.add_argument("--log-timed-rotating-file", action = TimedRotatingFileAction, + nargs = 3, metavar = ("FILENAME", "HOURS", "COUNT"), + help = "send logging to timed rotating file") - if default_thunk is None: - default_thunk = lambda: logging.handlers.SysLogHandler(address = syslog_address, facility = "daemon") + if default_thunk is None: + default_thunk = lambda: logging.handlers.SysLogHandler(address = syslog_address, facility = "daemon") - parser.set_defaults(log_handler = default_thunk) + parser.set_defaults(log_handler = default_thunk) def init(ident = None, args = None): - """ - Initialize logging system. + """ + Initialize logging system. - Default logging destination is stderr if "args" is not specified. - """ + Default logging destination is stderr if "args" is not specified. + """ - # pylint: disable=E1103 + # pylint: disable=E1103 - if ident is None: - ident = os.path.basename(sys.argv[0]) + if ident is None: + ident = os.path.basename(sys.argv[0]) - if args is None: - args = argparse.Namespace(log_level = logging.WARNING, - log_handler = logging.StreamHandler) + if args is None: + args = argparse.Namespace(log_level = logging.WARNING, + log_handler = logging.StreamHandler) - handler = args.log_handler() - handler.setFormatter(Formatter(ident, handler)) + handler = args.log_handler() + handler.setFormatter(Formatter(ident, handler)) - root_logger = logging.getLogger() - root_logger.addHandler(handler) - root_logger.setLevel(args.log_level) + root_logger = logging.getLogger() + root_logger.addHandler(handler) + root_logger.setLevel(args.log_level) - if ident and have_setproctitle and use_setproctitle: - if proctitle_extra: - setproctitle.setproctitle("%s (%s)" % (ident, proctitle_extra)) - else: - setproctitle.setproctitle(ident) + if ident and have_setproctitle and use_setproctitle: + if proctitle_extra: + setproctitle.setproctitle("%s (%s)" % (ident, proctitle_extra)) + else: + setproctitle.setproctitle(ident) def class_logger(module_logger, attribute = "logger"): - """ - Class decorator to add a class-level Logger object as a class - attribute. This allows control of debugging messages at the class - level rather than just the module level. + """ + Class decorator to add a class-level Logger object as a class + attribute. This allows control of debugging messages at the class + level rather than just the module level. - This decorator takes the module logger as an argument. - """ + This decorator takes the module logger as an argument. + """ - def decorator(cls): - setattr(cls, attribute, module_logger.getChild(cls.__name__)) - return cls - return decorator + def decorator(cls): + setattr(cls, attribute, module_logger.getChild(cls.__name__)) + return cls + return decorator def log_repr(obj, *tokens): - """ - Constructor for __repr__() strings, handles suppression of Python - IDs as needed, includes tenant_handle when available. - """ + """ + Constructor for __repr__() strings, handles suppression of Python + IDs as needed, includes tenant_handle when available. + """ - # pylint: disable=W0702 + # pylint: disable=W0702 - words = ["%s.%s" % (obj.__class__.__module__, obj.__class__.__name__)] - try: - words.append("{%s}" % obj.tenant.tenant_handle) - except: - pass + words = ["%s.%s" % (obj.__class__.__module__, obj.__class__.__name__)] + try: + words.append("{%s}" % obj.tenant.tenant_handle) + except: + pass - for token in tokens: - if token is not None: - try: - s = str(token) - except: - s = "???" - logger.exception("Failed to generate repr() string for object of type %r", type(token)) - if s: - words.append(s) + for token in tokens: + if token is not None: + try: + s = str(token) + except: + s = "???" + logger.exception("Failed to generate repr() string for object of type %r", type(token)) + if s: + words.append(s) - if show_python_ids: - words.append(" at %#x" % id(obj)) + if show_python_ids: + words.append(" at %#x" % id(obj)) - return "<" + " ".join(words) + ">" + return "<" + " ".join(words) + ">" def show_stack(stack_logger = None): - """ - Log a stack trace. - """ + """ + Log a stack trace. + """ - if stack_logger is None: - stack_logger = logger + if stack_logger is None: + stack_logger = logger - for frame in tb.format_stack(): - for line in frame.split("\n"): - if line: - stack_logger.debug("%s", line.rstrip()) + for frame in tb.format_stack(): + for line in frame.split("\n"): + if line: + stack_logger.debug("%s", line.rstrip()) diff --git a/rpki/myrpki.py b/rpki/myrpki.py index 2ae912f0..929c2a70 100644 --- a/rpki/myrpki.py +++ b/rpki/myrpki.py @@ -19,5 +19,5 @@ This is a tombstone for a program that no longer exists. """ if __name__ != "__main__": # sic -- don't break regression tests - import sys - sys.exit('"myrpki" is obsolete. Please use "rpkic" instead.') + import sys + sys.exit('"myrpki" is obsolete. Please use "rpkic" instead.') diff --git a/rpki/mysql_import.py b/rpki/mysql_import.py index 538e1916..bbb7ac22 100644 --- a/rpki/mysql_import.py +++ b/rpki/mysql_import.py @@ -52,11 +52,11 @@ from __future__ import with_statement import warnings if hasattr(warnings, "catch_warnings"): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - import MySQLdb + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + import MySQLdb else: - import MySQLdb + import MySQLdb import _mysql_exceptions diff --git a/rpki/oids.py b/rpki/oids.py index afb95020..abc928bc 100644 --- a/rpki/oids.py +++ b/rpki/oids.py @@ -82,22 +82,22 @@ id_sha256 = "2.16.840.1.101.3.4.2.1" _oid2name = {} for _sym in dir(): - if not _sym.startswith("_"): - _val = globals()[_sym] - if not isinstance(_val, str) or not all(_v.isdigit() for _v in _val.split(".")): - raise ValueError("Bad OID definition: %s = %r" % (_sym, _val)) - _oid2name[_val] = _sym.replace("_", "-") + if not _sym.startswith("_"): + _val = globals()[_sym] + if not isinstance(_val, str) or not all(_v.isdigit() for _v in _val.split(".")): + raise ValueError("Bad OID definition: %s = %r" % (_sym, _val)) + _oid2name[_val] = _sym.replace("_", "-") # pylint: disable=W0631 del _sym del _val def oid2name(oid): - """ - Translate an OID into a string suitable for printing. - """ + """ + Translate an OID into a string suitable for printing. + """ - if not isinstance(oid, (str, unicode)) or not all(o.isdigit() for o in oid.split(".")): - raise ValueError("Parameter does not look like an OID string: " + repr(oid)) + if not isinstance(oid, (str, unicode)) or not all(o.isdigit() for o in oid.split(".")): + raise ValueError("Parameter does not look like an OID string: " + repr(oid)) - return _oid2name.get(oid, oid) + return _oid2name.get(oid, oid) diff --git a/rpki/old_irdbd.py b/rpki/old_irdbd.py index 9294ee84..fca1f1d9 100644 --- a/rpki/old_irdbd.py +++ b/rpki/old_irdbd.py @@ -46,270 +46,270 @@ logger = logging.getLogger(__name__) class main(object): - def handle_list_resources(self, q_pdu, r_msg): - - r_pdu = rpki.left_right.list_resources_elt() - r_pdu.tag = q_pdu.tag - r_pdu.self_handle = q_pdu.self_handle - r_pdu.child_handle = q_pdu.child_handle - - self.cur.execute( - """ - SELECT registrant_id, valid_until - FROM registrant - WHERE registry_handle = %s AND registrant_handle = %s - """, - (q_pdu.self_handle, q_pdu.child_handle)) - - if self.cur.rowcount != 1: - raise rpki.exceptions.NotInDatabase( - "This query should have produced a single exact match, something's messed up" - " (rowcount = %d, self_handle = %s, child_handle = %s)" - % (self.cur.rowcount, q_pdu.self_handle, q_pdu.child_handle)) - - registrant_id, valid_until = self.cur.fetchone() - - r_pdu.valid_until = valid_until.strftime("%Y-%m-%dT%H:%M:%SZ") - - r_pdu.asn = rpki.resource_set.resource_set_as.from_sql( - self.cur, - """ - SELECT start_as, end_as - FROM registrant_asn - WHERE registrant_id = %s - """, - (registrant_id,)) - - r_pdu.ipv4 = rpki.resource_set.resource_set_ipv4.from_sql( - self.cur, - """ - SELECT start_ip, end_ip - FROM registrant_net - WHERE registrant_id = %s AND version = 4 - """, - (registrant_id,)) - - r_pdu.ipv6 = rpki.resource_set.resource_set_ipv6.from_sql( - self.cur, - """ - SELECT start_ip, end_ip - FROM registrant_net - WHERE registrant_id = %s AND version = 6 - """, - (registrant_id,)) - - r_msg.append(r_pdu) - - - def handle_list_roa_requests(self, q_pdu, r_msg): - - self.cur.execute( - "SELECT roa_request_id, asn FROM roa_request WHERE self_handle = %s", - (q_pdu.self_handle,)) - - for roa_request_id, asn in self.cur.fetchall(): - - r_pdu = rpki.left_right.list_roa_requests_elt() - r_pdu.tag = q_pdu.tag - r_pdu.self_handle = q_pdu.self_handle - r_pdu.asn = asn - - r_pdu.ipv4 = rpki.resource_set.roa_prefix_set_ipv4.from_sql( - self.cur, - """ - SELECT prefix, prefixlen, max_prefixlen - FROM roa_request_prefix - WHERE roa_request_id = %s AND version = 4 - """, - (roa_request_id,)) - - r_pdu.ipv6 = rpki.resource_set.roa_prefix_set_ipv6.from_sql( - self.cur, - """ - SELECT prefix, prefixlen, max_prefixlen - FROM roa_request_prefix - WHERE roa_request_id = %s AND version = 6 - """, - (roa_request_id,)) - - r_msg.append(r_pdu) - - - def handle_list_ghostbuster_requests(self, q_pdu, r_msg): - - self.cur.execute( - """ - SELECT vcard - FROM ghostbuster_request - WHERE self_handle = %s AND parent_handle = %s - """, - (q_pdu.self_handle, q_pdu.parent_handle)) - - vcards = [result[0] for result in self.cur.fetchall()] - - if not vcards: - - self.cur.execute( - """ - SELECT vcard - FROM ghostbuster_request - WHERE self_handle = %s AND parent_handle IS NULL - """, - (q_pdu.self_handle,)) - - vcards = [result[0] for result in self.cur.fetchall()] - - for vcard in vcards: - r_pdu = rpki.left_right.list_ghostbuster_requests_elt() - r_pdu.tag = q_pdu.tag - r_pdu.self_handle = q_pdu.self_handle - r_pdu.parent_handle = q_pdu.parent_handle - r_pdu.vcard = vcard - r_msg.append(r_pdu) - - - def handle_list_ee_certificate_requests(self, q_pdu, r_msg): - - self.cur.execute( - """ - SELECT ee_certificate_id, pkcs10, gski, cn, sn, eku, valid_until - FROM ee_certificate - WHERE self_handle = %s - """, - (q_pdu.self_handle,)) - - for ee_certificate_id, pkcs10, gski, cn, sn, eku, valid_until in self.cur.fetchall(): - - r_pdu = rpki.left_right.list_ee_certificate_requests_elt() - r_pdu.tag = q_pdu.tag - r_pdu.self_handle = q_pdu.self_handle - r_pdu.valid_until = valid_until.strftime("%Y-%m-%dT%H:%M:%SZ") - r_pdu.pkcs10 = rpki.x509.PKCS10(DER = pkcs10) - r_pdu.gski = gski - r_pdu.cn = cn - r_pdu.sn = sn - r_pdu.eku = eku - - r_pdu.asn = rpki.resource_set.resource_set_as.from_sql( - self.cur, - """ - SELECT start_as, end_as - FROM ee_certificate_asn - WHERE ee_certificate_id = %s - """, - (ee_certificate_id,)) + def handle_list_resources(self, q_pdu, r_msg): + + r_pdu = rpki.left_right.list_resources_elt() + r_pdu.tag = q_pdu.tag + r_pdu.self_handle = q_pdu.self_handle + r_pdu.child_handle = q_pdu.child_handle + + self.cur.execute( + """ + SELECT registrant_id, valid_until + FROM registrant + WHERE registry_handle = %s AND registrant_handle = %s + """, + (q_pdu.self_handle, q_pdu.child_handle)) + + if self.cur.rowcount != 1: + raise rpki.exceptions.NotInDatabase( + "This query should have produced a single exact match, something's messed up" + " (rowcount = %d, self_handle = %s, child_handle = %s)" + % (self.cur.rowcount, q_pdu.self_handle, q_pdu.child_handle)) + + registrant_id, valid_until = self.cur.fetchone() + + r_pdu.valid_until = valid_until.strftime("%Y-%m-%dT%H:%M:%SZ") + + r_pdu.asn = rpki.resource_set.resource_set_as.from_sql( + self.cur, + """ + SELECT start_as, end_as + FROM registrant_asn + WHERE registrant_id = %s + """, + (registrant_id,)) + + r_pdu.ipv4 = rpki.resource_set.resource_set_ipv4.from_sql( + self.cur, + """ + SELECT start_ip, end_ip + FROM registrant_net + WHERE registrant_id = %s AND version = 4 + """, + (registrant_id,)) + + r_pdu.ipv6 = rpki.resource_set.resource_set_ipv6.from_sql( + self.cur, + """ + SELECT start_ip, end_ip + FROM registrant_net + WHERE registrant_id = %s AND version = 6 + """, + (registrant_id,)) + + r_msg.append(r_pdu) + + + def handle_list_roa_requests(self, q_pdu, r_msg): + + self.cur.execute( + "SELECT roa_request_id, asn FROM roa_request WHERE self_handle = %s", + (q_pdu.self_handle,)) + + for roa_request_id, asn in self.cur.fetchall(): + + r_pdu = rpki.left_right.list_roa_requests_elt() + r_pdu.tag = q_pdu.tag + r_pdu.self_handle = q_pdu.self_handle + r_pdu.asn = asn + + r_pdu.ipv4 = rpki.resource_set.roa_prefix_set_ipv4.from_sql( + self.cur, + """ + SELECT prefix, prefixlen, max_prefixlen + FROM roa_request_prefix + WHERE roa_request_id = %s AND version = 4 + """, + (roa_request_id,)) + + r_pdu.ipv6 = rpki.resource_set.roa_prefix_set_ipv6.from_sql( + self.cur, + """ + SELECT prefix, prefixlen, max_prefixlen + FROM roa_request_prefix + WHERE roa_request_id = %s AND version = 6 + """, + (roa_request_id,)) + + r_msg.append(r_pdu) + + + def handle_list_ghostbuster_requests(self, q_pdu, r_msg): + + self.cur.execute( + """ + SELECT vcard + FROM ghostbuster_request + WHERE self_handle = %s AND parent_handle = %s + """, + (q_pdu.self_handle, q_pdu.parent_handle)) + + vcards = [result[0] for result in self.cur.fetchall()] + + if not vcards: + + self.cur.execute( + """ + SELECT vcard + FROM ghostbuster_request + WHERE self_handle = %s AND parent_handle IS NULL + """, + (q_pdu.self_handle,)) + + vcards = [result[0] for result in self.cur.fetchall()] + + for vcard in vcards: + r_pdu = rpki.left_right.list_ghostbuster_requests_elt() + r_pdu.tag = q_pdu.tag + r_pdu.self_handle = q_pdu.self_handle + r_pdu.parent_handle = q_pdu.parent_handle + r_pdu.vcard = vcard + r_msg.append(r_pdu) + + + def handle_list_ee_certificate_requests(self, q_pdu, r_msg): + + self.cur.execute( + """ + SELECT ee_certificate_id, pkcs10, gski, cn, sn, eku, valid_until + FROM ee_certificate + WHERE self_handle = %s + """, + (q_pdu.self_handle,)) + + for ee_certificate_id, pkcs10, gski, cn, sn, eku, valid_until in self.cur.fetchall(): + + r_pdu = rpki.left_right.list_ee_certificate_requests_elt() + r_pdu.tag = q_pdu.tag + r_pdu.self_handle = q_pdu.self_handle + r_pdu.valid_until = valid_until.strftime("%Y-%m-%dT%H:%M:%SZ") + r_pdu.pkcs10 = rpki.x509.PKCS10(DER = pkcs10) + r_pdu.gski = gski + r_pdu.cn = cn + r_pdu.sn = sn + r_pdu.eku = eku + + r_pdu.asn = rpki.resource_set.resource_set_as.from_sql( + self.cur, + """ + SELECT start_as, end_as + FROM ee_certificate_asn + WHERE ee_certificate_id = %s + """, + (ee_certificate_id,)) + + r_pdu.ipv4 = rpki.resource_set.resource_set_ipv4.from_sql( + self.cur, + """ + SELECT start_ip, end_ip + FROM ee_certificate_net + WHERE ee_certificate_id = %s AND version = 4 + """, + (ee_certificate_id,)) + + r_pdu.ipv6 = rpki.resource_set.resource_set_ipv6.from_sql( + self.cur, + """ + SELECT start_ip, end_ip + FROM ee_certificate_net + WHERE ee_certificate_id = %s AND version = 6 + """, + (ee_certificate_id,)) + + r_msg.append(r_pdu) + + + handle_dispatch = { + rpki.left_right.list_resources_elt : handle_list_resources, + rpki.left_right.list_roa_requests_elt : handle_list_roa_requests, + rpki.left_right.list_ghostbuster_requests_elt : handle_list_ghostbuster_requests, + rpki.left_right.list_ee_certificate_requests_elt : handle_list_ee_certificate_requests } + + def handler(self, request, q_der): + try: + + self.db.ping(True) + + r_msg = rpki.left_right.msg.reply() - r_pdu.ipv4 = rpki.resource_set.resource_set_ipv4.from_sql( - self.cur, - """ - SELECT start_ip, end_ip - FROM ee_certificate_net - WHERE ee_certificate_id = %s AND version = 4 - """, - (ee_certificate_id,)) - - r_pdu.ipv6 = rpki.resource_set.resource_set_ipv6.from_sql( - self.cur, - """ - SELECT start_ip, end_ip - FROM ee_certificate_net - WHERE ee_certificate_id = %s AND version = 6 - """, - (ee_certificate_id,)) - - r_msg.append(r_pdu) - - - handle_dispatch = { - rpki.left_right.list_resources_elt : handle_list_resources, - rpki.left_right.list_roa_requests_elt : handle_list_roa_requests, - rpki.left_right.list_ghostbuster_requests_elt : handle_list_ghostbuster_requests, - rpki.left_right.list_ee_certificate_requests_elt : handle_list_ee_certificate_requests } - - def handler(self, request, q_der): - try: - - self.db.ping(True) + try: - r_msg = rpki.left_right.msg.reply() + q_msg = rpki.left_right.cms_msg_saxify(DER = q_der).unwrap((self.bpki_ta, self.rpkid_cert)) - try: + if not isinstance(q_msg, rpki.left_right.msg) or not q_msg.is_query(): + raise rpki.exceptions.BadQuery("Unexpected %r PDU" % q_msg) - q_msg = rpki.left_right.cms_msg_saxify(DER = q_der).unwrap((self.bpki_ta, self.rpkid_cert)) - - if not isinstance(q_msg, rpki.left_right.msg) or not q_msg.is_query(): - raise rpki.exceptions.BadQuery("Unexpected %r PDU" % q_msg) - - for q_pdu in q_msg: + for q_pdu in q_msg: - try: + try: - try: - h = self.handle_dispatch[type(q_pdu)] - except KeyError: - raise rpki.exceptions.BadQuery("Unexpected %r PDU" % q_pdu) - else: - h(self, q_pdu, r_msg) + try: + h = self.handle_dispatch[type(q_pdu)] + except KeyError: + raise rpki.exceptions.BadQuery("Unexpected %r PDU" % q_pdu) + else: + h(self, q_pdu, r_msg) - except Exception, e: - logger.exception("Exception serving PDU %r", q_pdu) - r_msg.append(rpki.left_right.report_error_elt.from_exception(e, q_pdu.self_handle, q_pdu.tag)) + except Exception, e: + logger.exception("Exception serving PDU %r", q_pdu) + r_msg.append(rpki.left_right.report_error_elt.from_exception(e, q_pdu.self_handle, q_pdu.tag)) - except Exception, e: - logger.exception("Exception decoding query") - r_msg.append(rpki.left_right.report_error_elt.from_exception(e)) + except Exception, e: + logger.exception("Exception decoding query") + r_msg.append(rpki.left_right.report_error_elt.from_exception(e)) - request.send_cms_response(rpki.left_right.cms_msg_saxify().wrap(r_msg, self.irdbd_key, self.irdbd_cert)) + request.send_cms_response(rpki.left_right.cms_msg_saxify().wrap(r_msg, self.irdbd_key, self.irdbd_cert)) - except Exception, e: - logger.exception("Unhandled exception, returning HTTP failure") - request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + except Exception, e: + logger.exception("Unhandled exception, returning HTTP failure") + request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) - def __init__(self): + def __init__(self): - os.environ["TZ"] = "UTC" - time.tzset() + os.environ["TZ"] = "UTC" + time.tzset() - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument("-c", "--config", - help = "override default location of configuration file") - parser.add_argument("-f", "--foreground", action = "store_true", - help = "do not daemonize (ignored, old_irdbd never daemonizes)") - rpki.log.argparse_setup(parser) - args = parser.parse_args() + parser = argparse.ArgumentParser(description = __doc__) + parser.add_argument("-c", "--config", + help = "override default location of configuration file") + parser.add_argument("-f", "--foreground", action = "store_true", + help = "do not daemonize (ignored, old_irdbd never daemonizes)") + rpki.log.argparse_setup(parser) + args = parser.parse_args() - rpki.log.init("irdbd", args) + rpki.log.init("irdbd", args) - self.cfg = rpki.config.parser(set_filename = args.config, section = "irdbd") + self.cfg = rpki.config.parser(set_filename = args.config, section = "irdbd") - startup_msg = self.cfg.get("startup-message", "") - if startup_msg: - logger.info(startup_msg) + startup_msg = self.cfg.get("startup-message", "") + if startup_msg: + logger.info(startup_msg) - self.cfg.set_global_flags() + self.cfg.set_global_flags() - self.db = MySQLdb.connect(user = self.cfg.get("sql-username"), - db = self.cfg.get("sql-database"), - passwd = self.cfg.get("sql-password")) + self.db = MySQLdb.connect(user = self.cfg.get("sql-username"), + db = self.cfg.get("sql-database"), + passwd = self.cfg.get("sql-password")) - self.cur = self.db.cursor() - self.db.autocommit(True) + self.cur = self.db.cursor() + self.db.autocommit(True) - self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) - self.rpkid_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpkid-cert")) - self.irdbd_cert = rpki.x509.X509(Auto_update = self.cfg.get("irdbd-cert")) - self.irdbd_key = rpki.x509.RSA( Auto_update = self.cfg.get("irdbd-key")) + self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) + self.rpkid_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpkid-cert")) + self.irdbd_cert = rpki.x509.X509(Auto_update = self.cfg.get("irdbd-cert")) + self.irdbd_key = rpki.x509.RSA( Auto_update = self.cfg.get("irdbd-key")) - u = urlparse.urlparse(self.cfg.get("http-url")) + u = urlparse.urlparse(self.cfg.get("http-url")) - assert u.scheme in ("", "http") and \ - u.username is None and \ - u.password is None and \ - u.params == "" and \ - u.query == "" and \ - u.fragment == "" + assert u.scheme in ("", "http") and \ + u.username is None and \ + u.password is None and \ + u.params == "" and \ + u.query == "" and \ + u.fragment == "" - rpki.http_simple.server(host = u.hostname or "localhost", - port = u.port or 443, - handlers = ((u.path, self.handler),)) + rpki.http_simple.server(host = u.hostname or "localhost", + port = u.port or 443, + handlers = ((u.path, self.handler),)) diff --git a/rpki/pubd.py b/rpki/pubd.py index f917c18d..ee258f26 100644 --- a/rpki/pubd.py +++ b/rpki/pubd.py @@ -45,252 +45,252 @@ logger = logging.getLogger(__name__) class main(object): - """ - Main program for pubd. - """ - - def __init__(self): - - os.environ.update(TZ = "UTC", - DJANGO_SETTINGS_MODULE = "rpki.django_settings.pubd") - time.tzset() - - self.irbe_cms_timestamp = None - - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument("-c", "--config", - help = "override default location of configuration file") - parser.add_argument("-f", "--foreground", action = "store_true", - help = "do not daemonize") - parser.add_argument("--pidfile", - help = "override default location of pid file") - parser.add_argument("--profile", - help = "enable profiling, saving data to PROFILE") - rpki.log.argparse_setup(parser) - args = parser.parse_args() - - self.profile = args.profile - - rpki.log.init("pubd", args) - - self.cfg = rpki.config.parser(set_filename = args.config, section = "pubd") - self.cfg.set_global_flags() - - if not args.foreground: - rpki.daemonize.daemon(pidfile = args.pidfile) - - if self.profile: - import cProfile - prof = cProfile.Profile() - try: - prof.runcall(self.main) - finally: - prof.dump_stats(self.profile) - logger.info("Dumped profile data to %s", self.profile) - else: - self.main() - - def main(self): - - if self.profile: - logger.info("Running in profile mode with output to %s", self.profile) - - import django - django.setup() - - global rpki # pylint: disable=W0602 - import rpki.pubdb # pylint: disable=W0621 - - self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) - self.irbe_cert = rpki.x509.X509(Auto_update = self.cfg.get("irbe-cert")) - self.pubd_cert = rpki.x509.X509(Auto_update = self.cfg.get("pubd-cert")) - self.pubd_key = rpki.x509.RSA( Auto_update = self.cfg.get("pubd-key")) - self.pubd_crl = rpki.x509.CRL( Auto_update = self.cfg.get("pubd-crl")) - - self.http_server_host = self.cfg.get("server-host", "") - self.http_server_port = self.cfg.getint("server-port") - - self.publication_base = self.cfg.get("publication-base", "publication/") - - self.rrdp_uri_base = self.cfg.get("rrdp-uri-base", - "http://%s/rrdp/" % socket.getfqdn()) - self.rrdp_expiration_interval = rpki.sundial.timedelta.parse(self.cfg.get("rrdp-expiration-interval", "6h")) - self.rrdp_publication_base = self.cfg.get("rrdp-publication-base", - "rrdp-publication/") - - try: - self.session = rpki.pubdb.models.Session.objects.get() - except rpki.pubdb.models.Session.DoesNotExist: - self.session = rpki.pubdb.models.Session.objects.create(uuid = str(uuid.uuid4()), serial = 0) - - rpki.http_simple.server( - host = self.http_server_host, - port = self.http_server_port, - handlers = (("/control", self.control_handler), - ("/client/", self.client_handler))) - - - def control_handler(self, request, q_der): - """ - Process one PDU from the IRBE. """ - - from django.db import transaction, connection - - try: - connection.cursor() # Reconnect to mysqld if necessary - q_cms = rpki.publication_control.cms_msg(DER = q_der) - q_msg = q_cms.unwrap((self.bpki_ta, self.irbe_cert)) - self.irbe_cms_timestamp = q_cms.check_replay(self.irbe_cms_timestamp, "control") - if q_msg.get("type") != "query": - raise rpki.exceptions.BadQuery("Message type is %s, expected query" % q_msg.get("type")) - r_msg = Element(rpki.publication_control.tag_msg, nsmap = rpki.publication_control.nsmap, - type = "reply", version = rpki.publication_control.version) - - try: - q_pdu = None - with transaction.atomic(): - - for q_pdu in q_msg: - if q_pdu.tag != rpki.publication_control.tag_client: - raise rpki.exceptions.BadQuery("PDU is %s, expected client" % q_pdu.tag) - client_handle = q_pdu.get("client_handle") - action = q_pdu.get("action") - if client_handle is None: - logger.info("Control %s request", action) - else: - logger.info("Control %s request for %s", action, client_handle) - - if action in ("get", "list"): - if action == "get": - clients = rpki.pubdb.models.Client.objects.get(client_handle = client_handle), - else: - clients = rpki.pubdb.models.Client.objects.all() - for client in clients: - r_pdu = SubElement(r_msg, q_pdu.tag, action = action, - client_handle = client.client_handle, base_uri = client.base_uri) - if q_pdu.get("tag"): - r_pdu.set("tag", q_pdu.get("tag")) - SubElement(r_pdu, rpki.publication_control.tag_bpki_cert).text = client.bpki_cert.get_Base64() - if client.bpki_glue is not None: - SubElement(r_pdu, rpki.publication_control.tag_bpki_glue).text = client.bpki_glue.get_Base64() - - if action in ("create", "set"): - if action == "create": - client = rpki.pubdb.models.Client(client_handle = client_handle) - else: - client = rpki.pubdb.models.Client.objects.get(client_handle = client_handle) - if q_pdu.get("base_uri"): - client.base_uri = q_pdu.get("base_uri") - bpki_cert = q_pdu.find(rpki.publication_control.tag_bpki_cert) - if bpki_cert is not None: - client.bpki_cert = rpki.x509.X509(Base64 = bpki_cert.text) - bpki_glue = q_pdu.find(rpki.publication_control.tag_bpki_glue) - if bpki_glue is not None: - client.bpki_glue = rpki.x509.X509(Base64 = bpki_glue.text) - if q_pdu.get("clear_replay_protection") == "yes": - client.last_cms_timestamp = None - client.save() - logger.debug("Stored client_handle %s, base_uri %s, bpki_cert %r, bpki_glue %r, last_cms_timestamp %s", - client.client_handle, client.base_uri, client.bpki_cert, client.bpki_glue, - client.last_cms_timestamp) - r_pdu = SubElement(r_msg, q_pdu.tag, action = action, client_handle = client_handle) - if q_pdu.get("tag"): - r_pdu.set("tag", q_pdu.get("tag")) - - if action == "destroy": - rpki.pubdb.models.Client.objects.filter(client_handle = client_handle).delete() - r_pdu = SubElement(r_msg, q_pdu.tag, action = action, client_handle = client_handle) - if q_pdu.get("tag"): - r_pdu.set("tag", q_pdu.get("tag")) - - except Exception, e: - logger.exception("Exception processing PDU %r", q_pdu) - r_pdu = SubElement(r_msg, rpki.publication_control.tag_report_error, error_code = e.__class__.__name__) - r_pdu.text = str(e) - if q_pdu.get("tag") is not None: - r_pdu.set("tag", q_pdu.get("tag")) - - request.send_cms_response(rpki.publication_control.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert)) - - except Exception, e: - logger.exception("Unhandled exception processing control query, path %r", request.path) - request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) - - - client_url_regexp = re.compile("/client/([-A-Z0-9_/]+)$", re.I) - - def client_handler(self, request, q_der): - """ - Process one PDU from a client. + Main program for pubd. """ - from django.db import transaction, connection - - try: - connection.cursor() # Reconnect to mysqld if necessary - match = self.client_url_regexp.search(request.path) - if match is None: - raise rpki.exceptions.BadContactURL("Bad path: %s" % request.path) - client = rpki.pubdb.models.Client.objects.get(client_handle = match.group(1)) - q_cms = rpki.publication.cms_msg(DER = q_der) - q_msg = q_cms.unwrap((self.bpki_ta, client.bpki_cert, client.bpki_glue)) - client.last_cms_timestamp = q_cms.check_replay(client.last_cms_timestamp, client.client_handle) - client.save() - if q_msg.get("type") != "query": - raise rpki.exceptions.BadQuery("Message type is %s, expected query" % q_msg.get("type")) - r_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, - type = "reply", version = rpki.publication.version) - delta = None - try: - with transaction.atomic(): - for q_pdu in q_msg: - if q_pdu.get("uri"): - logger.info("Client %s request for %s", q_pdu.tag, q_pdu.get("uri")) - else: - logger.info("Client %s request", q_pdu.tag) + def __init__(self): - if q_pdu.tag == rpki.publication.tag_list: - for obj in client.publishedobject_set.all(): - r_pdu = SubElement(r_msg, q_pdu.tag, uri = obj.uri, hash = obj.hash) + os.environ.update(TZ = "UTC", + DJANGO_SETTINGS_MODULE = "rpki.django_settings.pubd") + time.tzset() + + self.irbe_cms_timestamp = None + + parser = argparse.ArgumentParser(description = __doc__) + parser.add_argument("-c", "--config", + help = "override default location of configuration file") + parser.add_argument("-f", "--foreground", action = "store_true", + help = "do not daemonize") + parser.add_argument("--pidfile", + help = "override default location of pid file") + parser.add_argument("--profile", + help = "enable profiling, saving data to PROFILE") + rpki.log.argparse_setup(parser) + args = parser.parse_args() + + self.profile = args.profile + + rpki.log.init("pubd", args) + + self.cfg = rpki.config.parser(set_filename = args.config, section = "pubd") + self.cfg.set_global_flags() + + if not args.foreground: + rpki.daemonize.daemon(pidfile = args.pidfile) + + if self.profile: + import cProfile + prof = cProfile.Profile() + try: + prof.runcall(self.main) + finally: + prof.dump_stats(self.profile) + logger.info("Dumped profile data to %s", self.profile) + else: + self.main() + + def main(self): + + if self.profile: + logger.info("Running in profile mode with output to %s", self.profile) + + import django + django.setup() + + global rpki # pylint: disable=W0602 + import rpki.pubdb # pylint: disable=W0621 + + self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) + self.irbe_cert = rpki.x509.X509(Auto_update = self.cfg.get("irbe-cert")) + self.pubd_cert = rpki.x509.X509(Auto_update = self.cfg.get("pubd-cert")) + self.pubd_key = rpki.x509.RSA( Auto_update = self.cfg.get("pubd-key")) + self.pubd_crl = rpki.x509.CRL( Auto_update = self.cfg.get("pubd-crl")) + + self.http_server_host = self.cfg.get("server-host", "") + self.http_server_port = self.cfg.getint("server-port") + + self.publication_base = self.cfg.get("publication-base", "publication/") + + self.rrdp_uri_base = self.cfg.get("rrdp-uri-base", + "http://%s/rrdp/" % socket.getfqdn()) + self.rrdp_expiration_interval = rpki.sundial.timedelta.parse(self.cfg.get("rrdp-expiration-interval", "6h")) + self.rrdp_publication_base = self.cfg.get("rrdp-publication-base", + "rrdp-publication/") + + try: + self.session = rpki.pubdb.models.Session.objects.get() + except rpki.pubdb.models.Session.DoesNotExist: + self.session = rpki.pubdb.models.Session.objects.create(uuid = str(uuid.uuid4()), serial = 0) + + rpki.http_simple.server( + host = self.http_server_host, + port = self.http_server_port, + handlers = (("/control", self.control_handler), + ("/client/", self.client_handler))) + + + def control_handler(self, request, q_der): + """ + Process one PDU from the IRBE. + """ + + from django.db import transaction, connection + + try: + connection.cursor() # Reconnect to mysqld if necessary + q_cms = rpki.publication_control.cms_msg(DER = q_der) + q_msg = q_cms.unwrap((self.bpki_ta, self.irbe_cert)) + self.irbe_cms_timestamp = q_cms.check_replay(self.irbe_cms_timestamp, "control") + if q_msg.get("type") != "query": + raise rpki.exceptions.BadQuery("Message type is %s, expected query" % q_msg.get("type")) + r_msg = Element(rpki.publication_control.tag_msg, nsmap = rpki.publication_control.nsmap, + type = "reply", version = rpki.publication_control.version) + + try: + q_pdu = None + with transaction.atomic(): + + for q_pdu in q_msg: + if q_pdu.tag != rpki.publication_control.tag_client: + raise rpki.exceptions.BadQuery("PDU is %s, expected client" % q_pdu.tag) + client_handle = q_pdu.get("client_handle") + action = q_pdu.get("action") + if client_handle is None: + logger.info("Control %s request", action) + else: + logger.info("Control %s request for %s", action, client_handle) + + if action in ("get", "list"): + if action == "get": + clients = rpki.pubdb.models.Client.objects.get(client_handle = client_handle), + else: + clients = rpki.pubdb.models.Client.objects.all() + for client in clients: + r_pdu = SubElement(r_msg, q_pdu.tag, action = action, + client_handle = client.client_handle, base_uri = client.base_uri) + if q_pdu.get("tag"): + r_pdu.set("tag", q_pdu.get("tag")) + SubElement(r_pdu, rpki.publication_control.tag_bpki_cert).text = client.bpki_cert.get_Base64() + if client.bpki_glue is not None: + SubElement(r_pdu, rpki.publication_control.tag_bpki_glue).text = client.bpki_glue.get_Base64() + + if action in ("create", "set"): + if action == "create": + client = rpki.pubdb.models.Client(client_handle = client_handle) + else: + client = rpki.pubdb.models.Client.objects.get(client_handle = client_handle) + if q_pdu.get("base_uri"): + client.base_uri = q_pdu.get("base_uri") + bpki_cert = q_pdu.find(rpki.publication_control.tag_bpki_cert) + if bpki_cert is not None: + client.bpki_cert = rpki.x509.X509(Base64 = bpki_cert.text) + bpki_glue = q_pdu.find(rpki.publication_control.tag_bpki_glue) + if bpki_glue is not None: + client.bpki_glue = rpki.x509.X509(Base64 = bpki_glue.text) + if q_pdu.get("clear_replay_protection") == "yes": + client.last_cms_timestamp = None + client.save() + logger.debug("Stored client_handle %s, base_uri %s, bpki_cert %r, bpki_glue %r, last_cms_timestamp %s", + client.client_handle, client.base_uri, client.bpki_cert, client.bpki_glue, + client.last_cms_timestamp) + r_pdu = SubElement(r_msg, q_pdu.tag, action = action, client_handle = client_handle) + if q_pdu.get("tag"): + r_pdu.set("tag", q_pdu.get("tag")) + + if action == "destroy": + rpki.pubdb.models.Client.objects.filter(client_handle = client_handle).delete() + r_pdu = SubElement(r_msg, q_pdu.tag, action = action, client_handle = client_handle) + if q_pdu.get("tag"): + r_pdu.set("tag", q_pdu.get("tag")) + + except Exception, e: + logger.exception("Exception processing PDU %r", q_pdu) + r_pdu = SubElement(r_msg, rpki.publication_control.tag_report_error, error_code = e.__class__.__name__) + r_pdu.text = str(e) if q_pdu.get("tag") is not None: - r_pdu.set("tag", q_pdu.get("tag")) + r_pdu.set("tag", q_pdu.get("tag")) + + request.send_cms_response(rpki.publication_control.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert)) + + except Exception, e: + logger.exception("Unhandled exception processing control query, path %r", request.path) + request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + + + client_url_regexp = re.compile("/client/([-A-Z0-9_/]+)$", re.I) + + def client_handler(self, request, q_der): + """ + Process one PDU from a client. + """ + + from django.db import transaction, connection + + try: + connection.cursor() # Reconnect to mysqld if necessary + match = self.client_url_regexp.search(request.path) + if match is None: + raise rpki.exceptions.BadContactURL("Bad path: %s" % request.path) + client = rpki.pubdb.models.Client.objects.get(client_handle = match.group(1)) + q_cms = rpki.publication.cms_msg(DER = q_der) + q_msg = q_cms.unwrap((self.bpki_ta, client.bpki_cert, client.bpki_glue)) + client.last_cms_timestamp = q_cms.check_replay(client.last_cms_timestamp, client.client_handle) + client.save() + if q_msg.get("type") != "query": + raise rpki.exceptions.BadQuery("Message type is %s, expected query" % q_msg.get("type")) + r_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, + type = "reply", version = rpki.publication.version) + delta = None + try: + with transaction.atomic(): + for q_pdu in q_msg: + if q_pdu.get("uri"): + logger.info("Client %s request for %s", q_pdu.tag, q_pdu.get("uri")) + else: + logger.info("Client %s request", q_pdu.tag) + + if q_pdu.tag == rpki.publication.tag_list: + for obj in client.publishedobject_set.all(): + r_pdu = SubElement(r_msg, q_pdu.tag, uri = obj.uri, hash = obj.hash) + if q_pdu.get("tag") is not None: + r_pdu.set("tag", q_pdu.get("tag")) + + else: + assert q_pdu.tag in (rpki.publication.tag_publish, rpki.publication.tag_withdraw) + if delta is None: + delta = self.session.new_delta(rpki.sundial.now() + self.rrdp_expiration_interval) + client.check_allowed_uri(q_pdu.get("uri")) + if q_pdu.tag == rpki.publication.tag_publish: + der = q_pdu.text.decode("base64") + logger.info("Publishing %s", rpki.x509.uri_dispatch(q_pdu.get("uri"))(DER = der).tracking_data(q_pdu.get("uri"))) + delta.publish(client, der, q_pdu.get("uri"), q_pdu.get("hash")) + else: + logger.info("Withdrawing %s", q_pdu.get("uri")) + delta.withdraw(client, q_pdu.get("uri"), q_pdu.get("hash")) + r_pdu = SubElement(r_msg, q_pdu.tag, uri = q_pdu.get("uri")) + if q_pdu.get("tag") is not None: + r_pdu.set("tag", q_pdu.get("tag")) + + if delta is not None: + delta.activate() + self.session.generate_snapshot() + self.session.expire_deltas() + + except Exception, e: + logger.exception("Exception processing PDU %r", q_pdu) + r_pdu = SubElement(r_msg, rpki.publication.tag_report_error, error_code = e.__class__.__name__) + r_pdu.text = str(e) + if q_pdu.get("tag") is not None: + r_pdu.set("tag", q_pdu.get("tag")) else: - assert q_pdu.tag in (rpki.publication.tag_publish, rpki.publication.tag_withdraw) - if delta is None: - delta = self.session.new_delta(rpki.sundial.now() + self.rrdp_expiration_interval) - client.check_allowed_uri(q_pdu.get("uri")) - if q_pdu.tag == rpki.publication.tag_publish: - der = q_pdu.text.decode("base64") - logger.info("Publishing %s", rpki.x509.uri_dispatch(q_pdu.get("uri"))(DER = der).tracking_data(q_pdu.get("uri"))) - delta.publish(client, der, q_pdu.get("uri"), q_pdu.get("hash")) - else: - logger.info("Withdrawing %s", q_pdu.get("uri")) - delta.withdraw(client, q_pdu.get("uri"), q_pdu.get("hash")) - r_pdu = SubElement(r_msg, q_pdu.tag, uri = q_pdu.get("uri")) - if q_pdu.get("tag") is not None: - r_pdu.set("tag", q_pdu.get("tag")) - - if delta is not None: - delta.activate() - self.session.generate_snapshot() - self.session.expire_deltas() - - except Exception, e: - logger.exception("Exception processing PDU %r", q_pdu) - r_pdu = SubElement(r_msg, rpki.publication.tag_report_error, error_code = e.__class__.__name__) - r_pdu.text = str(e) - if q_pdu.get("tag") is not None: - r_pdu.set("tag", q_pdu.get("tag")) - - else: - if delta is not None: - self.session.synchronize_rrdp_files(self.rrdp_publication_base, self.rrdp_uri_base) - delta.update_rsync_files(self.publication_base) - - request.send_cms_response(rpki.publication.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert, self.pubd_crl)) - - except Exception, e: - logger.exception("Unhandled exception processing client query, path %r", request.path) - request.send_error(500, "Could not process PDU: %s" % e) + if delta is not None: + self.session.synchronize_rrdp_files(self.rrdp_publication_base, self.rrdp_uri_base) + delta.update_rsync_files(self.publication_base) + + request.send_cms_response(rpki.publication.cms_msg().wrap(r_msg, self.pubd_key, self.pubd_cert, self.pubd_crl)) + + except Exception, e: + logger.exception("Unhandled exception processing client query, path %r", request.path) + request.send_error(500, "Could not process PDU: %s" % e) diff --git a/rpki/pubdb/models.py b/rpki/pubdb/models.py index 2b6d67e4..46dcf493 100644 --- a/rpki/pubdb/models.py +++ b/rpki/pubdb/models.py @@ -48,266 +48,266 @@ rrdp_tag_withdraw = rrdp_xmlns + "withdraw" # sure quite where to put it at the moment. def DERSubElement(elt, name, der, attrib = None, **kwargs): - """ - Convenience wrapper around SubElement for use with Base64 text. - """ + """ + Convenience wrapper around SubElement for use with Base64 text. + """ - se = SubElement(elt, name, attrib, **kwargs) - se.text = rpki.x509.base64_with_linebreaks(der) - se.tail = "\n" - return se + se = SubElement(elt, name, attrib, **kwargs) + se.text = rpki.x509.base64_with_linebreaks(der) + se.tail = "\n" + return se class Client(models.Model): - client_handle = models.CharField(unique = True, max_length = 255) - base_uri = models.TextField() - bpki_cert = CertificateField() - bpki_glue = CertificateField(null = True) - last_cms_timestamp = SundialField(blank = True, null = True) + client_handle = models.CharField(unique = True, max_length = 255) + base_uri = models.TextField() + bpki_cert = CertificateField() + bpki_glue = CertificateField(null = True) + last_cms_timestamp = SundialField(blank = True, null = True) - def check_allowed_uri(self, uri): - """ - Make sure that a target URI is within this client's allowed URI space. - """ + def check_allowed_uri(self, uri): + """ + Make sure that a target URI is within this client's allowed URI space. + """ - if not uri.startswith(self.base_uri): - raise rpki.exceptions.ForbiddenURI + if not uri.startswith(self.base_uri): + raise rpki.exceptions.ForbiddenURI class Session(models.Model): - uuid = models.CharField(unique = True, max_length=36) - serial = models.BigIntegerField() - snapshot = models.TextField(blank = True) - hash = models.CharField(max_length = 64, blank = True) - - ## @var keep_all_rrdp_files - # Debugging flag to prevent expiration of old RRDP files. - # This simplifies debugging delta code. Need for this - # may go away once RRDP is fully integrated into rcynic. - keep_all_rrdp_files = False - - def new_delta(self, expires): - """ - Construct a new delta associated with this session. - """ + uuid = models.CharField(unique = True, max_length=36) + serial = models.BigIntegerField() + snapshot = models.TextField(blank = True) + hash = models.CharField(max_length = 64, blank = True) + + ## @var keep_all_rrdp_files + # Debugging flag to prevent expiration of old RRDP files. + # This simplifies debugging delta code. Need for this + # may go away once RRDP is fully integrated into rcynic. + keep_all_rrdp_files = False + + def new_delta(self, expires): + """ + Construct a new delta associated with this session. + """ + + delta = Delta(session = self, + serial = self.serial + 1, + expires = expires) + delta.elt = Element(rrdp_tag_delta, + nsmap = rrdp_nsmap, + version = rrdp_version, + session_id = self.uuid, + serial = str(delta.serial)) + return delta + + + def expire_deltas(self): + """ + Delete deltas whose expiration date has passed. + """ + + self.delta_set.filter(expires__lt = rpki.sundial.now()).delete() + + + def generate_snapshot(self): + """ + Generate an XML snapshot of this session. + """ + + xml = Element(rrdp_tag_snapshot, nsmap = rrdp_nsmap, + version = rrdp_version, + session_id = self.uuid, + serial = str(self.serial)) + xml.text = "\n" + for obj in self.publishedobject_set.all(): + DERSubElement(xml, rrdp_tag_publish, + der = obj.der, + uri = obj.uri) + rpki.relaxng.rrdp.assertValid(xml) + self.snapshot = ElementToString(xml, pretty_print = True) + self.hash = rpki.x509.sha256(self.snapshot).encode("hex") + self.save() + + + @property + def snapshot_fn(self): + return "%s/snapshot/%s.xml" % (self.uuid, self.serial) + + + @property + def notification_fn(self): + return "notify.xml" + + + @staticmethod + def _write_rrdp_file(fn, text, rrdp_publication_base, overwrite = False): + if overwrite or not os.path.exists(os.path.join(rrdp_publication_base, fn)): + tn = os.path.join(rrdp_publication_base, fn + ".%s.tmp" % os.getpid()) + if not os.path.isdir(os.path.dirname(tn)): + os.makedirs(os.path.dirname(tn)) + with open(tn, "w") as f: + f.write(text) + os.rename(tn, os.path.join(rrdp_publication_base, fn)) + + + @staticmethod + def _rrdp_filename_to_uri(fn, rrdp_uri_base): + return "%s/%s" % (rrdp_uri_base.rstrip("/"), fn) + + + def _generate_update_xml(self, rrdp_uri_base): + xml = Element(rrdp_tag_notification, nsmap = rrdp_nsmap, + version = rrdp_version, + session_id = self.uuid, + serial = str(self.serial)) + SubElement(xml, rrdp_tag_snapshot, + uri = self._rrdp_filename_to_uri(self.snapshot_fn, rrdp_uri_base), + hash = self.hash) + for delta in self.delta_set.all(): + SubElement(xml, rrdp_tag_delta, + uri = self._rrdp_filename_to_uri(delta.fn, rrdp_uri_base), + hash = delta.hash, + serial = str(delta.serial)) + rpki.relaxng.rrdp.assertValid(xml) + return ElementToString(xml, pretty_print = True) + + + def synchronize_rrdp_files(self, rrdp_publication_base, rrdp_uri_base): + """ + Write current RRDP files to disk, clean up old files and directories. + """ + + current_filenames = set() + + for delta in self.delta_set.all(): + self._write_rrdp_file(delta.fn, delta.xml, rrdp_publication_base) + current_filenames.add(delta.fn) + + self._write_rrdp_file(self.snapshot_fn, self.snapshot, rrdp_publication_base) + current_filenames.add(self.snapshot_fn) + + self._write_rrdp_file(self.notification_fn, self._generate_update_xml(rrdp_uri_base), + rrdp_publication_base, overwrite = True) + current_filenames.add(self.notification_fn) + + if not self.keep_all_rrdp_files: + for root, dirs, files in os.walk(rrdp_publication_base, topdown = False): + for fn in files: + fn = os.path.join(root, fn) + if fn[len(rrdp_publication_base):].lstrip("/") not in current_filenames: + os.remove(fn) + for dn in dirs: + try: + os.rmdir(os.path.join(root, dn)) + except OSError: + pass - delta = Delta(session = self, - serial = self.serial + 1, - expires = expires) - delta.elt = Element(rrdp_tag_delta, - nsmap = rrdp_nsmap, - version = rrdp_version, - session_id = self.uuid, - serial = str(delta.serial)) - return delta +class Delta(models.Model): + serial = models.BigIntegerField() + xml = models.TextField() + hash = models.CharField(max_length = 64) + expires = SundialField() + session = models.ForeignKey(Session) - def expire_deltas(self): - """ - Delete deltas whose expiration date has passed. - """ - self.delta_set.filter(expires__lt = rpki.sundial.now()).delete() + @staticmethod + def _uri_to_filename(uri, publication_base): + if not uri.startswith("rsync://"): + raise rpki.exceptions.BadURISyntax(uri) + path = uri.split("/")[4:] + path.insert(0, publication_base.rstrip("/")) + filename = "/".join(path) + if "/../" in filename or filename.endswith("/.."): + raise rpki.exceptions.BadURISyntax(filename) + return filename - def generate_snapshot(self): - """ - Generate an XML snapshot of this session. - """ + @property + def fn(self): + return "%s/deltas/%s.xml" % (self.session.uuid, self.serial) - xml = Element(rrdp_tag_snapshot, nsmap = rrdp_nsmap, - version = rrdp_version, - session_id = self.uuid, - serial = str(self.serial)) - xml.text = "\n" - for obj in self.publishedobject_set.all(): - DERSubElement(xml, rrdp_tag_publish, - der = obj.der, - uri = obj.uri) - rpki.relaxng.rrdp.assertValid(xml) - self.snapshot = ElementToString(xml, pretty_print = True) - self.hash = rpki.x509.sha256(self.snapshot).encode("hex") - self.save() - - - @property - def snapshot_fn(self): - return "%s/snapshot/%s.xml" % (self.uuid, self.serial) - - - @property - def notification_fn(self): - return "notify.xml" - - - @staticmethod - def _write_rrdp_file(fn, text, rrdp_publication_base, overwrite = False): - if overwrite or not os.path.exists(os.path.join(rrdp_publication_base, fn)): - tn = os.path.join(rrdp_publication_base, fn + ".%s.tmp" % os.getpid()) - if not os.path.isdir(os.path.dirname(tn)): - os.makedirs(os.path.dirname(tn)) - with open(tn, "w") as f: - f.write(text) - os.rename(tn, os.path.join(rrdp_publication_base, fn)) - - - @staticmethod - def _rrdp_filename_to_uri(fn, rrdp_uri_base): - return "%s/%s" % (rrdp_uri_base.rstrip("/"), fn) - - - def _generate_update_xml(self, rrdp_uri_base): - xml = Element(rrdp_tag_notification, nsmap = rrdp_nsmap, - version = rrdp_version, - session_id = self.uuid, - serial = str(self.serial)) - SubElement(xml, rrdp_tag_snapshot, - uri = self._rrdp_filename_to_uri(self.snapshot_fn, rrdp_uri_base), - hash = self.hash) - for delta in self.delta_set.all(): - SubElement(xml, rrdp_tag_delta, - uri = self._rrdp_filename_to_uri(delta.fn, rrdp_uri_base), - hash = delta.hash, - serial = str(delta.serial)) - rpki.relaxng.rrdp.assertValid(xml) - return ElementToString(xml, pretty_print = True) - - - def synchronize_rrdp_files(self, rrdp_publication_base, rrdp_uri_base): - """ - Write current RRDP files to disk, clean up old files and directories. - """ - current_filenames = set() - - for delta in self.delta_set.all(): - self._write_rrdp_file(delta.fn, delta.xml, rrdp_publication_base) - current_filenames.add(delta.fn) - - self._write_rrdp_file(self.snapshot_fn, self.snapshot, rrdp_publication_base) - current_filenames.add(self.snapshot_fn) - - self._write_rrdp_file(self.notification_fn, self._generate_update_xml(rrdp_uri_base), - rrdp_publication_base, overwrite = True) - current_filenames.add(self.notification_fn) - - if not self.keep_all_rrdp_files: - for root, dirs, files in os.walk(rrdp_publication_base, topdown = False): - for fn in files: - fn = os.path.join(root, fn) - if fn[len(rrdp_publication_base):].lstrip("/") not in current_filenames: - os.remove(fn) - for dn in dirs: - try: - os.rmdir(os.path.join(root, dn)) - except OSError: - pass + def activate(self): + rpki.relaxng.rrdp.assertValid(self.elt) + self.xml = ElementToString(self.elt, pretty_print = True) + self.hash = rpki.x509.sha256(self.xml).encode("hex") + self.save() + self.session.serial += 1 + self.session.save() -class Delta(models.Model): - serial = models.BigIntegerField() - xml = models.TextField() - hash = models.CharField(max_length = 64) - expires = SundialField() - session = models.ForeignKey(Session) - - - @staticmethod - def _uri_to_filename(uri, publication_base): - if not uri.startswith("rsync://"): - raise rpki.exceptions.BadURISyntax(uri) - path = uri.split("/")[4:] - path.insert(0, publication_base.rstrip("/")) - filename = "/".join(path) - if "/../" in filename or filename.endswith("/.."): - raise rpki.exceptions.BadURISyntax(filename) - return filename - - - @property - def fn(self): - return "%s/deltas/%s.xml" % (self.session.uuid, self.serial) - - - def activate(self): - rpki.relaxng.rrdp.assertValid(self.elt) - self.xml = ElementToString(self.elt, pretty_print = True) - self.hash = rpki.x509.sha256(self.xml).encode("hex") - self.save() - self.session.serial += 1 - self.session.save() - - - def publish(self, client, der, uri, obj_hash): - try: - obj = client.publishedobject_set.get(session = self.session, uri = uri) - if obj.hash == obj_hash: - obj.delete() - elif obj_hash is None: - raise rpki.exceptions.ExistingObjectAtURI("Object already published at %s" % uri) - else: - raise rpki.exceptions.DifferentObjectAtURI("Found different object at %s (old %s, new %s)" % (uri, obj.hash, obj_hash)) - except rpki.pubdb.models.PublishedObject.DoesNotExist: - pass - logger.debug("Publishing %s", uri) - PublishedObject.objects.create(session = self.session, client = client, der = der, uri = uri, - hash = rpki.x509.sha256(der).encode("hex")) - se = DERSubElement(self.elt, rrdp_tag_publish, der = der, uri = uri) - if obj_hash is not None: - se.set("hash", obj_hash) - rpki.relaxng.rrdp.assertValid(self.elt) - - - def withdraw(self, client, uri, obj_hash): - obj = client.publishedobject_set.get(session = self.session, uri = uri) - if obj.hash != obj_hash: - raise rpki.exceptions.DifferentObjectAtURI("Found different object at %s (old %s, new %s)" % (uri, obj.hash, obj_hash)) - logger.debug("Withdrawing %s", uri) - obj.delete() - SubElement(self.elt, rrdp_tag_withdraw, uri = uri, hash = obj_hash).tail = "\n" - rpki.relaxng.rrdp.assertValid(self.elt) - - - def update_rsync_files(self, publication_base): - from errno import ENOENT - min_path_len = len(publication_base.rstrip("/")) - for pdu in self.elt: - assert pdu.tag in (rrdp_tag_publish, rrdp_tag_withdraw) - fn = self._uri_to_filename(pdu.get("uri"), publication_base) - if pdu.tag == rrdp_tag_publish: - tn = fn + ".tmp" - dn = os.path.dirname(fn) - if not os.path.isdir(dn): - os.makedirs(dn) - with open(tn, "wb") as f: - f.write(pdu.text.decode("base64")) - os.rename(tn, fn) - else: + def publish(self, client, der, uri, obj_hash): try: - os.remove(fn) - except OSError, e: - if e.errno != ENOENT: - raise - dn = os.path.dirname(fn) - while len(dn) > min_path_len: - try: - os.rmdir(dn) - except OSError: - break - else: - dn = os.path.dirname(dn) - del self.elt + obj = client.publishedobject_set.get(session = self.session, uri = uri) + if obj.hash == obj_hash: + obj.delete() + elif obj_hash is None: + raise rpki.exceptions.ExistingObjectAtURI("Object already published at %s" % uri) + else: + raise rpki.exceptions.DifferentObjectAtURI("Found different object at %s (old %s, new %s)" % (uri, obj.hash, obj_hash)) + except rpki.pubdb.models.PublishedObject.DoesNotExist: + pass + logger.debug("Publishing %s", uri) + PublishedObject.objects.create(session = self.session, client = client, der = der, uri = uri, + hash = rpki.x509.sha256(der).encode("hex")) + se = DERSubElement(self.elt, rrdp_tag_publish, der = der, uri = uri) + if obj_hash is not None: + se.set("hash", obj_hash) + rpki.relaxng.rrdp.assertValid(self.elt) + + + def withdraw(self, client, uri, obj_hash): + obj = client.publishedobject_set.get(session = self.session, uri = uri) + if obj.hash != obj_hash: + raise rpki.exceptions.DifferentObjectAtURI("Found different object at %s (old %s, new %s)" % (uri, obj.hash, obj_hash)) + logger.debug("Withdrawing %s", uri) + obj.delete() + SubElement(self.elt, rrdp_tag_withdraw, uri = uri, hash = obj_hash).tail = "\n" + rpki.relaxng.rrdp.assertValid(self.elt) + + + def update_rsync_files(self, publication_base): + from errno import ENOENT + min_path_len = len(publication_base.rstrip("/")) + for pdu in self.elt: + assert pdu.tag in (rrdp_tag_publish, rrdp_tag_withdraw) + fn = self._uri_to_filename(pdu.get("uri"), publication_base) + if pdu.tag == rrdp_tag_publish: + tn = fn + ".tmp" + dn = os.path.dirname(fn) + if not os.path.isdir(dn): + os.makedirs(dn) + with open(tn, "wb") as f: + f.write(pdu.text.decode("base64")) + os.rename(tn, fn) + else: + try: + os.remove(fn) + except OSError, e: + if e.errno != ENOENT: + raise + dn = os.path.dirname(fn) + while len(dn) > min_path_len: + try: + os.rmdir(dn) + except OSError: + break + else: + dn = os.path.dirname(dn) + del self.elt class PublishedObject(models.Model): - uri = models.CharField(max_length = 255) - der = models.BinaryField() - hash = models.CharField(max_length = 64) - client = models.ForeignKey(Client) - session = models.ForeignKey(Session) - - class Meta: # pylint: disable=C1001,W0232 - unique_together = (("session", "hash"), - ("session", "uri")) + uri = models.CharField(max_length = 255) + der = models.BinaryField() + hash = models.CharField(max_length = 64) + client = models.ForeignKey(Client) + session = models.ForeignKey(Session) + + class Meta: # pylint: disable=C1001,W0232 + unique_together = (("session", "hash"), + ("session", "uri")) diff --git a/rpki/publication.py b/rpki/publication.py index 16824d05..393e078e 100644 --- a/rpki/publication.py +++ b/rpki/publication.py @@ -51,34 +51,34 @@ allowed_content_types = (content_type,) def raise_if_error(pdu): - """ - Raise an appropriate error if this is a <report_error/> PDU. - - As a convenience, this will also accept a <msg/> PDU and raise an - appropriate error if it contains any <report_error/> PDUs or if - the <msg/> is not a reply. - """ - - if pdu.tag == tag_report_error: - code = pdu.get("error_code") - logger.debug("<report_error/> code %r", code) - e = getattr(rpki.exceptions, code, None) - if e is not None and issubclass(e, rpki.exceptions.RPKI_Exception): - raise e(pdu.text) - else: - raise rpki.exceptions.BadPublicationReply("Unexpected response from pubd: %r, %r" % (code, pdu)) - - if pdu.tag == tag_msg: - if pdu.get("type") != "reply": - raise rpki.exceptions.BadPublicationReply("Unexpected response from pubd: expected reply, got %r" % pdu.get("type")) - for p in pdu: - raise_if_error(p) + """ + Raise an appropriate error if this is a <report_error/> PDU. + + As a convenience, this will also accept a <msg/> PDU and raise an + appropriate error if it contains any <report_error/> PDUs or if + the <msg/> is not a reply. + """ + + if pdu.tag == tag_report_error: + code = pdu.get("error_code") + logger.debug("<report_error/> code %r", code) + e = getattr(rpki.exceptions, code, None) + if e is not None and issubclass(e, rpki.exceptions.RPKI_Exception): + raise e(pdu.text) + else: + raise rpki.exceptions.BadPublicationReply("Unexpected response from pubd: %r, %r" % (code, pdu)) + + if pdu.tag == tag_msg: + if pdu.get("type") != "reply": + raise rpki.exceptions.BadPublicationReply("Unexpected response from pubd: expected reply, got %r" % pdu.get("type")) + for p in pdu: + raise_if_error(p) class cms_msg(rpki.x509.XML_CMS_object): - """ - CMS-signed publication PDU. - """ + """ + CMS-signed publication PDU. + """ - encoding = "us-ascii" - schema = rpki.relaxng.publication + encoding = "us-ascii" + schema = rpki.relaxng.publication diff --git a/rpki/publication_control.py b/rpki/publication_control.py index ddb9d417..b0668eef 100644 --- a/rpki/publication_control.py +++ b/rpki/publication_control.py @@ -44,31 +44,31 @@ tag_report_error = rpki.relaxng.publication_control.xmlns + "report_error" def raise_if_error(pdu): - """ - Raise an appropriate error if this is a <report_error/> PDU. + """ + Raise an appropriate error if this is a <report_error/> PDU. - As a convience, this will also accept a <msg/> PDU and raise an - appropriate error if it contains any <report_error/> PDUs. - """ + As a convience, this will also accept a <msg/> PDU and raise an + appropriate error if it contains any <report_error/> PDUs. + """ - if pdu.tag == tag_report_error: - code = pdu.get("error_code") - logger.debug("<report_error/> code %r", code) - e = getattr(rpki.exceptions, code, None) - if e is not None and issubclass(e, rpki.exceptions.RPKI_Exception): - raise e(pdu.text) - else: - raise rpki.exceptions.BadPublicationReply("Unexpected response from pubd: %r, %r" % (code, pdu)) + if pdu.tag == tag_report_error: + code = pdu.get("error_code") + logger.debug("<report_error/> code %r", code) + e = getattr(rpki.exceptions, code, None) + if e is not None and issubclass(e, rpki.exceptions.RPKI_Exception): + raise e(pdu.text) + else: + raise rpki.exceptions.BadPublicationReply("Unexpected response from pubd: %r, %r" % (code, pdu)) - if pdu.tag == tag_msg: - for p in pdu: - raise_if_error(p) + if pdu.tag == tag_msg: + for p in pdu: + raise_if_error(p) class cms_msg(rpki.x509.XML_CMS_object): - """ - CMS-signed publication control PDU. - """ + """ + CMS-signed publication control PDU. + """ - encoding = "us-ascii" - schema = rpki.relaxng.publication_control + encoding = "us-ascii" + schema = rpki.relaxng.publication_control diff --git a/rpki/rcynic.py b/rpki/rcynic.py index a36e4a4e..3307e926 100644 --- a/rpki/rcynic.py +++ b/rpki/rcynic.py @@ -25,142 +25,142 @@ import rpki.resource_set from xml.etree.ElementTree import ElementTree class UnknownObject(rpki.exceptions.RPKI_Exception): - """ - Unrecognized object in rcynic result cache. - """ + """ + Unrecognized object in rcynic result cache. + """ class NotRsyncURI(rpki.exceptions.RPKI_Exception): - """ - URI is not an rsync URI. - """ + """ + URI is not an rsync URI. + """ class rcynic_object(object): - """ - An object read from rcynic cache. - """ + """ + An object read from rcynic cache. + """ - def __init__(self, filename, **kwargs): - self.filename = filename - for k, v in kwargs.iteritems(): - setattr(self, k, v) - self.obj = self.obj_class(DER_file = filename) + def __init__(self, filename, **kwargs): + self.filename = filename + for k, v in kwargs.iteritems(): + setattr(self, k, v) + self.obj = self.obj_class(DER_file = filename) - def __repr__(self): - return "<%s %s %s at 0x%x>" % (self.__class__.__name__, self.uri, self.resources, id(self)) + def __repr__(self): + return "<%s %s %s at 0x%x>" % (self.__class__.__name__, self.uri, self.resources, id(self)) - def show_attrs(self, *attrs): - """ - Print a bunch of object attributes, quietly ignoring any that - might be missing. - """ + def show_attrs(self, *attrs): + """ + Print a bunch of object attributes, quietly ignoring any that + might be missing. + """ - for a in attrs: - try: - print "%s: %s" % (a.capitalize(), getattr(self, a)) - except AttributeError: - pass + for a in attrs: + try: + print "%s: %s" % (a.capitalize(), getattr(self, a)) + except AttributeError: + pass - def show(self): - """ - Print common object attributes. - """ + def show(self): + """ + Print common object attributes. + """ - self.show_attrs("filename", "uri", "status", "timestamp") + self.show_attrs("filename", "uri", "status", "timestamp") class rcynic_certificate(rcynic_object): - """ - A certificate from rcynic cache. - """ - - obj_class = rpki.x509.X509 - - def __init__(self, filename, **kwargs): - rcynic_object.__init__(self, filename, **kwargs) - self.notBefore = self.obj.getNotBefore() - self.notAfter = self.obj.getNotAfter() - self.aia_uri = self.obj.get_aia_uri() - self.sia_directory_uri = self.obj.get_sia_directory_uri() - self.manifest_uri = self.obj.get_sia_manifest_uri() - self.resources = self.obj.get_3779resources() - self.is_ca = self.obj.is_CA() - self.serial = self.obj.getSerial() - self.issuer = self.obj.getIssuer() - self.subject = self.obj.getSubject() - self.ski = self.obj.hSKI() - self.aki = self.obj.hAKI() - - def show(self): """ - Print certificate attributes. + A certificate from rcynic cache. """ - rcynic_object.show(self) - self.show_attrs("notBefore", "notAfter", "aia_uri", "sia_directory_uri", "resources") + obj_class = rpki.x509.X509 + + def __init__(self, filename, **kwargs): + rcynic_object.__init__(self, filename, **kwargs) + self.notBefore = self.obj.getNotBefore() + self.notAfter = self.obj.getNotAfter() + self.aia_uri = self.obj.get_aia_uri() + self.sia_directory_uri = self.obj.get_sia_directory_uri() + self.manifest_uri = self.obj.get_sia_manifest_uri() + self.resources = self.obj.get_3779resources() + self.is_ca = self.obj.is_CA() + self.serial = self.obj.getSerial() + self.issuer = self.obj.getIssuer() + self.subject = self.obj.getSubject() + self.ski = self.obj.hSKI() + self.aki = self.obj.hAKI() + + def show(self): + """ + Print certificate attributes. + """ + + rcynic_object.show(self) + self.show_attrs("notBefore", "notAfter", "aia_uri", "sia_directory_uri", "resources") class rcynic_roa(rcynic_object): - """ - A ROA from rcynic cache. - """ - - obj_class = rpki.x509.ROA - - def __init__(self, filename, **kwargs): - rcynic_object.__init__(self, filename, **kwargs) - self.obj.extract() - self.asID = self.obj.get_POW().getASID() - self.prefix_sets = [] - v4, v6 = self.obj.get_POW().getPrefixes() - if v4: - self.prefix_sets.append(rpki.resource_set.roa_prefix_set_ipv4([ - rpki.resource_set.roa_prefix_ipv4(p[0], p[1], p[2]) for p in v4])) - if v6: - self.prefix_sets.append(rpki.resource_set.roa_prefix_set_ipv6([ - rpki.resource_set.roa_prefix_ipv6(p[0], p[1], p[2]) for p in v6])) - self.ee = rpki.x509.X509(POW = self.obj.get_POW().certs()[0]) - self.notBefore = self.ee.getNotBefore() - self.notAfter = self.ee.getNotAfter() - self.aia_uri = self.ee.get_aia_uri() - self.resources = self.ee.get_3779resources() - self.issuer = self.ee.getIssuer() - self.serial = self.ee.getSerial() - self.subject = self.ee.getSubject() - self.aki = self.ee.hAKI() - self.ski = self.ee.hSKI() - - def show(self): """ - Print ROA attributes. + A ROA from rcynic cache. """ - rcynic_object.show(self) - self.show_attrs("notBefore", "notAfter", "aia_uri", "resources", "asID") - if self.prefix_sets: - print "Prefixes:", ",".join(str(i) for i in self.prefix_sets) + obj_class = rpki.x509.ROA + + def __init__(self, filename, **kwargs): + rcynic_object.__init__(self, filename, **kwargs) + self.obj.extract() + self.asID = self.obj.get_POW().getASID() + self.prefix_sets = [] + v4, v6 = self.obj.get_POW().getPrefixes() + if v4: + self.prefix_sets.append(rpki.resource_set.roa_prefix_set_ipv4([ + rpki.resource_set.roa_prefix_ipv4(p[0], p[1], p[2]) for p in v4])) + if v6: + self.prefix_sets.append(rpki.resource_set.roa_prefix_set_ipv6([ + rpki.resource_set.roa_prefix_ipv6(p[0], p[1], p[2]) for p in v6])) + self.ee = rpki.x509.X509(POW = self.obj.get_POW().certs()[0]) + self.notBefore = self.ee.getNotBefore() + self.notAfter = self.ee.getNotAfter() + self.aia_uri = self.ee.get_aia_uri() + self.resources = self.ee.get_3779resources() + self.issuer = self.ee.getIssuer() + self.serial = self.ee.getSerial() + self.subject = self.ee.getSubject() + self.aki = self.ee.hAKI() + self.ski = self.ee.hSKI() + + def show(self): + """ + Print ROA attributes. + """ + + rcynic_object.show(self) + self.show_attrs("notBefore", "notAfter", "aia_uri", "resources", "asID") + if self.prefix_sets: + print "Prefixes:", ",".join(str(i) for i in self.prefix_sets) class rcynic_ghostbuster(rcynic_object): - """ - Ghostbuster record from the rcynic cache. - """ - - obj_class = rpki.x509.Ghostbuster - - def __init__(self, *args, **kwargs): - rcynic_object.__init__(self, *args, **kwargs) - self.obj.extract() - self.vcard = self.obj.get_content() - self.ee = rpki.x509.X509(POW = self.obj.get_POW().certs()[0]) - self.notBefore = self.ee.getNotBefore() - self.notAfter = self.ee.getNotAfter() - self.aia_uri = self.ee.get_aia_uri() - self.issuer = self.ee.getIssuer() - self.serial = self.ee.getSerial() - self.subject = self.ee.getSubject() - self.aki = self.ee.hAKI() - self.ski = self.ee.hSKI() - - def show(self): - rcynic_object.show(self) - self.show_attrs("notBefore", "notAfter", "vcard") + """ + Ghostbuster record from the rcynic cache. + """ + + obj_class = rpki.x509.Ghostbuster + + def __init__(self, *args, **kwargs): + rcynic_object.__init__(self, *args, **kwargs) + self.obj.extract() + self.vcard = self.obj.get_content() + self.ee = rpki.x509.X509(POW = self.obj.get_POW().certs()[0]) + self.notBefore = self.ee.getNotBefore() + self.notAfter = self.ee.getNotAfter() + self.aia_uri = self.ee.get_aia_uri() + self.issuer = self.ee.getIssuer() + self.serial = self.ee.getSerial() + self.subject = self.ee.getSubject() + self.aki = self.ee.hAKI() + self.ski = self.ee.hSKI() + + def show(self): + rcynic_object.show(self) + self.show_attrs("notBefore", "notAfter", "vcard") file_name_classes = { ".cer" : rcynic_certificate, @@ -168,112 +168,112 @@ file_name_classes = { ".roa" : rcynic_roa } class rcynic_file_iterator(object): - """ - Iterate over files in an rcynic output tree, yielding a Python - representation of each object found. - """ - - def __init__(self, rcynic_root, - authenticated_subdir = "authenticated"): - self.rcynic_dir = os.path.join(rcynic_root, authenticated_subdir) - - def __iter__(self): - for root, dirs, files in os.walk(self.rcynic_dir): # pylint: disable=W0612 - for filename in files: - filename = os.path.join(root, filename) - ext = os.path.splitext(filename)[1] - if ext in file_name_classes: - yield file_name_classes[ext](filename) + """ + Iterate over files in an rcynic output tree, yielding a Python + representation of each object found. + """ + + def __init__(self, rcynic_root, + authenticated_subdir = "authenticated"): + self.rcynic_dir = os.path.join(rcynic_root, authenticated_subdir) + + def __iter__(self): + for root, dirs, files in os.walk(self.rcynic_dir): # pylint: disable=W0612 + for filename in files: + filename = os.path.join(root, filename) + ext = os.path.splitext(filename)[1] + if ext in file_name_classes: + yield file_name_classes[ext](filename) class validation_status_element(object): - def __init__(self, *args, **kwargs): - self.attrs = [] - for k, v in kwargs.iteritems(): - setattr(self, k, v) - # attribute names are saved so that the __repr__ method can - # display the subset of attributes the user specified - self.attrs.append(k) - self._obj = None - - def get_obj(self): - if not self._obj: - self._obj = self.file_class(filename=self.filename, uri=self.uri) - return self._obj - - def __repr__(self): - v = [self.__class__.__name__, 'id=%s' % str(id(self))] - v.extend(['%s=%s' % (x, getattr(self, x)) for x in self.attrs]) - return '<%s>' % (' '.join(v),) - - obj = property(get_obj) + def __init__(self, *args, **kwargs): + self.attrs = [] + for k, v in kwargs.iteritems(): + setattr(self, k, v) + # attribute names are saved so that the __repr__ method can + # display the subset of attributes the user specified + self.attrs.append(k) + self._obj = None + + def get_obj(self): + if not self._obj: + self._obj = self.file_class(filename=self.filename, uri=self.uri) + return self._obj + + def __repr__(self): + v = [self.__class__.__name__, 'id=%s' % str(id(self))] + v.extend(['%s=%s' % (x, getattr(self, x)) for x in self.attrs]) + return '<%s>' % (' '.join(v),) + + obj = property(get_obj) class rcynic_xml_iterator(object): - """ - Iterate over validation_status entries in the XML output from an - rcynic run. Yields a tuple for each entry: - - timestamp, generation, status, object - - where URI, status, and timestamp are the corresponding values from - the XML element, OK is a boolean indicating whether validation was - considered succesful, and object is a Python representation of the - object in question. If OK is True, object will be from rcynic's - authenticated output tree; otherwise, object will be from rcynic's - unauthenticated output tree. - - Note that it is possible for the same URI to appear in more than one - validation_status element; in such cases, the succesful case (OK - True) should be the last entry (as rcynic will stop trying once it - gets a good copy), but there may be multiple failures, which might - or might not have different status codes. - """ - - def __init__(self, rcynic_root, xml_file, - authenticated_old_subdir = "authenticated.old", - unauthenticated_subdir = "unauthenticated"): - self.rcynic_root = rcynic_root - self.xml_file = xml_file - self.authenticated_subdir = os.path.join(rcynic_root, 'authenticated') - self.authenticated_old_subdir = os.path.join(rcynic_root, authenticated_old_subdir) - self.unauthenticated_subdir = os.path.join(rcynic_root, unauthenticated_subdir) - - base_uri = "rsync://" - - def uri_to_filename(self, uri): - if uri.startswith(self.base_uri): - return uri[len(self.base_uri):] - else: - raise NotRsyncURI("Not an rsync URI %r" % uri) - - def __iter__(self): - for validation_status in ElementTree(file=self.xml_file).getroot().getiterator("validation_status"): - timestamp = validation_status.get("timestamp") - status = validation_status.get("status") - uri = validation_status.text.strip() - generation = validation_status.get("generation") - - # determine the path to this object - if status == 'object_accepted': - d = self.authenticated_subdir - elif generation == 'backup': - d = self.authenticated_old_subdir - else: - d = self.unauthenticated_subdir - - filename = os.path.join(d, self.uri_to_filename(uri)) - - ext = os.path.splitext(filename)[1] - if ext in file_name_classes: - yield validation_status_element(timestamp = timestamp, generation = generation, - uri=uri, status = status, filename = filename, - file_class = file_name_classes[ext]) + """ + Iterate over validation_status entries in the XML output from an + rcynic run. Yields a tuple for each entry: + + timestamp, generation, status, object + + where URI, status, and timestamp are the corresponding values from + the XML element, OK is a boolean indicating whether validation was + considered succesful, and object is a Python representation of the + object in question. If OK is True, object will be from rcynic's + authenticated output tree; otherwise, object will be from rcynic's + unauthenticated output tree. + + Note that it is possible for the same URI to appear in more than one + validation_status element; in such cases, the succesful case (OK + True) should be the last entry (as rcynic will stop trying once it + gets a good copy), but there may be multiple failures, which might + or might not have different status codes. + """ + + def __init__(self, rcynic_root, xml_file, + authenticated_old_subdir = "authenticated.old", + unauthenticated_subdir = "unauthenticated"): + self.rcynic_root = rcynic_root + self.xml_file = xml_file + self.authenticated_subdir = os.path.join(rcynic_root, 'authenticated') + self.authenticated_old_subdir = os.path.join(rcynic_root, authenticated_old_subdir) + self.unauthenticated_subdir = os.path.join(rcynic_root, unauthenticated_subdir) + + base_uri = "rsync://" + + def uri_to_filename(self, uri): + if uri.startswith(self.base_uri): + return uri[len(self.base_uri):] + else: + raise NotRsyncURI("Not an rsync URI %r" % uri) + + def __iter__(self): + for validation_status in ElementTree(file=self.xml_file).getroot().getiterator("validation_status"): + timestamp = validation_status.get("timestamp") + status = validation_status.get("status") + uri = validation_status.text.strip() + generation = validation_status.get("generation") + + # determine the path to this object + if status == 'object_accepted': + d = self.authenticated_subdir + elif generation == 'backup': + d = self.authenticated_old_subdir + else: + d = self.unauthenticated_subdir + + filename = os.path.join(d, self.uri_to_filename(uri)) + + ext = os.path.splitext(filename)[1] + if ext in file_name_classes: + yield validation_status_element(timestamp = timestamp, generation = generation, + uri=uri, status = status, filename = filename, + file_class = file_name_classes[ext]) def label_iterator(xml_file): - """ - Returns an iterator which contains all defined labels from an rcynic XML - output file. Each item is a tuple of the form - (label, kind, description). - """ - - for label in ElementTree(file=xml_file).find("labels"): - yield label.tag, label.get("kind"), label.text.strip() + """ + Returns an iterator which contains all defined labels from an rcynic XML + output file. Each item is a tuple of the form + (label, kind, description). + """ + + for label in ElementTree(file=xml_file).find("labels"): + yield label.tag, label.get("kind"), label.text.strip() diff --git a/rpki/relaxng.py b/rpki/relaxng.py index 566be90f..49ea88d8 100644 --- a/rpki/relaxng.py +++ b/rpki/relaxng.py @@ -7,17 +7,17 @@ from rpki.relaxng_parser import RelaxNGParser left_right = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: left-right.rnc 6137 2015-10-20 19:21:37Z sra $ - + RelaxNG schema for RPKI left-right protocol. - + Copyright (C) 2012- -2014 Dragon Research Labs ("DRL") Portions copyright (C) 2009- -2011 Internet Systems Consortium ("ISC") Portions copyright (C) 2007- -2008 American Registry for Internet Numbers ("ARIN") - + Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notices and this permission notice appear in all copies. - + THE SOFTWARE IS PROVIDED "AS IS" AND DRL, ISC, AND ARIN DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL, @@ -1106,23 +1106,23 @@ left_right = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> myrpki = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: myrpki.rnc 5876 2014-06-26 19:00:12Z sra $ - + RelaxNG schema for MyRPKI XML messages. - + This message protocol is on its way out, as we're in the process of moving on from the user interface model that produced it, but even after we finish replacing it we'll still need the schema for a while to validate old messages when upgrading. - + libxml2 (including xmllint) only groks the XML syntax of RelaxNG, so run the compact syntax through trang to get XML syntax. - + Copyright (C) 2009-2011 Internet Systems Consortium ("ISC") - + Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. - + THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, @@ -1661,17 +1661,17 @@ oob_setup = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> publication_control = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: publication-control.rnc 5903 2014-07-18 17:08:13Z sra $ - + RelaxNG schema for RPKI publication protocol. - + Copyright (C) 2012- -2014 Dragon Research Labs ("DRL") Portions copyright (C) 2009- -2011 Internet Systems Consortium ("ISC") Portions copyright (C) 2007- -2008 American Registry for Internet Numbers ("ARIN") - + Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notices and this permission notice appear in all copies. - + THE SOFTWARE IS PROVIDED "AS IS" AND DRL, ISC, AND ARIN DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL, @@ -1735,7 +1735,7 @@ publication_control = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- Base64 encoded DER stuff base64 = xsd:base64Binary { maxLength="512000" } - + Sadly, it turns out that CRLs can in fact get longer than this for an active CA. Remove length limit for now, think about whether to put it back later. --> @@ -1945,29 +1945,29 @@ publication_control = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> publication = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: publication.rnc 5896 2014-07-15 19:34:32Z sra $ - + RelaxNG schema for RPKI publication protocol, from current I-D. - + Copyright (c) 2014 IETF Trust and the persons identified as authors of the code. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - + * Neither the name of Internet Society, IETF or IETF Trust, nor the names of specific contributors, may be used to endorse or promote products derived from this software without specific prior written permission. - + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS @@ -2150,22 +2150,22 @@ publication = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> router_certificate = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: router-certificate.rnc 5881 2014-07-03 16:55:02Z sra $ - + RelaxNG schema for BGPSEC router certificate interchange format. - + At least for now, this is a trivial encapsulation of a PKCS #10 request, a set (usually containing exactly one member) of autonomous system numbers, and a router-id. Be warned that this could change radically by the time we have any real operational understanding of how these things will be used, this is just our current best guess to let us move forward on initial coding. - + Copyright (C) 2014 Dragon Research Labs ("DRL") - + Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. - + THE SOFTWARE IS PROVIDED "AS IS" AND DRL DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL BE LIABLE FOR ANY SPECIAL, DIRECT, @@ -2252,15 +2252,15 @@ router_certificate = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> rrdp = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: rrdp.rnc 6010 2014-11-08 18:01:58Z sra $ - + RelaxNG schema for RPKI Repository Delta Protocol (RRDP). - + Copyright (C) 2014 Dragon Research Labs ("DRL") - + Permission to use, copy, modify, and distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. - + THE SOFTWARE IS PROVIDED "AS IS" AND DRL DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL BE LIABLE FOR ANY SPECIAL, DIRECT, @@ -2406,29 +2406,29 @@ rrdp = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> up_down = RelaxNGParser(r'''<?xml version="1.0" encoding="UTF-8"?> <!-- $Id: up-down.rnc 5881 2014-07-03 16:55:02Z sra $ - + RelaxNG schema for the up-down protocol, extracted from RFC 6492. - + Copyright (c) 2012 IETF Trust and the persons identified as authors of the code. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - + * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - + * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - + * Neither the name of Internet Society, IETF or IETF Trust, nor the names of specific contributors, may be used to endorse or promote products derived from this software without specific prior written permission. - + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS diff --git a/rpki/relaxng_parser.py b/rpki/relaxng_parser.py index 466b1a79..53ec8f0d 100644 --- a/rpki/relaxng_parser.py +++ b/rpki/relaxng_parser.py @@ -22,21 +22,21 @@ from an XML-format RelaxNG schema. import lxml.etree class RelaxNGParser(object): - """ - Parse schema, extract XML namespace and protocol version (if any). - Method calls are just passed along to the parsed RelaxNG schema. - """ + """ + Parse schema, extract XML namespace and protocol version (if any). + Method calls are just passed along to the parsed RelaxNG schema. + """ - def __init__(self, text): - xml = lxml.etree.fromstring(text) - self.schema = lxml.etree.RelaxNG(xml) - ns = xml.get("ns") - self.xmlns = "{" + ns + "}" - self.nsmap = { None : ns } - x = xml.xpath("ns0:define[@name = 'version']/ns0:value", - namespaces = dict(ns0 = "http://relaxng.org/ns/structure/1.0")) - if len(x) == 1: - self.version = x[0].text + def __init__(self, text): + xml = lxml.etree.fromstring(text) + self.schema = lxml.etree.RelaxNG(xml) + ns = xml.get("ns") + self.xmlns = "{" + ns + "}" + self.nsmap = { None : ns } + x = xml.xpath("ns0:define[@name = 'version']/ns0:value", + namespaces = dict(ns0 = "http://relaxng.org/ns/structure/1.0")) + if len(x) == 1: + self.version = x[0].text - def __getattr__(self, name): - return getattr(self.schema, name) + def __getattr__(self, name): + return getattr(self.schema, name) diff --git a/rpki/resource_set.py b/rpki/resource_set.py index 130bf4e7..43dfa9ef 100644 --- a/rpki/resource_set.py +++ b/rpki/resource_set.py @@ -44,780 +44,780 @@ re_prefix_with_maxlen = re.compile("^([0-9:.a-fA-F]+)/([0-9]+)-([0-9]+)$") re_prefix = re.compile("^([0-9:.a-fA-F]+)/([0-9]+)$") class resource_range(object): - """ - Generic resource range type. Assumes underlying type is some kind - of integer. - - This is a virtual class. You probably don't want to use this type - directly. - """ - - def __init__(self, range_min, range_max): - assert range_min.__class__ is range_max.__class__, \ - "Type mismatch, %r doesn't match %r" % (range_min.__class__, range_max.__class__) - assert range_min <= range_max, "Mis-ordered range: %s before %s" % (range_min, range_max) - self.min = range_min - self.max = range_max - - def __cmp__(self, other): - assert self.__class__ is other.__class__, \ - "Type mismatch, comparing %r with %r" % (self.__class__, other.__class__) - return cmp(self.min, other.min) or cmp(self.max, other.max) - -class resource_range_as(resource_range): - """ - Range of Autonomous System Numbers. - - Denotes a single ASN by a range whose min and max values are - identical. - """ - - ## @var datum_type - # Type of underlying data (min and max). - - datum_type = long - - def __init__(self, range_min, range_max): - resource_range.__init__(self, - long(range_min) if isinstance(range_min, int) else range_min, - long(range_max) if isinstance(range_max, int) else range_max) - - def __str__(self): """ - Convert a resource_range_as to string format. - """ - - if self.min == self.max: - return str(self.min) - else: - return str(self.min) + "-" + str(self.max) + Generic resource range type. Assumes underlying type is some kind + of integer. - @classmethod - def parse_str(cls, x): - """ - Parse ASN resource range from text (eg, XML attributes). + This is a virtual class. You probably don't want to use this type + directly. """ - r = re_asn_range.match(x) - if r: - return cls(long(r.group(1)), long(r.group(2))) - else: - return cls(long(x), long(x)) + def __init__(self, range_min, range_max): + assert range_min.__class__ is range_max.__class__, \ + "Type mismatch, %r doesn't match %r" % (range_min.__class__, range_max.__class__) + assert range_min <= range_max, "Mis-ordered range: %s before %s" % (range_min, range_max) + self.min = range_min + self.max = range_max - @classmethod - def from_strings(cls, a, b = None): - """ - Construct ASN range from strings. - """ + def __cmp__(self, other): + assert self.__class__ is other.__class__, \ + "Type mismatch, comparing %r with %r" % (self.__class__, other.__class__) + return cmp(self.min, other.min) or cmp(self.max, other.max) - if b is None: - b = a - return cls(long(a), long(b)) - -class resource_range_ip(resource_range): - """ - Range of (generic) IP addresses. - - Prefixes are converted to ranges on input, and ranges that can be - represented as prefixes are written as prefixes on output. - - This is a virtual class. You probably don't want to use it - directly. - """ - - ## @var datum_type - # Type of underlying data (min and max). - - datum_type = rpki.POW.IPAddress - - def prefixlen(self): - """ - Determine whether a resource_range_ip can be expressed as a - prefix. Returns prefix length if it can, otherwise raises - MustBePrefix exception. - """ - - mask = self.min ^ self.max - if self.min & mask != 0: - raise rpki.exceptions.MustBePrefix - prefixlen = self.min.bits - while mask & 1: - prefixlen -= 1 - mask >>= 1 - if mask: - raise rpki.exceptions.MustBePrefix - return prefixlen - - @property - def can_be_prefix(self): - """ - Boolean property indicating whether this range can be expressed as - a prefix. - - This just calls .prefixlen() to do the work, so that we can keep - the logic in one place. This property is useful primarily in - context where catching an exception isn't practical. - """ - - try: - self.prefixlen() - return True - except rpki.exceptions.MustBePrefix: - return False - - def __str__(self): - """ - Convert a resource_range_ip to string format. - """ - - try: - return str(self.min) + "/" + str(self.prefixlen()) - except rpki.exceptions.MustBePrefix: - return str(self.min) + "-" + str(self.max) - - @classmethod - def parse_str(cls, x): - """ - Parse IP address range or prefix from text (eg, XML attributes). - """ - - r = re_address_range.match(x) - if r: - return cls.from_strings(r.group(1), r.group(2)) - r = re_prefix.match(x) - if r: - a = rpki.POW.IPAddress(r.group(1)) - if cls is resource_range_ip and a.version == 4: - cls = resource_range_ipv4 - if cls is resource_range_ip and a.version == 6: - cls = resource_range_ipv6 - return cls.make_prefix(a, int(r.group(2))) - raise rpki.exceptions.BadIPResource('Bad IP resource "%s"' % x) - - @classmethod - def make_prefix(cls, prefix, prefixlen): - """ - Construct a resource range corresponding to a prefix. - """ - - assert isinstance(prefix, rpki.POW.IPAddress) and isinstance(prefixlen, (int, long)) - assert prefixlen >= 0 and prefixlen <= prefix.bits, "Nonsensical prefix length: %s" % prefixlen - mask = (1 << (prefix.bits - prefixlen)) - 1 - assert (prefix & mask) == 0, "Resource not in canonical form: %s/%s" % (prefix, prefixlen) - return cls(prefix, rpki.POW.IPAddress(prefix | mask)) - - def chop_into_prefixes(self, result): - """ - Chop up a resource_range_ip into ranges that can be represented as - prefixes. +class resource_range_as(resource_range): """ + Range of Autonomous System Numbers. - try: - self.prefixlen() - result.append(self) - except rpki.exceptions.MustBePrefix: - range_min = self.min - range_max = self.max - while range_max >= range_min: - bits = int(math.log(long(range_max - range_min + 1), 2)) - while True: - mask = ~(~0 << bits) - assert range_min + mask <= range_max - if range_min & mask == 0: - break - assert bits > 0 - bits -= 1 - result.append(self.make_prefix(range_min, range_min.bits - bits)) - range_min = range_min + mask + 1 - - @classmethod - def from_strings(cls, a, b = None): + Denotes a single ASN by a range whose min and max values are + identical. """ - Construct IP address range from strings. - """ - - if b is None: - b = a - a = rpki.POW.IPAddress(a) - b = rpki.POW.IPAddress(b) - if a.version != b.version: - raise TypeError - if cls is resource_range_ip: - if a.version == 4: - return resource_range_ipv4(a, b) - if a.version == 6: - return resource_range_ipv6(a, b) - elif a.version == cls.version: - return cls(a, b) - else: - raise TypeError - -class resource_range_ipv4(resource_range_ip): - """ - Range of IPv4 addresses. - """ - - version = 4 - -class resource_range_ipv6(resource_range_ip): - """ - Range of IPv6 addresses. - """ - - version = 6 - -def _rsplit(rset, that): - """ - Utility function to split a resource range into two resource ranges. - """ - this = rset.pop(0) + ## @var datum_type + # Type of underlying data (min and max). - assert type(this) is type(that), "type(this) [%r] is not type(that) [%r]" % (type(this), type(that)) + datum_type = long - assert type(this.min) is type(that.min), "type(this.min) [%r] is not type(that.min) [%r]" % (type(this.min), type(that.min)) - assert type(this.min) is type(this.max), "type(this.min) [%r] is not type(this.max) [%r]" % (type(this.min), type(this.max)) - assert type(that.min) is type(that.max), "type(that.min) [%r] is not type(that.max) [%r]" % (type(that.min), type(that.max)) + def __init__(self, range_min, range_max): + resource_range.__init__(self, + long(range_min) if isinstance(range_min, int) else range_min, + long(range_max) if isinstance(range_max, int) else range_max) - if this.min < that.min: - rset.insert(0, type(this)(this.min, type(that.min)(that.min - 1))) - rset.insert(1, type(this)(that.min, this.max)) + def __str__(self): + """ + Convert a resource_range_as to string format. + """ - else: - assert this.max > that.max - rset.insert(0, type(this)(this.min, that.max)) - rset.insert(1, type(this)(type(that.max)(that.max + 1), this.max)) - -class resource_set(list): - """ - Generic resource set, a list subclass containing resource ranges. - - This is a virtual class. You probably don't want to use it - directly. - """ - - ## @var inherit - # Boolean indicating whether this resource_set uses RFC 3779 inheritance. - - inherit = False - - ## @var canonical - # Whether this resource_set is currently in canonical form. - - canonical = False - - def __init__(self, ini = None, allow_overlap = False): - """ - Initialize a resource_set. - """ - - list.__init__(self) - if isinstance(ini, (int, long)): - ini = str(ini) - if ini is inherit_token: - self.inherit = True - elif isinstance(ini, str) and len(ini): - self.extend(self.parse_str(s) for s in ini.split(",")) - elif isinstance(ini, list): - self.extend(ini) - elif ini is not None and ini != "": - raise ValueError("Unexpected initializer: %s" % str(ini)) - self.canonize(allow_overlap) - - def canonize(self, allow_overlap = False): - """ - Whack this resource_set into canonical form. - """ - - assert not self.inherit or len(self) == 0 - if not self.canonical: - self.sort() - i = 0 - while i + 1 < len(self): - if allow_overlap and self[i].max + 1 >= self[i+1].min: - self[i] = type(self[i])(self[i].min, max(self[i].max, self[i+1].max)) - del self[i+1] - elif self[i].max + 1 == self[i+1].min: - self[i] = type(self[i])(self[i].min, self[i+1].max) - del self[i+1] + if self.min == self.max: + return str(self.min) else: - i += 1 - for i in xrange(0, len(self) - 1): - if self[i].max >= self[i+1].min: - raise rpki.exceptions.ResourceOverlap("Resource overlap: %s %s" % (self[i], self[i+1])) - self.canonical = True + return str(self.min) + "-" + str(self.max) - def append(self, item): - """ - Wrapper around list.append() (q.v.) to reset canonical flag. - """ + @classmethod + def parse_str(cls, x): + """ + Parse ASN resource range from text (eg, XML attributes). + """ - list.append(self, item) - self.canonical = False + r = re_asn_range.match(x) + if r: + return cls(long(r.group(1)), long(r.group(2))) + else: + return cls(long(x), long(x)) - def extend(self, item): - """ - Wrapper around list.extend() (q.v.) to reset canonical flag. - """ + @classmethod + def from_strings(cls, a, b = None): + """ + Construct ASN range from strings. + """ - list.extend(self, item) - self.canonical = False + if b is None: + b = a + return cls(long(a), long(b)) - def __str__(self): - """ - Convert a resource_set to string format. +class resource_range_ip(resource_range): """ + Range of (generic) IP addresses. + + Prefixes are converted to ranges on input, and ranges that can be + represented as prefixes are written as prefixes on output. + + This is a virtual class. You probably don't want to use it + directly. + """ + + ## @var datum_type + # Type of underlying data (min and max). + + datum_type = rpki.POW.IPAddress + + def prefixlen(self): + """ + Determine whether a resource_range_ip can be expressed as a + prefix. Returns prefix length if it can, otherwise raises + MustBePrefix exception. + """ + + mask = self.min ^ self.max + if self.min & mask != 0: + raise rpki.exceptions.MustBePrefix + prefixlen = self.min.bits + while mask & 1: + prefixlen -= 1 + mask >>= 1 + if mask: + raise rpki.exceptions.MustBePrefix + return prefixlen + + @property + def can_be_prefix(self): + """ + Boolean property indicating whether this range can be expressed as + a prefix. + + This just calls .prefixlen() to do the work, so that we can keep + the logic in one place. This property is useful primarily in + context where catching an exception isn't practical. + """ + + try: + self.prefixlen() + return True + except rpki.exceptions.MustBePrefix: + return False + + def __str__(self): + """ + Convert a resource_range_ip to string format. + """ + + try: + return str(self.min) + "/" + str(self.prefixlen()) + except rpki.exceptions.MustBePrefix: + return str(self.min) + "-" + str(self.max) + + @classmethod + def parse_str(cls, x): + """ + Parse IP address range or prefix from text (eg, XML attributes). + """ + + r = re_address_range.match(x) + if r: + return cls.from_strings(r.group(1), r.group(2)) + r = re_prefix.match(x) + if r: + a = rpki.POW.IPAddress(r.group(1)) + if cls is resource_range_ip and a.version == 4: + cls = resource_range_ipv4 + if cls is resource_range_ip and a.version == 6: + cls = resource_range_ipv6 + return cls.make_prefix(a, int(r.group(2))) + raise rpki.exceptions.BadIPResource('Bad IP resource "%s"' % x) + + @classmethod + def make_prefix(cls, prefix, prefixlen): + """ + Construct a resource range corresponding to a prefix. + """ + + assert isinstance(prefix, rpki.POW.IPAddress) and isinstance(prefixlen, (int, long)) + assert prefixlen >= 0 and prefixlen <= prefix.bits, "Nonsensical prefix length: %s" % prefixlen + mask = (1 << (prefix.bits - prefixlen)) - 1 + assert (prefix & mask) == 0, "Resource not in canonical form: %s/%s" % (prefix, prefixlen) + return cls(prefix, rpki.POW.IPAddress(prefix | mask)) + + def chop_into_prefixes(self, result): + """ + Chop up a resource_range_ip into ranges that can be represented as + prefixes. + """ + + try: + self.prefixlen() + result.append(self) + except rpki.exceptions.MustBePrefix: + range_min = self.min + range_max = self.max + while range_max >= range_min: + bits = int(math.log(long(range_max - range_min + 1), 2)) + while True: + mask = ~(~0 << bits) + assert range_min + mask <= range_max + if range_min & mask == 0: + break + assert bits > 0 + bits -= 1 + result.append(self.make_prefix(range_min, range_min.bits - bits)) + range_min = range_min + mask + 1 + + @classmethod + def from_strings(cls, a, b = None): + """ + Construct IP address range from strings. + """ + + if b is None: + b = a + a = rpki.POW.IPAddress(a) + b = rpki.POW.IPAddress(b) + if a.version != b.version: + raise TypeError + if cls is resource_range_ip: + if a.version == 4: + return resource_range_ipv4(a, b) + if a.version == 6: + return resource_range_ipv6(a, b) + elif a.version == cls.version: + return cls(a, b) + else: + raise TypeError - if self.inherit: - return inherit_token - else: - return ",".join(str(x) for x in self) - - def _comm(self, other): +class resource_range_ipv4(resource_range_ip): """ - Like comm(1), sort of. - - Returns a tuple of three resource sets: resources only in self, - resources only in other, and resources in both. Used (not very - efficiently) as the basis for most set operations on resource - sets. + Range of IPv4 addresses. """ - assert not self.inherit - assert type(self) is type(other), "Type mismatch %r %r" % (type(self), type(other)) - set1 = type(self)(self) # clone and whack into canonical form - set2 = type(other)(other) # ditto - only1, only2, both = [], [], [] - while set1 or set2: - if set1 and (not set2 or set1[0].max < set2[0].min): - only1.append(set1.pop(0)) - elif set2 and (not set1 or set2[0].max < set1[0].min): - only2.append(set2.pop(0)) - elif set1[0].min < set2[0].min: - _rsplit(set1, set2[0]) - elif set2[0].min < set1[0].min: - _rsplit(set2, set1[0]) - elif set1[0].max < set2[0].max: - _rsplit(set2, set1[0]) - elif set2[0].max < set1[0].max: - _rsplit(set1, set2[0]) - else: - assert set1[0].min == set2[0].min and set1[0].max == set2[0].max - both.append(set1.pop(0)) - set2.pop(0) - return type(self)(only1), type(self)(only2), type(self)(both) - - def union(self, other): - """ - Set union for resource sets. - """ + version = 4 - assert not self.inherit - assert type(self) is type(other), "Type mismatch: %r %r" % (type(self), type(other)) - set1 = type(self)(self) # clone and whack into canonical form - set2 = type(other)(other) # ditto - result = [] - while set1 or set2: - if set1 and (not set2 or set1[0].max < set2[0].min): - result.append(set1.pop(0)) - elif set2 and (not set1 or set2[0].max < set1[0].min): - result.append(set2.pop(0)) - else: - this = set1.pop(0) - that = set2.pop(0) - assert type(this) is type(that) - range_min = min(this.min, that.min) - range_max = max(this.max, that.max) - result.append(type(this)(range_min, range_max)) - while set1 and set1[0].max <= range_max: - assert set1[0].min >= range_min - del set1[0] - while set2 and set2[0].max <= range_max: - assert set2[0].min >= range_min - del set2[0] - return type(self)(result) - - __or__ = union - - def intersection(self, other): +class resource_range_ipv6(resource_range_ip): """ - Set intersection for resource sets. + Range of IPv6 addresses. """ - return self._comm(other)[2] - - __and__ = intersection + version = 6 - def difference(self, other): +def _rsplit(rset, that): """ - Set difference for resource sets. + Utility function to split a resource range into two resource ranges. """ - return self._comm(other)[0] + this = rset.pop(0) - __sub__ = difference + assert type(this) is type(that), "type(this) [%r] is not type(that) [%r]" % (type(this), type(that)) - def symmetric_difference(self, other): - """ - Set symmetric difference (XOR) for resource sets. - """ + assert type(this.min) is type(that.min), "type(this.min) [%r] is not type(that.min) [%r]" % (type(this.min), type(that.min)) + assert type(this.min) is type(this.max), "type(this.min) [%r] is not type(this.max) [%r]" % (type(this.min), type(this.max)) + assert type(that.min) is type(that.max), "type(that.min) [%r] is not type(that.max) [%r]" % (type(that.min), type(that.max)) - com = self._comm(other) - return com[0] | com[1] + if this.min < that.min: + rset.insert(0, type(this)(this.min, type(that.min)(that.min - 1))) + rset.insert(1, type(this)(that.min, this.max)) - __xor__ = symmetric_difference - - def contains(self, item): - """ - Set membership test for resource sets. - """ - - assert not self.inherit - self.canonize() - if not self: - return False - if type(item) is type(self[0]): - range_min = item.min - range_max = item.max else: - range_min = item - range_max = item - lo = 0 - hi = len(self) - while lo < hi: - mid = (lo + hi) / 2 - if self[mid].max < range_max: - lo = mid + 1 - else: - hi = mid - return lo < len(self) and self[lo].min <= range_min and self[lo].max >= range_max - - __contains__ = contains - - def issubset(self, other): - """ - Test whether self is a subset (possibly improper) of other. - """ - - for i in self: - if not other.contains(i): - return False - return True - - __le__ = issubset - - def issuperset(self, other): - """ - Test whether self is a superset (possibly improper) of other. - """ - - return other.issubset(self) - - __ge__ = issuperset - - def __lt__(self, other): - return not self.issuperset(other) - - def __gt__(self, other): - return not self.issubset(other) - - def __ne__(self, other): - """ - A set with the inherit bit set is always unequal to any other set, because - we can't know the answer here. This is also consistent with __nonzero__ - which returns True for inherit sets, and False for empty sets. - """ - - return self.inherit or other.inherit or list.__ne__(self, other) + assert this.max > that.max + rset.insert(0, type(this)(this.min, that.max)) + rset.insert(1, type(this)(type(that.max)(that.max + 1), this.max)) - def __eq__(self, other): - return not self.__ne__(other) - - def __nonzero__(self): - """ - Tests whether or not this set is empty. Note that sets with the inherit - bit set are considered non-empty, despite having zero length. - """ - - return self.inherit or len(self) - - @classmethod - def from_sql(cls, sql, query, args = None): - """ - Create resource set from an SQL query. - - sql is an object that supports execute() and fetchall() methods - like a DB API 2.0 cursor object. - - query is an SQL query that returns a sequence of (min, max) pairs. - """ - - sql.execute(query, args) - return cls(ini = [cls.range_type(cls.range_type.datum_type(b), - cls.range_type.datum_type(e)) - for (b, e) in sql.fetchall()]) - - @classmethod - def from_django(cls, iterable): - """ - Create resource set from a Django query. - - iterable is something which returns (min, max) pairs. - """ - - return cls(ini = [cls.range_type(cls.range_type.datum_type(b), - cls.range_type.datum_type(e)) - for (b, e) in iterable]) - - @classmethod - def parse_str(cls, s): - """ - Parse resource set from text string (eg, XML attributes). This is - a backwards compatability wrapper, real functionality is now part - of the range classes. +class resource_set(list): """ + Generic resource set, a list subclass containing resource ranges. + + This is a virtual class. You probably don't want to use it + directly. + """ + + ## @var inherit + # Boolean indicating whether this resource_set uses RFC 3779 inheritance. + + inherit = False + + ## @var canonical + # Whether this resource_set is currently in canonical form. + + canonical = False + + def __init__(self, ini = None, allow_overlap = False): + """ + Initialize a resource_set. + """ + + list.__init__(self) + if isinstance(ini, (int, long)): + ini = str(ini) + if ini is inherit_token: + self.inherit = True + elif isinstance(ini, str) and len(ini): + self.extend(self.parse_str(s) for s in ini.split(",")) + elif isinstance(ini, list): + self.extend(ini) + elif ini is not None and ini != "": + raise ValueError("Unexpected initializer: %s" % str(ini)) + self.canonize(allow_overlap) + + def canonize(self, allow_overlap = False): + """ + Whack this resource_set into canonical form. + """ + + assert not self.inherit or len(self) == 0 + if not self.canonical: + self.sort() + i = 0 + while i + 1 < len(self): + if allow_overlap and self[i].max + 1 >= self[i+1].min: + self[i] = type(self[i])(self[i].min, max(self[i].max, self[i+1].max)) + del self[i+1] + elif self[i].max + 1 == self[i+1].min: + self[i] = type(self[i])(self[i].min, self[i+1].max) + del self[i+1] + else: + i += 1 + for i in xrange(0, len(self) - 1): + if self[i].max >= self[i+1].min: + raise rpki.exceptions.ResourceOverlap("Resource overlap: %s %s" % (self[i], self[i+1])) + self.canonical = True + + def append(self, item): + """ + Wrapper around list.append() (q.v.) to reset canonical flag. + """ + + list.append(self, item) + self.canonical = False + + def extend(self, item): + """ + Wrapper around list.extend() (q.v.) to reset canonical flag. + """ + + list.extend(self, item) + self.canonical = False + + def __str__(self): + """ + Convert a resource_set to string format. + """ + + if self.inherit: + return inherit_token + else: + return ",".join(str(x) for x in self) + + def _comm(self, other): + """ + Like comm(1), sort of. + + Returns a tuple of three resource sets: resources only in self, + resources only in other, and resources in both. Used (not very + efficiently) as the basis for most set operations on resource + sets. + """ + + assert not self.inherit + assert type(self) is type(other), "Type mismatch %r %r" % (type(self), type(other)) + set1 = type(self)(self) # clone and whack into canonical form + set2 = type(other)(other) # ditto + only1, only2, both = [], [], [] + while set1 or set2: + if set1 and (not set2 or set1[0].max < set2[0].min): + only1.append(set1.pop(0)) + elif set2 and (not set1 or set2[0].max < set1[0].min): + only2.append(set2.pop(0)) + elif set1[0].min < set2[0].min: + _rsplit(set1, set2[0]) + elif set2[0].min < set1[0].min: + _rsplit(set2, set1[0]) + elif set1[0].max < set2[0].max: + _rsplit(set2, set1[0]) + elif set2[0].max < set1[0].max: + _rsplit(set1, set2[0]) + else: + assert set1[0].min == set2[0].min and set1[0].max == set2[0].max + both.append(set1.pop(0)) + set2.pop(0) + return type(self)(only1), type(self)(only2), type(self)(both) + + def union(self, other): + """ + Set union for resource sets. + """ + + assert not self.inherit + assert type(self) is type(other), "Type mismatch: %r %r" % (type(self), type(other)) + set1 = type(self)(self) # clone and whack into canonical form + set2 = type(other)(other) # ditto + result = [] + while set1 or set2: + if set1 and (not set2 or set1[0].max < set2[0].min): + result.append(set1.pop(0)) + elif set2 and (not set1 or set2[0].max < set1[0].min): + result.append(set2.pop(0)) + else: + this = set1.pop(0) + that = set2.pop(0) + assert type(this) is type(that) + range_min = min(this.min, that.min) + range_max = max(this.max, that.max) + result.append(type(this)(range_min, range_max)) + while set1 and set1[0].max <= range_max: + assert set1[0].min >= range_min + del set1[0] + while set2 and set2[0].max <= range_max: + assert set2[0].min >= range_min + del set2[0] + return type(self)(result) + + __or__ = union + + def intersection(self, other): + """ + Set intersection for resource sets. + """ + + return self._comm(other)[2] + + __and__ = intersection + + def difference(self, other): + """ + Set difference for resource sets. + """ + + return self._comm(other)[0] + + __sub__ = difference + + def symmetric_difference(self, other): + """ + Set symmetric difference (XOR) for resource sets. + """ + + com = self._comm(other) + return com[0] | com[1] + + __xor__ = symmetric_difference + + def contains(self, item): + """ + Set membership test for resource sets. + """ + + assert not self.inherit + self.canonize() + if not self: + return False + if type(item) is type(self[0]): + range_min = item.min + range_max = item.max + else: + range_min = item + range_max = item + lo = 0 + hi = len(self) + while lo < hi: + mid = (lo + hi) / 2 + if self[mid].max < range_max: + lo = mid + 1 + else: + hi = mid + return lo < len(self) and self[lo].min <= range_min and self[lo].max >= range_max - return cls.range_type.parse_str(s) - -class resource_set_as(resource_set): - """ - Autonomous System Number resource set. - """ - - ## @var range_type - # Type of range underlying this type of resource_set. + __contains__ = contains - range_type = resource_range_as + def issubset(self, other): + """ + Test whether self is a subset (possibly improper) of other. + """ -class resource_set_ip(resource_set): - """ - (Generic) IP address resource set. + for i in self: + if not other.contains(i): + return False + return True - This is a virtual class. You probably don't want to use it - directly. - """ + __le__ = issubset - def to_roa_prefix_set(self): - """ - Convert from a resource set to a ROA prefix set. - """ + def issuperset(self, other): + """ + Test whether self is a superset (possibly improper) of other. + """ - prefix_ranges = [] - for r in self: - r.chop_into_prefixes(prefix_ranges) - return self.roa_prefix_set_type([ - self.roa_prefix_set_type.prefix_type(r.min, r.prefixlen()) - for r in prefix_ranges]) + return other.issubset(self) -class resource_set_ipv4(resource_set_ip): - """ - IPv4 address resource set. - """ + __ge__ = issuperset - ## @var range_type - # Type of range underlying this type of resource_set. + def __lt__(self, other): + return not self.issuperset(other) - range_type = resource_range_ipv4 + def __gt__(self, other): + return not self.issubset(other) -class resource_set_ipv6(resource_set_ip): - """ - IPv6 address resource set. - """ + def __ne__(self, other): + """ + A set with the inherit bit set is always unequal to any other set, because + we can't know the answer here. This is also consistent with __nonzero__ + which returns True for inherit sets, and False for empty sets. + """ - ## @var range_type - # Type of range underlying this type of resource_set. + return self.inherit or other.inherit or list.__ne__(self, other) - range_type = resource_range_ipv6 + def __eq__(self, other): + return not self.__ne__(other) -class resource_bag(object): - """ - Container to simplify passing around the usual triple of ASN, IPv4, - and IPv6 resource sets. - """ + def __nonzero__(self): + """ + Tests whether or not this set is empty. Note that sets with the inherit + bit set are considered non-empty, despite having zero length. + """ - ## @var asn - # Set of Autonomous System Number resources. + return self.inherit or len(self) - ## @var v4 - # Set of IPv4 resources. + @classmethod + def from_sql(cls, sql, query, args = None): + """ + Create resource set from an SQL query. - ## @var v6 - # Set of IPv6 resources. + sql is an object that supports execute() and fetchall() methods + like a DB API 2.0 cursor object. - ## @var valid_until - # Expiration date of resources, for setting certificate notAfter field. + query is an SQL query that returns a sequence of (min, max) pairs. + """ - def __init__(self, asn = None, v4 = None, v6 = None, valid_until = None): - self.asn = asn or resource_set_as() - self.v4 = v4 or resource_set_ipv4() - self.v6 = v6 or resource_set_ipv6() - self.valid_until = valid_until + sql.execute(query, args) + return cls(ini = [cls.range_type(cls.range_type.datum_type(b), + cls.range_type.datum_type(e)) + for (b, e) in sql.fetchall()]) - def oversized(self, other): - """ - True iff self is oversized with respect to other. - """ + @classmethod + def from_django(cls, iterable): + """ + Create resource set from a Django query. - return not self.asn.issubset(other.asn) or \ - not self.v4.issubset(other.v4) or \ - not self.v6.issubset(other.v6) + iterable is something which returns (min, max) pairs. + """ - def undersized(self, other): - """ - True iff self is undersized with respect to other. - """ + return cls(ini = [cls.range_type(cls.range_type.datum_type(b), + cls.range_type.datum_type(e)) + for (b, e) in iterable]) - return not other.asn.issubset(self.asn) or \ - not other.v4.issubset(self.v4) or \ - not other.v6.issubset(self.v6) + @classmethod + def parse_str(cls, s): + """ + Parse resource set from text string (eg, XML attributes). This is + a backwards compatability wrapper, real functionality is now part + of the range classes. + """ - @classmethod - def from_inheritance(cls): - """ - Build a resource bag that just inherits everything from its - parent. - """ + return cls.range_type.parse_str(s) - self = cls() - self.asn = resource_set_as() - self.v4 = resource_set_ipv4() - self.v6 = resource_set_ipv6() - self.asn.inherit = True - self.v4.inherit = True - self.v6.inherit = True - return self - - @classmethod - def from_str(cls, text, allow_overlap = False): - """ - Parse a comma-separated text string into a resource_bag. Not - particularly efficient, fix that if and when it becomes an issue. +class resource_set_as(resource_set): """ - - asns = [] - v4s = [] - v6s = [] - for word in text.split(","): - if "." in word: - v4s.append(word) - elif ":" in word: - v6s.append(word) - else: - asns.append(word) - return cls(asn = resource_set_as(",".join(asns), allow_overlap) if asns else None, - v4 = resource_set_ipv4(",".join(v4s), allow_overlap) if v4s else None, - v6 = resource_set_ipv6(",".join(v6s), allow_overlap) if v6s else None) - - @classmethod - def from_POW_rfc3779(cls, resources): + Autonomous System Number resource set. """ - Build a resource_bag from data returned by - rpki.POW.X509.getRFC3779(). - The conversion to long for v4 and v6 is (intended to be) - temporary: in the long run, we should be using rpki.POW.IPAddress - rather than long here. - """ + ## @var range_type + # Type of range underlying this type of resource_set. - asn = inherit_token if resources[0] == "inherit" else [resource_range_as( r[0], r[1]) for r in resources[0] or ()] - v4 = inherit_token if resources[1] == "inherit" else [resource_range_ipv4(r[0], r[1]) for r in resources[1] or ()] - v6 = inherit_token if resources[2] == "inherit" else [resource_range_ipv6(r[0], r[1]) for r in resources[2] or ()] - return cls(resource_set_as(asn) if asn else None, - resource_set_ipv4(v4) if v4 else None, - resource_set_ipv6(v6) if v6 else None) + range_type = resource_range_as - def empty(self): - """ - True iff all resource sets in this bag are empty. +class resource_set_ip(resource_set): """ + (Generic) IP address resource set. - return not self.asn and not self.v4 and not self.v6 - - def __nonzero__(self): - return not self.empty() - - def __eq__(self, other): - return self.asn == other.asn and \ - self.v4 == other.v4 and \ - self.v6 == other.v6 and \ - self.valid_until == other.valid_until - - def __ne__(self, other): - return not (self == other) # pylint: disable=C0325 - - def intersection(self, other): - """ - Compute intersection with another resource_bag. valid_until - attribute (if any) inherits from self. + This is a virtual class. You probably don't want to use it + directly. """ - return self.__class__(self.asn & other.asn, - self.v4 & other.v4, - self.v6 & other.v6, - self.valid_until) + def to_roa_prefix_set(self): + """ + Convert from a resource set to a ROA prefix set. + """ - __and__ = intersection + prefix_ranges = [] + for r in self: + r.chop_into_prefixes(prefix_ranges) + return self.roa_prefix_set_type([ + self.roa_prefix_set_type.prefix_type(r.min, r.prefixlen()) + for r in prefix_ranges]) - def union(self, other): +class resource_set_ipv4(resource_set_ip): """ - Compute union with another resource_bag. valid_until attribute - (if any) inherits from self. + IPv4 address resource set. """ - return self.__class__(self.asn | other.asn, - self.v4 | other.v4, - self.v6 | other.v6, - self.valid_until) + ## @var range_type + # Type of range underlying this type of resource_set. - __or__ = union + range_type = resource_range_ipv4 - def difference(self, other): +class resource_set_ipv6(resource_set_ip): """ - Compute difference against another resource_bag. valid_until - attribute (if any) inherits from self + IPv6 address resource set. """ - return self.__class__(self.asn - other.asn, - self.v4 - other.v4, - self.v6 - other.v6, - self.valid_until) + ## @var range_type + # Type of range underlying this type of resource_set. - __sub__ = difference + range_type = resource_range_ipv6 - def symmetric_difference(self, other): - """ - Compute symmetric difference against another resource_bag. - valid_until attribute (if any) inherits from self +class resource_bag(object): """ - - return self.__class__(self.asn ^ other.asn, - self.v4 ^ other.v4, - self.v6 ^ other.v6, - self.valid_until) - - __xor__ = symmetric_difference - - def __str__(self): - s = "" - if self.asn: - s += "ASN: %s" % self.asn - if self.v4: - if s: - s += ", " - s += "V4: %s" % self.v4 - if self.v6: - if s: - s += ", " - s += "V6: %s" % self.v6 - return s - - def __iter__(self): - for r in self.asn: - yield r - for r in self.v4: - yield r - for r in self.v6: - yield r + Container to simplify passing around the usual triple of ASN, IPv4, + and IPv6 resource sets. + """ + + ## @var asn + # Set of Autonomous System Number resources. + + ## @var v4 + # Set of IPv4 resources. + + ## @var v6 + # Set of IPv6 resources. + + ## @var valid_until + # Expiration date of resources, for setting certificate notAfter field. + + def __init__(self, asn = None, v4 = None, v6 = None, valid_until = None): + self.asn = asn or resource_set_as() + self.v4 = v4 or resource_set_ipv4() + self.v6 = v6 or resource_set_ipv6() + self.valid_until = valid_until + + def oversized(self, other): + """ + True iff self is oversized with respect to other. + """ + + return not self.asn.issubset(other.asn) or \ + not self.v4.issubset(other.v4) or \ + not self.v6.issubset(other.v6) + + def undersized(self, other): + """ + True iff self is undersized with respect to other. + """ + + return not other.asn.issubset(self.asn) or \ + not other.v4.issubset(self.v4) or \ + not other.v6.issubset(self.v6) + + @classmethod + def from_inheritance(cls): + """ + Build a resource bag that just inherits everything from its + parent. + """ + + self = cls() + self.asn = resource_set_as() + self.v4 = resource_set_ipv4() + self.v6 = resource_set_ipv6() + self.asn.inherit = True + self.v4.inherit = True + self.v6.inherit = True + return self + + @classmethod + def from_str(cls, text, allow_overlap = False): + """ + Parse a comma-separated text string into a resource_bag. Not + particularly efficient, fix that if and when it becomes an issue. + """ + + asns = [] + v4s = [] + v6s = [] + for word in text.split(","): + if "." in word: + v4s.append(word) + elif ":" in word: + v6s.append(word) + else: + asns.append(word) + return cls(asn = resource_set_as(",".join(asns), allow_overlap) if asns else None, + v4 = resource_set_ipv4(",".join(v4s), allow_overlap) if v4s else None, + v6 = resource_set_ipv6(",".join(v6s), allow_overlap) if v6s else None) + + @classmethod + def from_POW_rfc3779(cls, resources): + """ + Build a resource_bag from data returned by + rpki.POW.X509.getRFC3779(). + + The conversion to long for v4 and v6 is (intended to be) + temporary: in the long run, we should be using rpki.POW.IPAddress + rather than long here. + """ + + asn = inherit_token if resources[0] == "inherit" else [resource_range_as( r[0], r[1]) for r in resources[0] or ()] + v4 = inherit_token if resources[1] == "inherit" else [resource_range_ipv4(r[0], r[1]) for r in resources[1] or ()] + v6 = inherit_token if resources[2] == "inherit" else [resource_range_ipv6(r[0], r[1]) for r in resources[2] or ()] + return cls(resource_set_as(asn) if asn else None, + resource_set_ipv4(v4) if v4 else None, + resource_set_ipv6(v6) if v6 else None) + + def empty(self): + """ + True iff all resource sets in this bag are empty. + """ + + return not self.asn and not self.v4 and not self.v6 + + def __nonzero__(self): + return not self.empty() + + def __eq__(self, other): + return self.asn == other.asn and \ + self.v4 == other.v4 and \ + self.v6 == other.v6 and \ + self.valid_until == other.valid_until + + def __ne__(self, other): + return not (self == other) # pylint: disable=C0325 + + def intersection(self, other): + """ + Compute intersection with another resource_bag. valid_until + attribute (if any) inherits from self. + """ + + return self.__class__(self.asn & other.asn, + self.v4 & other.v4, + self.v6 & other.v6, + self.valid_until) + + __and__ = intersection + + def union(self, other): + """ + Compute union with another resource_bag. valid_until attribute + (if any) inherits from self. + """ + + return self.__class__(self.asn | other.asn, + self.v4 | other.v4, + self.v6 | other.v6, + self.valid_until) + + __or__ = union + + def difference(self, other): + """ + Compute difference against another resource_bag. valid_until + attribute (if any) inherits from self + """ + + return self.__class__(self.asn - other.asn, + self.v4 - other.v4, + self.v6 - other.v6, + self.valid_until) + + __sub__ = difference + + def symmetric_difference(self, other): + """ + Compute symmetric difference against another resource_bag. + valid_until attribute (if any) inherits from self + """ + + return self.__class__(self.asn ^ other.asn, + self.v4 ^ other.v4, + self.v6 ^ other.v6, + self.valid_until) + + __xor__ = symmetric_difference + + def __str__(self): + s = "" + if self.asn: + s += "ASN: %s" % self.asn + if self.v4: + if s: + s += ", " + s += "V4: %s" % self.v4 + if self.v6: + if s: + s += ", " + s += "V6: %s" % self.v6 + return s + + def __iter__(self): + for r in self.asn: + yield r + for r in self.v4: + yield r + for r in self.v6: + yield r # Sadly, there are enough differences between RFC 3779 and the data # structures in the latest proposed ROA format that we can't just use @@ -828,369 +828,369 @@ class resource_bag(object): # worth. class roa_prefix(object): - """ - ROA prefix. This is similar to the resource_range_ip class, but - differs in that it only represents prefixes, never ranges, and - includes the maximum prefix length as an additional value. - - This is a virtual class, you probably don't want to use it directly. - """ - - ## @var prefix - # The prefix itself, an IP address with bits beyond the prefix - # length zeroed. - - ## @var prefixlen - # (Minimum) prefix length. - - ## @var max_prefixlen - # Maxmimum prefix length. - - def __init__(self, prefix, prefixlen, max_prefixlen = None): """ - Initialize a ROA prefix. max_prefixlen is optional and defaults - to prefixlen. max_prefixlen must not be smaller than prefixlen. - """ - - if max_prefixlen is None: - max_prefixlen = prefixlen - assert max_prefixlen >= prefixlen, "Bad max_prefixlen: %d must not be shorter than %d" % (max_prefixlen, prefixlen) - self.prefix = prefix - self.prefixlen = prefixlen - self.max_prefixlen = max_prefixlen + ROA prefix. This is similar to the resource_range_ip class, but + differs in that it only represents prefixes, never ranges, and + includes the maximum prefix length as an additional value. - def __cmp__(self, other): - """ - Compare two ROA prefix objects. Comparision is based on prefix, - prefixlen, and max_prefixlen, in that order. + This is a virtual class, you probably don't want to use it directly. """ - assert self.__class__ is other.__class__ - return (cmp(self.prefix, other.prefix) or - cmp(self.prefixlen, other.prefixlen) or - cmp(self.max_prefixlen, other.max_prefixlen)) + ## @var prefix + # The prefix itself, an IP address with bits beyond the prefix + # length zeroed. - def __str__(self): - """ - Convert a ROA prefix to string format. - """ - - if self.prefixlen == self.max_prefixlen: - return str(self.prefix) + "/" + str(self.prefixlen) - else: - return str(self.prefix) + "/" + str(self.prefixlen) + "-" + str(self.max_prefixlen) + ## @var prefixlen + # (Minimum) prefix length. - def to_resource_range(self): - """ - Convert this ROA prefix to the equivilent resource_range_ip - object. This is an irreversable transformation because it loses - the max_prefixlen attribute, nothing we can do about that. - """ + ## @var max_prefixlen + # Maxmimum prefix length. - return self.range_type.make_prefix(self.prefix, self.prefixlen) + def __init__(self, prefix, prefixlen, max_prefixlen = None): + """ + Initialize a ROA prefix. max_prefixlen is optional and defaults + to prefixlen. max_prefixlen must not be smaller than prefixlen. + """ - def min(self): - """ - Return lowest address covered by prefix. - """ + if max_prefixlen is None: + max_prefixlen = prefixlen + assert max_prefixlen >= prefixlen, "Bad max_prefixlen: %d must not be shorter than %d" % (max_prefixlen, prefixlen) + self.prefix = prefix + self.prefixlen = prefixlen + self.max_prefixlen = max_prefixlen - return self.prefix + def __cmp__(self, other): + """ + Compare two ROA prefix objects. Comparision is based on prefix, + prefixlen, and max_prefixlen, in that order. + """ - def max(self): - """ - Return highest address covered by prefix. - """ + assert self.__class__ is other.__class__ + return (cmp(self.prefix, other.prefix) or + cmp(self.prefixlen, other.prefixlen) or + cmp(self.max_prefixlen, other.max_prefixlen)) - return self.prefix | ((1 << (self.prefix.bits - self.prefixlen)) - 1) + def __str__(self): + """ + Convert a ROA prefix to string format. + """ - def to_POW_roa_tuple(self): - """ - Convert a resource_range_ip to rpki.POW.ROA.setPrefixes() format. - """ + if self.prefixlen == self.max_prefixlen: + return str(self.prefix) + "/" + str(self.prefixlen) + else: + return str(self.prefix) + "/" + str(self.prefixlen) + "-" + str(self.max_prefixlen) - return self.prefix, self.prefixlen, self.max_prefixlen + def to_resource_range(self): + """ + Convert this ROA prefix to the equivilent resource_range_ip + object. This is an irreversable transformation because it loses + the max_prefixlen attribute, nothing we can do about that. + """ - @classmethod - def parse_str(cls, x): - """ - Parse ROA prefix from text (eg, an XML attribute). - """ + return self.range_type.make_prefix(self.prefix, self.prefixlen) - r = re_prefix_with_maxlen.match(x) - if r: - return cls(rpki.POW.IPAddress(r.group(1)), int(r.group(2)), int(r.group(3))) - r = re_prefix.match(x) - if r: - return cls(rpki.POW.IPAddress(r.group(1)), int(r.group(2))) - raise rpki.exceptions.BadROAPrefix('Bad ROA prefix "%s"' % x) + def min(self): + """ + Return lowest address covered by prefix. + """ -class roa_prefix_ipv4(roa_prefix): - """ - IPv4 ROA prefix. - """ + return self.prefix - ## @var range_type - # Type of corresponding resource_range_ip. + def max(self): + """ + Return highest address covered by prefix. + """ - range_type = resource_range_ipv4 + return self.prefix | ((1 << (self.prefix.bits - self.prefixlen)) - 1) -class roa_prefix_ipv6(roa_prefix): - """ - IPv6 ROA prefix. - """ + def to_POW_roa_tuple(self): + """ + Convert a resource_range_ip to rpki.POW.ROA.setPrefixes() format. + """ - ## @var range_type - # Type of corresponding resource_range_ip. + return self.prefix, self.prefixlen, self.max_prefixlen - range_type = resource_range_ipv6 + @classmethod + def parse_str(cls, x): + """ + Parse ROA prefix from text (eg, an XML attribute). + """ -class roa_prefix_set(list): - """ - Set of ROA prefixes, analogous to the resource_set_ip class. - """ + r = re_prefix_with_maxlen.match(x) + if r: + return cls(rpki.POW.IPAddress(r.group(1)), int(r.group(2)), int(r.group(3))) + r = re_prefix.match(x) + if r: + return cls(rpki.POW.IPAddress(r.group(1)), int(r.group(2))) + raise rpki.exceptions.BadROAPrefix('Bad ROA prefix "%s"' % x) - def __init__(self, ini = None): +class roa_prefix_ipv4(roa_prefix): """ - Initialize a ROA prefix set. + IPv4 ROA prefix. """ - list.__init__(self) - if isinstance(ini, str) and len(ini): - self.extend(self.parse_str(s) for s in ini.split(",")) - elif isinstance(ini, (list, tuple)): - self.extend(ini) - else: - assert ini is None or ini == "", "Unexpected initializer: %s" % str(ini) - self.sort() - - def __str__(self): - """ - Convert a ROA prefix set to string format. - """ + ## @var range_type + # Type of corresponding resource_range_ip. - return ",".join(str(x) for x in self) + range_type = resource_range_ipv4 - @classmethod - def parse_str(cls, s): +class roa_prefix_ipv6(roa_prefix): """ - Parse ROA prefix from text (eg, an XML attribute). - This method is a backwards compatability shim. + IPv6 ROA prefix. """ - return cls.prefix_type.parse_str(s) + ## @var range_type + # Type of corresponding resource_range_ip. - def to_resource_set(self): - """ - Convert a ROA prefix set to a resource set. This is an - irreversable transformation. We have to compute a union here - because ROA prefix sets can include overlaps, while RFC 3779 - resource sets cannot. This is ugly, and there is almost certainly - a more efficient way to do this, but start by getting the output - right before worrying about making it fast or pretty. - """ + range_type = resource_range_ipv6 - r = self.resource_set_type() - s = self.resource_set_type() - s.append(None) - for p in self: - s[0] = p.to_resource_range() - r |= s - return r - - @classmethod - def from_sql(cls, sql, query, args = None): +class roa_prefix_set(list): """ - Create ROA prefix set from an SQL query. - - sql is an object that supports execute() and fetchall() methods - like a DB API 2.0 cursor object. - - query is an SQL query that returns a sequence of (prefix, - prefixlen, max_prefixlen) triples. + Set of ROA prefixes, analogous to the resource_set_ip class. """ - sql.execute(query, args) - return cls([cls.prefix_type(rpki.POW.IPAddress(x), int(y), int(z)) - for (x, y, z) in sql.fetchall()]) - - @classmethod - def from_django(cls, iterable): - """ - Create ROA prefix set from a Django query. + def __init__(self, ini = None): + """ + Initialize a ROA prefix set. + """ - iterable is something which returns (prefix, prefixlen, - max_prefixlen) triples. - """ + list.__init__(self) + if isinstance(ini, str) and len(ini): + self.extend(self.parse_str(s) for s in ini.split(",")) + elif isinstance(ini, (list, tuple)): + self.extend(ini) + else: + assert ini is None or ini == "", "Unexpected initializer: %s" % str(ini) + self.sort() + + def __str__(self): + """ + Convert a ROA prefix set to string format. + """ + + return ",".join(str(x) for x in self) + + @classmethod + def parse_str(cls, s): + """ + Parse ROA prefix from text (eg, an XML attribute). + This method is a backwards compatability shim. + """ + + return cls.prefix_type.parse_str(s) + + def to_resource_set(self): + """ + Convert a ROA prefix set to a resource set. This is an + irreversable transformation. We have to compute a union here + because ROA prefix sets can include overlaps, while RFC 3779 + resource sets cannot. This is ugly, and there is almost certainly + a more efficient way to do this, but start by getting the output + right before worrying about making it fast or pretty. + """ + + r = self.resource_set_type() + s = self.resource_set_type() + s.append(None) + for p in self: + s[0] = p.to_resource_range() + r |= s + return r + + @classmethod + def from_sql(cls, sql, query, args = None): + """ + Create ROA prefix set from an SQL query. + + sql is an object that supports execute() and fetchall() methods + like a DB API 2.0 cursor object. + + query is an SQL query that returns a sequence of (prefix, + prefixlen, max_prefixlen) triples. + """ + + sql.execute(query, args) + return cls([cls.prefix_type(rpki.POW.IPAddress(x), int(y), int(z)) + for (x, y, z) in sql.fetchall()]) + + @classmethod + def from_django(cls, iterable): + """ + Create ROA prefix set from a Django query. + + iterable is something which returns (prefix, prefixlen, + max_prefixlen) triples. + """ + + return cls([cls.prefix_type(rpki.POW.IPAddress(x), int(y), int(z)) + for (x, y, z) in iterable]) + + def to_POW_roa_tuple(self): + """ + Convert ROA prefix set to form used by rpki.POW.ROA.setPrefixes(). + """ + + if self: + return tuple(a.to_POW_roa_tuple() for a in self) + else: + return None - return cls([cls.prefix_type(rpki.POW.IPAddress(x), int(y), int(z)) - for (x, y, z) in iterable]) - def to_POW_roa_tuple(self): +class roa_prefix_set_ipv4(roa_prefix_set): """ - Convert ROA prefix set to form used by rpki.POW.ROA.setPrefixes(). + Set of IPv4 ROA prefixes. """ - if self: - return tuple(a.to_POW_roa_tuple() for a in self) - else: - return None - + ## @var prefix_type + # Type of underlying roa_prefix. -class roa_prefix_set_ipv4(roa_prefix_set): - """ - Set of IPv4 ROA prefixes. - """ + prefix_type = roa_prefix_ipv4 - ## @var prefix_type - # Type of underlying roa_prefix. + ## @var resource_set_type + # Type of corresponding resource_set_ip class. - prefix_type = roa_prefix_ipv4 - - ## @var resource_set_type - # Type of corresponding resource_set_ip class. - - resource_set_type = resource_set_ipv4 + resource_set_type = resource_set_ipv4 # Fix back link from resource_set to roa_prefix resource_set_ipv4.roa_prefix_set_type = roa_prefix_set_ipv4 class roa_prefix_set_ipv6(roa_prefix_set): - """ - Set of IPv6 ROA prefixes. - """ + """ + Set of IPv6 ROA prefixes. + """ - ## @var prefix_type - # Type of underlying roa_prefix. + ## @var prefix_type + # Type of underlying roa_prefix. - prefix_type = roa_prefix_ipv6 + prefix_type = roa_prefix_ipv6 - ## @var resource_set_type - # Type of corresponding resource_set_ip class. + ## @var resource_set_type + # Type of corresponding resource_set_ip class. - resource_set_type = resource_set_ipv6 + resource_set_type = resource_set_ipv6 # Fix back link from resource_set to roa_prefix resource_set_ipv6.roa_prefix_set_type = roa_prefix_set_ipv6 class roa_prefix_bag(object): - """ - Container to simplify passing around the combination of an IPv4 ROA - prefix set and an IPv6 ROA prefix set. - """ + """ + Container to simplify passing around the combination of an IPv4 ROA + prefix set and an IPv6 ROA prefix set. + """ - ## @var v4 - # Set of IPv4 prefixes. + ## @var v4 + # Set of IPv4 prefixes. - ## @var v6 - # Set of IPv6 prefixes. + ## @var v6 + # Set of IPv6 prefixes. - def __init__(self, v4 = None, v6 = None): - self.v4 = v4 or roa_prefix_set_ipv4() - self.v6 = v6 or roa_prefix_set_ipv6() + def __init__(self, v4 = None, v6 = None): + self.v4 = v4 or roa_prefix_set_ipv4() + self.v6 = v6 or roa_prefix_set_ipv6() - def __eq__(self, other): - return self.v4 == other.v4 and self.v6 == other.v6 + def __eq__(self, other): + return self.v4 == other.v4 and self.v6 == other.v6 - def __ne__(self, other): - return not (self == other) # pylint: disable=C0325 + def __ne__(self, other): + return not (self == other) # pylint: disable=C0325 # Test suite for set operations. if __name__ == "__main__": - def testprefix(v): - return " (%s)" % v.to_roa_prefix_set() if isinstance(v, resource_set_ip) else "" - - def test1(t, s1, s2): - if isinstance(s1, str) and isinstance(s2, str): - print "x: ", s1 - print "y: ", s2 - r1 = t(s1) - r2 = t(s2) - print "x: ", r1, testprefix(r1) - print "y: ", r2, testprefix(r2) - v1 = r1._comm(r2) - v2 = r2._comm(r1) - assert v1[0] == v2[1] and v1[1] == v2[0] and v1[2] == v2[2] - for i in r1: assert i in r1 and i.min in r1 and i.max in r1 - for i in r2: assert i in r2 and i.min in r2 and i.max in r2 - for i in v1[0]: assert i in r1 and i not in r2 - for i in v1[1]: assert i not in r1 and i in r2 - for i in v1[2]: assert i in r1 and i in r2 - v1 = r1 | r2 - v2 = r2 | r1 - assert v1 == v2 - print "x|y:", v1, testprefix(v1) - v1 = r1 - r2 - v2 = r2 - r1 - print "x-y:", v1, testprefix(v1) - print "y-x:", v2, testprefix(v2) - v1 = r1 ^ r2 - v2 = r2 ^ r1 - assert v1 == v2 - print "x^y:", v1, testprefix(v1) - v1 = r1 & r2 - v2 = r2 & r1 - assert v1 == v2 - print "x&y:", v1, testprefix(v1) - - def test2(t, s1, s2): - print "x: ", s1 - print "y: ", s2 - r1 = t(s1) - r2 = t(s2) - print "x: ", r1 - print "y: ", r2 - print "x>y:", (r1 > r2) - print "x<y:", (r1 < r2) - test1(t.resource_set_type, - r1.to_resource_set(), - r2.to_resource_set()) - - def test3(t, s1, s2): - test1(t, s1, s2) - r1 = t(s1).to_roa_prefix_set() - r2 = t(s2).to_roa_prefix_set() - print "x: ", r1 - print "y: ", r2 - print "x>y:", (r1 > r2) - print "x<y:", (r1 < r2) - test1(t.roa_prefix_set_type.resource_set_type, - r1.to_resource_set(), - r2.to_resource_set()) - - print - print "Testing set operations on resource sets" - print - test1(resource_set_as, "1,2,3,4,5,6,11,12,13,14,15", "1,2,3,4,5,6,111,121,131,141,151") - print - test1(resource_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.3.0.0/24,10.0.0.77/32") - print - test1(resource_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.0.0.0/24") - print - test1(resource_set_ipv4, "10.0.0.0/24", "10.3.0.0/24,10.0.0.77/32") - print - test1(resource_set_ipv4, "10.0.0.0/24", "10.0.0.0/32,10.0.0.2/32,10.0.0.4/32") - print - print "Testing set operations on ROA prefixes" - print - test2(roa_prefix_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.3.0.0/24,10.0.0.77/32") - print - test2(roa_prefix_set_ipv4, "10.0.0.0/24-32,10.6.0.0/24-32", "10.3.0.0/24,10.0.0.0/16-32") - print - test2(roa_prefix_set_ipv4, "10.3.0.0/24-24,10.0.0.0/16-32", "10.3.0.0/24,10.0.0.0/16-32") - print - test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::2/128") - print - test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::7/128") - print - test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::/120") - print - test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::/120-128") - print - test3(resource_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.3.0.0/24,10.0.0.77/32") - print - test3(resource_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::2/128") - print - test3(resource_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::/120") + def testprefix(v): + return " (%s)" % v.to_roa_prefix_set() if isinstance(v, resource_set_ip) else "" + + def test1(t, s1, s2): + if isinstance(s1, str) and isinstance(s2, str): + print "x: ", s1 + print "y: ", s2 + r1 = t(s1) + r2 = t(s2) + print "x: ", r1, testprefix(r1) + print "y: ", r2, testprefix(r2) + v1 = r1._comm(r2) + v2 = r2._comm(r1) + assert v1[0] == v2[1] and v1[1] == v2[0] and v1[2] == v2[2] + for i in r1: assert i in r1 and i.min in r1 and i.max in r1 + for i in r2: assert i in r2 and i.min in r2 and i.max in r2 + for i in v1[0]: assert i in r1 and i not in r2 + for i in v1[1]: assert i not in r1 and i in r2 + for i in v1[2]: assert i in r1 and i in r2 + v1 = r1 | r2 + v2 = r2 | r1 + assert v1 == v2 + print "x|y:", v1, testprefix(v1) + v1 = r1 - r2 + v2 = r2 - r1 + print "x-y:", v1, testprefix(v1) + print "y-x:", v2, testprefix(v2) + v1 = r1 ^ r2 + v2 = r2 ^ r1 + assert v1 == v2 + print "x^y:", v1, testprefix(v1) + v1 = r1 & r2 + v2 = r2 & r1 + assert v1 == v2 + print "x&y:", v1, testprefix(v1) + + def test2(t, s1, s2): + print "x: ", s1 + print "y: ", s2 + r1 = t(s1) + r2 = t(s2) + print "x: ", r1 + print "y: ", r2 + print "x>y:", (r1 > r2) + print "x<y:", (r1 < r2) + test1(t.resource_set_type, + r1.to_resource_set(), + r2.to_resource_set()) + + def test3(t, s1, s2): + test1(t, s1, s2) + r1 = t(s1).to_roa_prefix_set() + r2 = t(s2).to_roa_prefix_set() + print "x: ", r1 + print "y: ", r2 + print "x>y:", (r1 > r2) + print "x<y:", (r1 < r2) + test1(t.roa_prefix_set_type.resource_set_type, + r1.to_resource_set(), + r2.to_resource_set()) + + print + print "Testing set operations on resource sets" + print + test1(resource_set_as, "1,2,3,4,5,6,11,12,13,14,15", "1,2,3,4,5,6,111,121,131,141,151") + print + test1(resource_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.3.0.0/24,10.0.0.77/32") + print + test1(resource_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.0.0.0/24") + print + test1(resource_set_ipv4, "10.0.0.0/24", "10.3.0.0/24,10.0.0.77/32") + print + test1(resource_set_ipv4, "10.0.0.0/24", "10.0.0.0/32,10.0.0.2/32,10.0.0.4/32") + print + print "Testing set operations on ROA prefixes" + print + test2(roa_prefix_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.3.0.0/24,10.0.0.77/32") + print + test2(roa_prefix_set_ipv4, "10.0.0.0/24-32,10.6.0.0/24-32", "10.3.0.0/24,10.0.0.0/16-32") + print + test2(roa_prefix_set_ipv4, "10.3.0.0/24-24,10.0.0.0/16-32", "10.3.0.0/24,10.0.0.0/16-32") + print + test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::2/128") + print + test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::7/128") + print + test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::/120") + print + test2(roa_prefix_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::/120-128") + print + test3(resource_set_ipv4, "10.0.0.44/32,10.6.0.2/32", "10.3.0.0/24,10.0.0.77/32") + print + test3(resource_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::2/128") + print + test3(resource_set_ipv6, "2002:0a00:002c::1/128", "2002:0a00:002c::/120") diff --git a/rpki/rootd.py b/rpki/rootd.py index 1d4f5659..e3a460f4 100644 --- a/rpki/rootd.py +++ b/rpki/rootd.py @@ -44,416 +44,416 @@ logger = logging.getLogger(__name__) class ReplayTracker(object): - """ - Stash for replay protection timestamps. - """ + """ + Stash for replay protection timestamps. + """ - def __init__(self): - self.cms_timestamp = None + def __init__(self): + self.cms_timestamp = None class main(object): - def root_newer_than_subject(self): - return self.rpki_root_cert.mtime > os.stat(self.rpki_subject_cert_file).st_mtime - - - def get_subject_cert(self): - try: - x = rpki.x509.X509(Auto_file = self.rpki_subject_cert_file) - logger.debug("Read subject cert %s", self.rpki_subject_cert_file) - return x - except IOError: - return None - - - def set_subject_cert(self, cert): - logger.debug("Writing subject cert %s, SKI %s", self.rpki_subject_cert_file, cert.hSKI()) - with open(self.rpki_subject_cert_file, "wb") as f: - f.write(cert.get_DER()) - - - def del_subject_cert(self): - logger.debug("Deleting subject cert %s", self.rpki_subject_cert_file) - os.remove(self.rpki_subject_cert_file) - - - def get_subject_pkcs10(self): - try: - x = rpki.x509.PKCS10(Auto_file = self.rpki_subject_pkcs10) - logger.debug("Read subject PKCS #10 %s", self.rpki_subject_pkcs10) - return x - except IOError: - return None - - - def set_subject_pkcs10(self, pkcs10): - logger.debug("Writing subject PKCS #10 %s", self.rpki_subject_pkcs10) - with open(self.rpki_subject_pkcs10, "wb") as f: - f.write(pkcs10.get_DER()) - - - def del_subject_pkcs10(self): - logger.debug("Deleting subject PKCS #10 %s", self.rpki_subject_pkcs10) - try: - os.remove(self.rpki_subject_pkcs10) - except OSError: - pass - - - def issue_subject_cert_maybe(self, new_pkcs10): - now = rpki.sundial.now() - subject_cert = self.get_subject_cert() - if subject_cert is None: - subject_cert_hash = None - else: - subject_cert_hash = rpki.x509.sha256(subject_cert.get_DER()).encode("hex") - old_pkcs10 = self.get_subject_pkcs10() - if new_pkcs10 is not None and new_pkcs10 != old_pkcs10: - self.set_subject_pkcs10(new_pkcs10) - if subject_cert is not None: - logger.debug("PKCS #10 changed, regenerating subject certificate") + def root_newer_than_subject(self): + return self.rpki_root_cert.mtime > os.stat(self.rpki_subject_cert_file).st_mtime + + + def get_subject_cert(self): + try: + x = rpki.x509.X509(Auto_file = self.rpki_subject_cert_file) + logger.debug("Read subject cert %s", self.rpki_subject_cert_file) + return x + except IOError: + return None + + + def set_subject_cert(self, cert): + logger.debug("Writing subject cert %s, SKI %s", self.rpki_subject_cert_file, cert.hSKI()) + with open(self.rpki_subject_cert_file, "wb") as f: + f.write(cert.get_DER()) + + + def del_subject_cert(self): + logger.debug("Deleting subject cert %s", self.rpki_subject_cert_file) + os.remove(self.rpki_subject_cert_file) + + + def get_subject_pkcs10(self): + try: + x = rpki.x509.PKCS10(Auto_file = self.rpki_subject_pkcs10) + logger.debug("Read subject PKCS #10 %s", self.rpki_subject_pkcs10) + return x + except IOError: + return None + + + def set_subject_pkcs10(self, pkcs10): + logger.debug("Writing subject PKCS #10 %s", self.rpki_subject_pkcs10) + with open(self.rpki_subject_pkcs10, "wb") as f: + f.write(pkcs10.get_DER()) + + + def del_subject_pkcs10(self): + logger.debug("Deleting subject PKCS #10 %s", self.rpki_subject_pkcs10) + try: + os.remove(self.rpki_subject_pkcs10) + except OSError: + pass + + + def issue_subject_cert_maybe(self, new_pkcs10): + now = rpki.sundial.now() + subject_cert = self.get_subject_cert() + if subject_cert is None: + subject_cert_hash = None + else: + subject_cert_hash = rpki.x509.sha256(subject_cert.get_DER()).encode("hex") + old_pkcs10 = self.get_subject_pkcs10() + if new_pkcs10 is not None and new_pkcs10 != old_pkcs10: + self.set_subject_pkcs10(new_pkcs10) + if subject_cert is not None: + logger.debug("PKCS #10 changed, regenerating subject certificate") + self.revoke_subject_cert(now) + subject_cert = None + if subject_cert is not None and subject_cert.getNotAfter() <= now + self.rpki_subject_regen: + logger.debug("Subject certificate has reached expiration threshold, regenerating") + self.revoke_subject_cert(now) + subject_cert = None + if subject_cert is not None and self.root_newer_than_subject(): + logger.debug("Root certificate has changed, regenerating subject") + self.revoke_subject_cert(now) + subject_cert = None + if subject_cert is not None: + return subject_cert, None + pkcs10 = old_pkcs10 if new_pkcs10 is None else new_pkcs10 + if pkcs10 is None: + logger.debug("No PKCS #10 request, can't generate subject certificate yet") + return None, None + resources = self.rpki_root_cert.get_3779resources() + notAfter = now + self.rpki_subject_lifetime + logger.info("Generating subject cert %s with resources %s, expires %s", + self.rpki_subject_cert_uri, resources, notAfter) + req_key = pkcs10.getPublicKey() + req_sia = pkcs10.get_SIA() + self.next_serial_number() + subject_cert = self.rpki_root_cert.issue( + keypair = self.rpki_root_key, + subject_key = req_key, + serial = self.serial_number, + sia = req_sia, + aia = self.rpki_root_cert_uri, + crldp = self.rpki_root_crl_uri, + resources = resources, + notBefore = now, + notAfter = notAfter) + self.set_subject_cert(subject_cert) + pubd_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, + type = "query", version = rpki.publication.version) + pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_subject_cert_uri) + pdu.text = subject_cert.get_Base64() + if subject_cert_hash is not None: + pdu.set("hash", subject_cert_hash) + self.generate_crl_and_manifest(now, pubd_msg) + return subject_cert, pubd_msg + + + def generate_crl_and_manifest(self, now, pubd_msg): + subject_cert = self.get_subject_cert() + self.next_serial_number() + self.next_crl_number() + while self.revoked and self.revoked[0][1] + 2 * self.rpki_subject_regen < now: + del self.revoked[0] + crl = rpki.x509.CRL.generate( + keypair = self.rpki_root_key, + issuer = self.rpki_root_cert, + serial = self.crl_number, + thisUpdate = now, + nextUpdate = now + self.rpki_subject_regen, + revokedCertificates = self.revoked) + crl_hash = self.read_hash_maybe(self.rpki_root_crl_file) + logger.debug("Writing CRL %s", self.rpki_root_crl_file) + with open(self.rpki_root_crl_file, "wb") as f: + f.write(crl.get_DER()) + pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_crl_uri) + pdu.text = crl.get_Base64() + if crl_hash is not None: + pdu.set("hash", crl_hash) + manifest_content = [(os.path.basename(self.rpki_root_crl_uri), crl)] + if subject_cert is not None: + manifest_content.append((os.path.basename(self.rpki_subject_cert_uri), subject_cert)) + manifest_resources = rpki.resource_set.resource_bag.from_inheritance() + manifest_keypair = rpki.x509.RSA.generate() + manifest_cert = self.rpki_root_cert.issue( + keypair = self.rpki_root_key, + subject_key = manifest_keypair.get_public(), + serial = self.serial_number, + sia = (None, None, self.rpki_root_manifest_uri, self.rrdp_notification_uri), + aia = self.rpki_root_cert_uri, + crldp = self.rpki_root_crl_uri, + resources = manifest_resources, + notBefore = now, + notAfter = now + self.rpki_subject_lifetime, + is_ca = False) + manifest = rpki.x509.SignedManifest.build( + serial = self.crl_number, + thisUpdate = now, + nextUpdate = now + self.rpki_subject_regen, + names_and_objs = manifest_content, + keypair = manifest_keypair, + certs = manifest_cert) + mft_hash = self.read_hash_maybe(self.rpki_root_manifest_file) + logger.debug("Writing manifest %s", self.rpki_root_manifest_file) + with open(self.rpki_root_manifest_file, "wb") as f: + f.write(manifest.get_DER()) + pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_manifest_uri) + pdu.text = manifest.get_Base64() + if mft_hash is not None: + pdu.set("hash", mft_hash) + cer_hash = rpki.x509.sha256(self.rpki_root_cert.get_DER()).encode("hex") + if cer_hash != self.rpki_root_cert_hash: + pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_cert_uri) + pdu.text = self.rpki_root_cert.get_Base64() + if self.rpki_root_cert_hash is not None: + pdu.set("hash", self.rpki_root_cert_hash) + self.rpki_root_cert_hash = cer_hash + + + @staticmethod + def read_hash_maybe(fn): + try: + with open(fn, "rb") as f: + return rpki.x509.sha256(f.read()).encode("hex") + except IOError: + return None + + + def revoke_subject_cert(self, now): + self.revoked.append((self.get_subject_cert().getSerial(), now)) + + + def publish(self, q_msg): + if q_msg is None: + return + assert len(q_msg) > 0 + + if not all(q_pdu.get("hash") is not None for q_pdu in q_msg): + logger.debug("Some publication PDUs are missing hashes, checking published data...") + q = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, + type = "query", version = rpki.publication.version) + SubElement(q, rpki.publication.tag_list) + published_hash = dict((r.get("uri"), r.get("hash")) for r in self.call_pubd(q)) + for q_pdu in q_msg: + q_uri = q_pdu.get("uri") + if q_pdu.get("hash") is None and published_hash.get(q_uri) is not None: + logger.debug("Updating hash of %s to %s from previously published data", q_uri, published_hash[q_uri]) + q_pdu.set("hash", published_hash[q_uri]) + + r_msg = self.call_pubd(q_msg) + if len(q_msg) != len(r_msg): + raise rpki.exceptions.BadPublicationReply("Wrong number of response PDUs from pubd: sent %s, got %s" % (len(q_msg), len(r_msg))) + + + def call_pubd(self, q_msg): + for q_pdu in q_msg: + logger.info("Sending %s to pubd", q_pdu.get("uri")) + r_msg = rpki.http_simple.client( + proto_cms_msg = rpki.publication.cms_msg, + client_key = self.rootd_bpki_key, + client_cert = self.rootd_bpki_cert, + client_crl = self.rootd_bpki_crl, + server_ta = self.bpki_ta, + server_cert = self.pubd_bpki_cert, + url = self.pubd_url, + q_msg = q_msg, + replay_track = self.pubd_replay_tracker) + rpki.publication.raise_if_error(r_msg) + return r_msg + + + def compose_response(self, r_msg, pkcs10 = None): + subject_cert, pubd_msg = self.issue_subject_cert_maybe(pkcs10) + bag = self.rpki_root_cert.get_3779resources() + rc = SubElement(r_msg, rpki.up_down.tag_class, + class_name = self.rpki_class_name, + cert_url = str(rpki.up_down.multi_uri(self.rpki_root_cert_uri)), + resource_set_as = str(bag.asn), + resource_set_ipv4 = str(bag.v4), + resource_set_ipv6 = str(bag.v6), + resource_set_notafter = str(bag.valid_until)) + if subject_cert is not None: + c = SubElement(rc, rpki.up_down.tag_certificate, + cert_url = str(rpki.up_down.multi_uri(self.rpki_subject_cert_uri))) + c.text = subject_cert.get_Base64() + SubElement(rc, rpki.up_down.tag_issuer).text = self.rpki_root_cert.get_Base64() + self.publish(pubd_msg) + + + def handle_list(self, q_msg, r_msg): + self.compose_response(r_msg) + + + def handle_issue(self, q_msg, r_msg): + # This is where we'd check q_msg[0].get("class_name") if this weren't rootd. + self.compose_response(r_msg, rpki.x509.PKCS10(Base64 = q_msg[0].text)) + + + def handle_revoke(self, q_msg, r_msg): + class_name = q_msg[0].get("class_name") + ski = q_msg[0].get("ski") + logger.debug("Revocation requested for class %s SKI %s", class_name, ski) + subject_cert = self.get_subject_cert() + if subject_cert is None: + logger.debug("No subject certificate, nothing to revoke") + raise rpki.exceptions.NotInDatabase + if subject_cert.gSKI() != ski: + logger.debug("Subject certificate has different SKI %s, not revoking", subject_cert.gSKI()) + raise rpki.exceptions.NotInDatabase + logger.debug("Revoking certificate %s", ski) + now = rpki.sundial.now() + pubd_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, + type = "query", version = rpki.publication.version) self.revoke_subject_cert(now) - subject_cert = None - if subject_cert is not None and subject_cert.getNotAfter() <= now + self.rpki_subject_regen: - logger.debug("Subject certificate has reached expiration threshold, regenerating") - self.revoke_subject_cert(now) - subject_cert = None - if subject_cert is not None and self.root_newer_than_subject(): - logger.debug("Root certificate has changed, regenerating subject") - self.revoke_subject_cert(now) - subject_cert = None - if subject_cert is not None: - return subject_cert, None - pkcs10 = old_pkcs10 if new_pkcs10 is None else new_pkcs10 - if pkcs10 is None: - logger.debug("No PKCS #10 request, can't generate subject certificate yet") - return None, None - resources = self.rpki_root_cert.get_3779resources() - notAfter = now + self.rpki_subject_lifetime - logger.info("Generating subject cert %s with resources %s, expires %s", - self.rpki_subject_cert_uri, resources, notAfter) - req_key = pkcs10.getPublicKey() - req_sia = pkcs10.get_SIA() - self.next_serial_number() - subject_cert = self.rpki_root_cert.issue( - keypair = self.rpki_root_key, - subject_key = req_key, - serial = self.serial_number, - sia = req_sia, - aia = self.rpki_root_cert_uri, - crldp = self.rpki_root_crl_uri, - resources = resources, - notBefore = now, - notAfter = notAfter) - self.set_subject_cert(subject_cert) - pubd_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, - type = "query", version = rpki.publication.version) - pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_subject_cert_uri) - pdu.text = subject_cert.get_Base64() - if subject_cert_hash is not None: - pdu.set("hash", subject_cert_hash) - self.generate_crl_and_manifest(now, pubd_msg) - return subject_cert, pubd_msg - - - def generate_crl_and_manifest(self, now, pubd_msg): - subject_cert = self.get_subject_cert() - self.next_serial_number() - self.next_crl_number() - while self.revoked and self.revoked[0][1] + 2 * self.rpki_subject_regen < now: - del self.revoked[0] - crl = rpki.x509.CRL.generate( - keypair = self.rpki_root_key, - issuer = self.rpki_root_cert, - serial = self.crl_number, - thisUpdate = now, - nextUpdate = now + self.rpki_subject_regen, - revokedCertificates = self.revoked) - crl_hash = self.read_hash_maybe(self.rpki_root_crl_file) - logger.debug("Writing CRL %s", self.rpki_root_crl_file) - with open(self.rpki_root_crl_file, "wb") as f: - f.write(crl.get_DER()) - pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_crl_uri) - pdu.text = crl.get_Base64() - if crl_hash is not None: - pdu.set("hash", crl_hash) - manifest_content = [(os.path.basename(self.rpki_root_crl_uri), crl)] - if subject_cert is not None: - manifest_content.append((os.path.basename(self.rpki_subject_cert_uri), subject_cert)) - manifest_resources = rpki.resource_set.resource_bag.from_inheritance() - manifest_keypair = rpki.x509.RSA.generate() - manifest_cert = self.rpki_root_cert.issue( - keypair = self.rpki_root_key, - subject_key = manifest_keypair.get_public(), - serial = self.serial_number, - sia = (None, None, self.rpki_root_manifest_uri, self.rrdp_notification_uri), - aia = self.rpki_root_cert_uri, - crldp = self.rpki_root_crl_uri, - resources = manifest_resources, - notBefore = now, - notAfter = now + self.rpki_subject_lifetime, - is_ca = False) - manifest = rpki.x509.SignedManifest.build( - serial = self.crl_number, - thisUpdate = now, - nextUpdate = now + self.rpki_subject_regen, - names_and_objs = manifest_content, - keypair = manifest_keypair, - certs = manifest_cert) - mft_hash = self.read_hash_maybe(self.rpki_root_manifest_file) - logger.debug("Writing manifest %s", self.rpki_root_manifest_file) - with open(self.rpki_root_manifest_file, "wb") as f: - f.write(manifest.get_DER()) - pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_manifest_uri) - pdu.text = manifest.get_Base64() - if mft_hash is not None: - pdu.set("hash", mft_hash) - cer_hash = rpki.x509.sha256(self.rpki_root_cert.get_DER()).encode("hex") - if cer_hash != self.rpki_root_cert_hash: - pdu = SubElement(pubd_msg, rpki.publication.tag_publish, uri = self.rpki_root_cert_uri) - pdu.text = self.rpki_root_cert.get_Base64() - if self.rpki_root_cert_hash is not None: - pdu.set("hash", self.rpki_root_cert_hash) - self.rpki_root_cert_hash = cer_hash - - - @staticmethod - def read_hash_maybe(fn): - try: - with open(fn, "rb") as f: - return rpki.x509.sha256(f.read()).encode("hex") - except IOError: - return None - - - def revoke_subject_cert(self, now): - self.revoked.append((self.get_subject_cert().getSerial(), now)) - - - def publish(self, q_msg): - if q_msg is None: - return - assert len(q_msg) > 0 - - if not all(q_pdu.get("hash") is not None for q_pdu in q_msg): - logger.debug("Some publication PDUs are missing hashes, checking published data...") - q = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, - type = "query", version = rpki.publication.version) - SubElement(q, rpki.publication.tag_list) - published_hash = dict((r.get("uri"), r.get("hash")) for r in self.call_pubd(q)) - for q_pdu in q_msg: - q_uri = q_pdu.get("uri") - if q_pdu.get("hash") is None and published_hash.get(q_uri) is not None: - logger.debug("Updating hash of %s to %s from previously published data", q_uri, published_hash[q_uri]) - q_pdu.set("hash", published_hash[q_uri]) - - r_msg = self.call_pubd(q_msg) - if len(q_msg) != len(r_msg): - raise rpki.exceptions.BadPublicationReply("Wrong number of response PDUs from pubd: sent %s, got %s" % (len(q_msg), len(r_msg))) - - - def call_pubd(self, q_msg): - for q_pdu in q_msg: - logger.info("Sending %s to pubd", q_pdu.get("uri")) - r_msg = rpki.http_simple.client( - proto_cms_msg = rpki.publication.cms_msg, - client_key = self.rootd_bpki_key, - client_cert = self.rootd_bpki_cert, - client_crl = self.rootd_bpki_crl, - server_ta = self.bpki_ta, - server_cert = self.pubd_bpki_cert, - url = self.pubd_url, - q_msg = q_msg, - replay_track = self.pubd_replay_tracker) - rpki.publication.raise_if_error(r_msg) - return r_msg - - - def compose_response(self, r_msg, pkcs10 = None): - subject_cert, pubd_msg = self.issue_subject_cert_maybe(pkcs10) - bag = self.rpki_root_cert.get_3779resources() - rc = SubElement(r_msg, rpki.up_down.tag_class, - class_name = self.rpki_class_name, - cert_url = str(rpki.up_down.multi_uri(self.rpki_root_cert_uri)), - resource_set_as = str(bag.asn), - resource_set_ipv4 = str(bag.v4), - resource_set_ipv6 = str(bag.v6), - resource_set_notafter = str(bag.valid_until)) - if subject_cert is not None: - c = SubElement(rc, rpki.up_down.tag_certificate, - cert_url = str(rpki.up_down.multi_uri(self.rpki_subject_cert_uri))) - c.text = subject_cert.get_Base64() - SubElement(rc, rpki.up_down.tag_issuer).text = self.rpki_root_cert.get_Base64() - self.publish(pubd_msg) - - - def handle_list(self, q_msg, r_msg): - self.compose_response(r_msg) - - - def handle_issue(self, q_msg, r_msg): - # This is where we'd check q_msg[0].get("class_name") if this weren't rootd. - self.compose_response(r_msg, rpki.x509.PKCS10(Base64 = q_msg[0].text)) - - - def handle_revoke(self, q_msg, r_msg): - class_name = q_msg[0].get("class_name") - ski = q_msg[0].get("ski") - logger.debug("Revocation requested for class %s SKI %s", class_name, ski) - subject_cert = self.get_subject_cert() - if subject_cert is None: - logger.debug("No subject certificate, nothing to revoke") - raise rpki.exceptions.NotInDatabase - if subject_cert.gSKI() != ski: - logger.debug("Subject certificate has different SKI %s, not revoking", subject_cert.gSKI()) - raise rpki.exceptions.NotInDatabase - logger.debug("Revoking certificate %s", ski) - now = rpki.sundial.now() - pubd_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, - type = "query", version = rpki.publication.version) - self.revoke_subject_cert(now) - self.del_subject_cert() - self.del_subject_pkcs10() - SubElement(r_msg, q_msg[0].tag, class_name = class_name, ski = ski) - self.generate_crl_and_manifest(now, pubd_msg) - self.publish(pubd_msg) - - - # Need to do something about mapping exceptions to up-down error - # codes, right now everything shows up as "internal error". - # - #exceptions = { - # rpki.exceptions.ClassNameUnknown : 1201, - # rpki.exceptions.NoActiveCA : 1202, - # (rpki.exceptions.ClassNameUnknown, revoke_pdu) : 1301, - # (rpki.exceptions.NotInDatabase, revoke_pdu) : 1302 } - # - # Might be that what we want here is a subclass of - # rpki.exceptions.RPKI_Exception which carries an extra data field - # for the up-down error code, so that we can add the correct code - # when we instantiate it. - # - # There are also a few that are also schema violations, which means - # we'd have to catch them before validating or pick them out of a - # message that failed validation or otherwise break current - # modularity. Maybe an optional pre-validation check method hook in - # rpki.x509.XML_CMS_object which we can use to intercept such things? - - - def handler(self, request, q_der): - try: - q_cms = rpki.up_down.cms_msg(DER = q_der) - q_msg = q_cms.unwrap((self.bpki_ta, self.child_bpki_cert)) - q_type = q_msg.get("type") - logger.info("Serving %s query", q_type) - r_msg = Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap, version = rpki.up_down.version, - sender = q_msg.get("recipient"), recipient = q_msg.get("sender"), type = q_type + "_response") - try: - self.rpkid_cms_timestamp = q_cms.check_replay(self.rpkid_cms_timestamp, request.path) - getattr(self, "handle_" + q_type)(q_msg, r_msg) - except Exception, e: - logger.exception("Exception processing up-down %s message", q_type) - rpki.up_down.generate_error_response_from_exception(r_msg, e, q_type) - request.send_cms_response(rpki.up_down.cms_msg().wrap(r_msg, self.rootd_bpki_key, self.rootd_bpki_cert, - self.rootd_bpki_crl if self.include_bpki_crl else None)) - except Exception, e: - logger.exception("Unhandled exception processing up-down message") - request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) - - - def next_crl_number(self): - if self.crl_number is None: - try: - crl = rpki.x509.CRL(DER_file = self.rpki_root_crl_file) - self.crl_number = crl.getCRLNumber() - except: # pylint: disable=W0702 - self.crl_number = 0 - self.crl_number += 1 - return self.crl_number - - - def next_serial_number(self): - if self.serial_number is None: - subject_cert = self.get_subject_cert() - if subject_cert is not None: - self.serial_number = subject_cert.getSerial() + 1 - else: - self.serial_number = 0 - self.serial_number += 1 - return self.serial_number - - - def __init__(self): - self.serial_number = None - self.crl_number = None - self.revoked = [] - self.rpkid_cms_timestamp = None - self.pubd_replay_tracker = ReplayTracker() - - os.environ["TZ"] = "UTC" - time.tzset() - - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument("-c", "--config", - help = "override default location of configuration file") - parser.add_argument("-f", "--foreground", action = "store_true", - help = "do not daemonize") - parser.add_argument("--pidfile", - help = "override default location of pid file") - rpki.log.argparse_setup(parser) - args = parser.parse_args() - - rpki.log.init("rootd", args) - - self.cfg = rpki.config.parser(set_filename = args.config, section = "rootd") - self.cfg.set_global_flags() - - if not args.foreground: - rpki.daemonize.daemon(pidfile = args.pidfile) - - self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) - self.rootd_bpki_key = rpki.x509.RSA( Auto_update = self.cfg.get("rootd-bpki-key")) - self.rootd_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("rootd-bpki-cert")) - self.rootd_bpki_crl = rpki.x509.CRL( Auto_update = self.cfg.get("rootd-bpki-crl")) - self.child_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("child-bpki-cert")) - - if self.cfg.has_option("pubd-bpki-cert"): - self.pubd_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("pubd-bpki-cert")) - else: - self.pubd_bpki_cert = None - - self.http_server_host = self.cfg.get("server-host", "") - self.http_server_port = self.cfg.getint("server-port") - - self.rpki_class_name = self.cfg.get("rpki-class-name") - - self.rpki_root_key = rpki.x509.RSA( Auto_update = self.cfg.get("rpki-root-key-file")) - self.rpki_root_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpki-root-cert-file")) - self.rpki_root_cert_uri = self.cfg.get("rpki-root-cert-uri") - self.rpki_root_cert_hash = None - - self.rpki_root_manifest_file = self.cfg.get("rpki-root-manifest-file") - self.rpki_root_manifest_uri = self.cfg.get("rpki-root-manifest-uri") - - self.rpki_root_crl_file = self.cfg.get("rpki-root-crl-file") - self.rpki_root_crl_uri = self.cfg.get("rpki-root-crl-uri") - - self.rpki_subject_cert_file = self.cfg.get("rpki-subject-cert-file") - self.rpki_subject_cert_uri = self.cfg.get("rpki-subject-cert-uri") - self.rpki_subject_pkcs10 = self.cfg.get("rpki-subject-pkcs10-file") - self.rpki_subject_lifetime = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-lifetime", "8w")) - self.rpki_subject_regen = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-regen", - self.rpki_subject_lifetime.convert_to_seconds() / 2)) - - self.include_bpki_crl = self.cfg.getboolean("include-bpki-crl", False) - - self.pubd_url = self.cfg.get("pubd-contact-uri") - - self.rrdp_notification_uri = self.cfg.get("rrdp-notification-uri") - - rpki.http_simple.server(host = self.http_server_host, - port = self.http_server_port, - handlers = (("/", self.handler, rpki.up_down.allowed_content_types),)) + self.del_subject_cert() + self.del_subject_pkcs10() + SubElement(r_msg, q_msg[0].tag, class_name = class_name, ski = ski) + self.generate_crl_and_manifest(now, pubd_msg) + self.publish(pubd_msg) + + + # Need to do something about mapping exceptions to up-down error + # codes, right now everything shows up as "internal error". + # + #exceptions = { + # rpki.exceptions.ClassNameUnknown : 1201, + # rpki.exceptions.NoActiveCA : 1202, + # (rpki.exceptions.ClassNameUnknown, revoke_pdu) : 1301, + # (rpki.exceptions.NotInDatabase, revoke_pdu) : 1302 } + # + # Might be that what we want here is a subclass of + # rpki.exceptions.RPKI_Exception which carries an extra data field + # for the up-down error code, so that we can add the correct code + # when we instantiate it. + # + # There are also a few that are also schema violations, which means + # we'd have to catch them before validating or pick them out of a + # message that failed validation or otherwise break current + # modularity. Maybe an optional pre-validation check method hook in + # rpki.x509.XML_CMS_object which we can use to intercept such things? + + + def handler(self, request, q_der): + try: + q_cms = rpki.up_down.cms_msg(DER = q_der) + q_msg = q_cms.unwrap((self.bpki_ta, self.child_bpki_cert)) + q_type = q_msg.get("type") + logger.info("Serving %s query", q_type) + r_msg = Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap, version = rpki.up_down.version, + sender = q_msg.get("recipient"), recipient = q_msg.get("sender"), type = q_type + "_response") + try: + self.rpkid_cms_timestamp = q_cms.check_replay(self.rpkid_cms_timestamp, request.path) + getattr(self, "handle_" + q_type)(q_msg, r_msg) + except Exception, e: + logger.exception("Exception processing up-down %s message", q_type) + rpki.up_down.generate_error_response_from_exception(r_msg, e, q_type) + request.send_cms_response(rpki.up_down.cms_msg().wrap(r_msg, self.rootd_bpki_key, self.rootd_bpki_cert, + self.rootd_bpki_crl if self.include_bpki_crl else None)) + except Exception, e: + logger.exception("Unhandled exception processing up-down message") + request.send_error(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + + + def next_crl_number(self): + if self.crl_number is None: + try: + crl = rpki.x509.CRL(DER_file = self.rpki_root_crl_file) + self.crl_number = crl.getCRLNumber() + except: # pylint: disable=W0702 + self.crl_number = 0 + self.crl_number += 1 + return self.crl_number + + + def next_serial_number(self): + if self.serial_number is None: + subject_cert = self.get_subject_cert() + if subject_cert is not None: + self.serial_number = subject_cert.getSerial() + 1 + else: + self.serial_number = 0 + self.serial_number += 1 + return self.serial_number + + + def __init__(self): + self.serial_number = None + self.crl_number = None + self.revoked = [] + self.rpkid_cms_timestamp = None + self.pubd_replay_tracker = ReplayTracker() + + os.environ["TZ"] = "UTC" + time.tzset() + + parser = argparse.ArgumentParser(description = __doc__) + parser.add_argument("-c", "--config", + help = "override default location of configuration file") + parser.add_argument("-f", "--foreground", action = "store_true", + help = "do not daemonize") + parser.add_argument("--pidfile", + help = "override default location of pid file") + rpki.log.argparse_setup(parser) + args = parser.parse_args() + + rpki.log.init("rootd", args) + + self.cfg = rpki.config.parser(set_filename = args.config, section = "rootd") + self.cfg.set_global_flags() + + if not args.foreground: + rpki.daemonize.daemon(pidfile = args.pidfile) + + self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) + self.rootd_bpki_key = rpki.x509.RSA( Auto_update = self.cfg.get("rootd-bpki-key")) + self.rootd_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("rootd-bpki-cert")) + self.rootd_bpki_crl = rpki.x509.CRL( Auto_update = self.cfg.get("rootd-bpki-crl")) + self.child_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("child-bpki-cert")) + + if self.cfg.has_option("pubd-bpki-cert"): + self.pubd_bpki_cert = rpki.x509.X509(Auto_update = self.cfg.get("pubd-bpki-cert")) + else: + self.pubd_bpki_cert = None + + self.http_server_host = self.cfg.get("server-host", "") + self.http_server_port = self.cfg.getint("server-port") + + self.rpki_class_name = self.cfg.get("rpki-class-name") + + self.rpki_root_key = rpki.x509.RSA( Auto_update = self.cfg.get("rpki-root-key-file")) + self.rpki_root_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpki-root-cert-file")) + self.rpki_root_cert_uri = self.cfg.get("rpki-root-cert-uri") + self.rpki_root_cert_hash = None + + self.rpki_root_manifest_file = self.cfg.get("rpki-root-manifest-file") + self.rpki_root_manifest_uri = self.cfg.get("rpki-root-manifest-uri") + + self.rpki_root_crl_file = self.cfg.get("rpki-root-crl-file") + self.rpki_root_crl_uri = self.cfg.get("rpki-root-crl-uri") + + self.rpki_subject_cert_file = self.cfg.get("rpki-subject-cert-file") + self.rpki_subject_cert_uri = self.cfg.get("rpki-subject-cert-uri") + self.rpki_subject_pkcs10 = self.cfg.get("rpki-subject-pkcs10-file") + self.rpki_subject_lifetime = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-lifetime", "8w")) + self.rpki_subject_regen = rpki.sundial.timedelta.parse(self.cfg.get("rpki-subject-regen", + self.rpki_subject_lifetime.convert_to_seconds() / 2)) + + self.include_bpki_crl = self.cfg.getboolean("include-bpki-crl", False) + + self.pubd_url = self.cfg.get("pubd-contact-uri") + + self.rrdp_notification_uri = self.cfg.get("rrdp-notification-uri") + + rpki.http_simple.server(host = self.http_server_host, + port = self.http_server_port, + handlers = (("/", self.handler, rpki.up_down.allowed_content_types),)) diff --git a/rpki/rpkic.py b/rpki/rpkic.py index 4b9ffedb..9cde75fb 100644 --- a/rpki/rpkic.py +++ b/rpki/rpkic.py @@ -53,813 +53,812 @@ module_doc = __doc__ class main(Cmd): - prompt = "rpkic> " - - completedefault = Cmd.filename_complete - - # Top-level argparser, for stuff that one might want when starting - # up the interactive command loop. Not sure -i belongs here, but - # it's harmless so leave it here for the moment. - - top_argparser = argparse.ArgumentParser(add_help = False) - top_argparser.add_argument("-c", "--config", - help = "override default location of configuration file") - top_argparser.add_argument("-i", "--identity", "--handle", - help = "set initial entity handdle") - top_argparser.add_argument("--profile", - help = "enable profiling, saving data to PROFILE") - - # Argparser for non-interactive commands (no command loop). - - full_argparser = argparse.ArgumentParser(parents = [top_argparser], - description = module_doc) - argsubparsers = full_argparser.add_subparsers(title = "Commands", metavar = "") - - def __init__(self): - - Cmd.__init__(self) - os.environ["TZ"] = "UTC" - time.tzset() - - # Try parsing just the arguments that make sense if we're - # going to be running an interactive command loop. If that - # parses everything, we're interactive, otherwise, it's either - # a non-interactive command or a parse error, so we let the full - # parser sort that out for us. - - args, argv = self.top_argparser.parse_known_args() - self.interactive = not argv - if not self.interactive: - args = self.full_argparser.parse_args() - - self.cfg_file = args.config - self.handle = args.identity + prompt = "rpkic> " + + completedefault = Cmd.filename_complete + + # Top-level argparser, for stuff that one might want when starting + # up the interactive command loop. Not sure -i belongs here, but + # it's harmless so leave it here for the moment. + + top_argparser = argparse.ArgumentParser(add_help = False) + top_argparser.add_argument("-c", "--config", + help = "override default location of configuration file") + top_argparser.add_argument("-i", "--identity", "--handle", + help = "set initial entity handdle") + top_argparser.add_argument("--profile", + help = "enable profiling, saving data to PROFILE") + + # Argparser for non-interactive commands (no command loop). + + full_argparser = argparse.ArgumentParser(parents = [top_argparser], + description = module_doc) + argsubparsers = full_argparser.add_subparsers(title = "Commands", metavar = "") + + def __init__(self): + + Cmd.__init__(self) + os.environ["TZ"] = "UTC" + time.tzset() + + # Try parsing just the arguments that make sense if we're + # going to be running an interactive command loop. If that + # parses everything, we're interactive, otherwise, it's either + # a non-interactive command or a parse error, so we let the full + # parser sort that out for us. + + args, argv = self.top_argparser.parse_known_args() + self.interactive = not argv + if not self.interactive: + args = self.full_argparser.parse_args() + + self.cfg_file = args.config + self.handle = args.identity - if args.profile: - import cProfile - prof = cProfile.Profile() - try: - prof.runcall(self.main, args) - finally: - prof.dump_stats(args.profile) - print "Dumped profile data to %s" % args.profile - else: - self.main(args) + if args.profile: + import cProfile + prof = cProfile.Profile() + try: + prof.runcall(self.main, args) + finally: + prof.dump_stats(args.profile) + print "Dumped profile data to %s" % args.profile + else: + self.main(args) - def main(self, args): - rpki.log.init("rpkic") - self.read_config() - if self.interactive: - self.cmdloop_with_history() - else: - args.func(self, args) + def main(self, args): + rpki.log.init("rpkic") + self.read_config() + if self.interactive: + self.cmdloop_with_history() + else: + args.func(self, args) - def read_config(self): - global rpki # pylint: disable=W0602 - - try: - cfg = rpki.config.parser(set_filename = self.cfg_file, section = "myrpki") - cfg.set_global_flags() - except IOError, e: - sys.exit("%s: %s" % (e.strerror, e.filename)) + def read_config(self): + global rpki # pylint: disable=W0602 + + try: + cfg = rpki.config.parser(set_filename = self.cfg_file, section = "myrpki") + cfg.set_global_flags() + except IOError, e: + sys.exit("%s: %s" % (e.strerror, e.filename)) - self.histfile = cfg.get("history_file", os.path.expanduser("~/.rpkic_history")) - self.autosync = cfg.getboolean("autosync", True, section = "rpkic") + self.histfile = cfg.get("history_file", os.path.expanduser("~/.rpkic_history")) + self.autosync = cfg.getboolean("autosync", True, section = "rpkic") - os.environ.update(DJANGO_SETTINGS_MODULE = "rpki.django_settings.irdb") + os.environ.update(DJANGO_SETTINGS_MODULE = "rpki.django_settings.irdb") - import django - django.setup() + import django + django.setup() - import rpki.irdb # pylint: disable=W0621 + import rpki.irdb # pylint: disable=W0621 - try: - rpki.irdb.models.ca_certificate_lifetime = rpki.sundial.timedelta.parse( - cfg.get("bpki_ca_certificate_lifetime", section = "rpkic")) - except rpki.config.ConfigParser.Error: - pass + try: + rpki.irdb.models.ca_certificate_lifetime = rpki.sundial.timedelta.parse( + cfg.get("bpki_ca_certificate_lifetime", section = "rpkic")) + except rpki.config.ConfigParser.Error: + pass - try: - rpki.irdb.models.ee_certificate_lifetime = rpki.sundial.timedelta.parse( - cfg.get("bpki_ee_certificate_lifetime", section = "rpkic")) - except rpki.config.ConfigParser.Error: - pass + try: + rpki.irdb.models.ee_certificate_lifetime = rpki.sundial.timedelta.parse( + cfg.get("bpki_ee_certificate_lifetime", section = "rpkic")) + except rpki.config.ConfigParser.Error: + pass - try: - rpki.irdb.models.crl_interval = rpki.sundial.timedelta.parse( - cfg.get("bpki_crl_interval", section = "rpkic")) - except rpki.config.ConfigParser.Error: - pass + try: + rpki.irdb.models.crl_interval = rpki.sundial.timedelta.parse( + cfg.get("bpki_crl_interval", section = "rpkic")) + except rpki.config.ConfigParser.Error: + pass - self.zoo = rpki.irdb.Zookeeper(cfg = cfg, handle = self.handle, logstream = sys.stdout) + self.zoo = rpki.irdb.Zookeeper(cfg = cfg, handle = self.handle, logstream = sys.stdout) - def do_help(self, arg): - """ - List available commands with "help" or detailed help with "help cmd". - """ + def do_help(self, arg): + """ + List available commands with "help" or detailed help with "help cmd". + """ - argv = arg.split() + argv = arg.split() - if not argv: - #return self.full_argparser.print_help() - return self.print_topics( - self.doc_header, - sorted(set(name[3:] for name in self.get_names() - if name.startswith("do_") - and getattr(self, name).__doc__)), - 15, 80) + if not argv: + #return self.full_argparser.print_help() + return self.print_topics( + self.doc_header, + sorted(set(name[3:] for name in self.get_names() + if name.startswith("do_") + and getattr(self, name).__doc__)), + 15, 80) - try: - return getattr(self, "help_" + argv[0])() - except AttributeError: - pass + try: + return getattr(self, "help_" + argv[0])() + except AttributeError: + pass - func = getattr(self, "do_" + argv[0], None) + func = getattr(self, "do_" + argv[0], None) - try: - return func.argparser.print_help() - except AttributeError: - pass + try: + return func.argparser.print_help() + except AttributeError: + pass - try: - return self.stdout.write(func.__doc__ + "\n") - except AttributeError: - pass + try: + return self.stdout.write(func.__doc__ + "\n") + except AttributeError: + pass - self.stdout.write((self.nohelp + "\n") % arg) + self.stdout.write((self.nohelp + "\n") % arg) - def irdb_handle_complete(self, manager, text, line, begidx, endidx): - return [obj.handle for obj in manager.all() if obj.handle and obj.handle.startswith(text)] + def irdb_handle_complete(self, manager, text, line, begidx, endidx): + return [obj.handle for obj in manager.all() if obj.handle and obj.handle.startswith(text)] - @parsecmd(argsubparsers, - cmdarg("handle", help = "new handle")) - def do_select_identity(self, args): - """ - Select an identity handle for use with later commands. - """ + @parsecmd(argsubparsers, + cmdarg("handle", help = "new handle")) + def do_select_identity(self, args): + """ + Select an identity handle for use with later commands. + """ - self.zoo.reset_identity(args.handle) + self.zoo.reset_identity(args.handle) - def complete_select_identity(self, *args): - return self.irdb_handle_complete(rpki.irdb.models.ResourceHolderCA.objects, *args) + def complete_select_identity(self, *args): + return self.irdb_handle_complete(rpki.irdb.models.ResourceHolderCA.objects, *args) - @parsecmd(argsubparsers) - def do_initialize(self, args): - """ - Initialize an RPKI installation. DEPRECATED. + @parsecmd(argsubparsers) + def do_initialize(self, args): + """ + Initialize an RPKI installation. DEPRECATED. - This command reads the configuration file, creates the BPKI and - EntityDB directories, generates the initial BPKI certificates, and - creates an XML file describing the resource-holding aspect of this - RPKI installation. - """ + This command reads the configuration file, creates the BPKI and + EntityDB directories, generates the initial BPKI certificates, and + creates an XML file describing the resource-holding aspect of this + RPKI installation. + """ - rootd_case = self.zoo.run_rootd and self.zoo.handle == self.zoo.cfg.get("handle") + rootd_case = self.zoo.run_rootd and self.zoo.handle == self.zoo.cfg.get("handle") - r = self.zoo.initialize() - r.save("%s.identity.xml" % self.zoo.handle, - None if rootd_case else sys.stdout) + r = self.zoo.initialize() + r.save("%s.identity.xml" % self.zoo.handle, + None if rootd_case else sys.stdout) - if rootd_case: - r = self.zoo.configure_rootd() - if r is not None: - r.save("%s.%s.repository-request.xml" % (self.zoo.handle, self.zoo.handle), sys.stdout) + if rootd_case: + r = self.zoo.configure_rootd() + if r is not None: + r.save("%s.%s.repository-request.xml" % (self.zoo.handle, self.zoo.handle), sys.stdout) - self.zoo.write_bpki_files() + self.zoo.write_bpki_files() - @parsecmd(argsubparsers, - cmdarg("handle", help = "handle of entity to create")) - def do_create_identity(self, args): - """ - Create a new resource-holding entity. + @parsecmd(argsubparsers, + cmdarg("handle", help = "handle of entity to create")) + def do_create_identity(self, args): + """ + Create a new resource-holding entity. - Returns XML file describing the new resource holder. + Returns XML file describing the new resource holder. - This command is idempotent: calling it for a resource holder which - already exists returns the existing identity. - """ + This command is idempotent: calling it for a resource holder which + already exists returns the existing identity. + """ - self.zoo.reset_identity(args.handle) + self.zoo.reset_identity(args.handle) - r = self.zoo.initialize_resource_bpki() - r.save("%s.identity.xml" % self.zoo.handle, sys.stdout) + r = self.zoo.initialize_resource_bpki() + r.save("%s.identity.xml" % self.zoo.handle, sys.stdout) - @parsecmd(argsubparsers) - def do_initialize_server_bpki(self, args): - """ - Initialize server BPKI portion of an RPKI installation. + @parsecmd(argsubparsers) + def do_initialize_server_bpki(self, args): + """ + Initialize server BPKI portion of an RPKI installation. - Reads server configuration from configuration file and creates the - server BPKI objects needed to start daemons. - """ + Reads server configuration from configuration file and creates the + server BPKI objects needed to start daemons. + """ - self.zoo.initialize_server_bpki() - self.zoo.write_bpki_files() + self.zoo.initialize_server_bpki() + self.zoo.write_bpki_files() - @parsecmd(argsubparsers) - def do_update_bpki(self, args): - """ - Update BPKI certificates. Assumes an existing RPKI installation. + @parsecmd(argsubparsers) + def do_update_bpki(self, args): + """ + Update BPKI certificates. Assumes an existing RPKI installation. - Basic plan here is to reissue all BPKI certificates we can, right - now. In the long run we might want to be more clever about only - touching ones that need maintenance, but this will do for a start. + Basic plan here is to reissue all BPKI certificates we can, right + now. In the long run we might want to be more clever about only + touching ones that need maintenance, but this will do for a start. - We also reissue CRLs for all CAs. + We also reissue CRLs for all CAs. - Most likely this should be run under cron. - """ + Most likely this should be run under cron. + """ - self.zoo.update_bpki() - self.zoo.write_bpki_files() - try: - self.zoo.synchronize_bpki() - except Exception, e: - print "Couldn't push updated BPKI material into daemons: %s" % e + self.zoo.update_bpki() + self.zoo.write_bpki_files() + try: + self.zoo.synchronize_bpki() + except Exception, e: + print "Couldn't push updated BPKI material into daemons: %s" % e - @parsecmd(argsubparsers, - cmdarg("--child_handle", help = "override default handle for new child"), - cmdarg("--valid_until", help = "override default validity interval"), - cmdarg("child_xml", help = "XML file containing child's identity")) - def do_configure_child(self, args): - """ - Configure a new child of this RPKI entity. + @parsecmd(argsubparsers, + cmdarg("--child_handle", help = "override default handle for new child"), + cmdarg("--valid_until", help = "override default validity interval"), + cmdarg("child_xml", help = "XML file containing child's identity")) + def do_configure_child(self, args): + """ + Configure a new child of this RPKI entity. - This command extracts the child's data from an XML input file, - cross-certifies the child's resource-holding BPKI certificate, and - generates an XML output file describing the relationship between - the child and this parent, including this parent's BPKI data and - up-down protocol service URI. - """ + This command extracts the child's data from an XML input file, + cross-certifies the child's resource-holding BPKI certificate, and + generates an XML output file describing the relationship between + the child and this parent, including this parent's BPKI data and + up-down protocol service URI. + """ - r, child_handle = self.zoo.configure_child(args.child_xml, args.child_handle, args.valid_until) - r.save("%s.%s.parent-response.xml" % (self.zoo.handle, child_handle), sys.stdout) - self.zoo.synchronize_ca() + r, child_handle = self.zoo.configure_child(args.child_xml, args.child_handle, args.valid_until) + r.save("%s.%s.parent-response.xml" % (self.zoo.handle, child_handle), sys.stdout) + self.zoo.synchronize_ca() - @parsecmd(argsubparsers, - cmdarg("child_handle", help = "handle of child to delete")) - def do_delete_child(self, args): - """ - Delete a child of this RPKI entity. - """ - - try: - self.zoo.delete_child(args.child_handle) - self.zoo.synchronize_ca() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - except rpki.irdb.models.Child.DoesNotExist: - print "No such child \"%s\"" % args.child_handle - - def complete_delete_child(self, *args): - return self.irdb_handle_complete(self.zoo.resource_ca.children, *args) - - - @parsecmd(argsubparsers, - cmdarg("--parent_handle", help = "override default handle for new parent"), - cmdarg("parent_xml", help = "XML file containing parent's response")) - def do_configure_parent(self, args): - """ - Configure a new parent of this RPKI entity. - - This command reads the parent's response XML, extracts the - parent's BPKI and service URI information, cross-certifies the - parent's BPKI data into this entity's BPKI, and checks for offers - or referrals of publication service. If a publication offer or - referral is present, we generate a request-for-service message to - that repository, in case the user wants to avail herself of the - referral or offer. - - We do NOT attempt automatic synchronization with rpkid at the - completion of this command, because synchronization at this point - will usually fail due to the repository not being set up yet. If - you know what you are doing and for some reason really want to - synchronize here, run the synchronize command yourself. - """ - - r, parent_handle = self.zoo.configure_parent(args.parent_xml, args.parent_handle) - r.save("%s.%s.repository-request.xml" % (self.zoo.handle, parent_handle), sys.stdout) - - - @parsecmd(argsubparsers, - cmdarg("parent_handle", help = "handle of parent to delete")) - def do_delete_parent(self, args): - """ - Delete a parent of this RPKI entity. - """ - - try: - self.zoo.delete_parent(args.parent_handle) - self.zoo.synchronize_ca() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - except rpki.irdb.models.Parent.DoesNotExist: - print "No such parent \"%s\"" % args.parent_handle - - def complete_delete_parent(self, *args): - return self.irdb_handle_complete(self.zoo.resource_ca.parents, *args) - - - @parsecmd(argsubparsers) - def do_configure_root(self, args): - """ - Configure the current resource holding identity as a root. - - This configures rpkid to talk to rootd as (one of) its parent(s). - Returns repository request XML file like configure_parent does. - """ - - r = self.zoo.configure_rootd() - if r is not None: - r.save("%s.%s.repository-request.xml" % (self.zoo.handle, self.zoo.handle), sys.stdout) - self.zoo.write_bpki_files() - - - @parsecmd(argsubparsers) - def do_delete_root(self, args): - """ - Delete local RPKI root as parent of the current entity. - - This tells the current rpkid identity (<tenant/>) to stop talking to - rootd. - """ - - try: - self.zoo.delete_rootd() - self.zoo.synchronize_ca() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - except rpki.irdb.models.Rootd.DoesNotExist: - print "No associated rootd" - - - @parsecmd(argsubparsers, - cmdarg("--flat", help = "use flat publication scheme", action = "store_true"), - cmdarg("--sia_base", help = "override SIA base value"), - cmdarg("client_xml", help = "XML file containing client request")) - def do_configure_publication_client(self, args): - """ - Configure publication server to know about a new client. - - This command reads the client's request for service, - cross-certifies the client's BPKI data, and generates a response - message containing the repository's BPKI data and service URI. - """ - - r, client_handle = self.zoo.configure_publication_client(args.client_xml, args.sia_base, args.flat) - r.save("%s.repository-response.xml" % client_handle.replace("/", "."), sys.stdout) - try: - self.zoo.synchronize_pubd() - except rpki.irdb.models.Repository.DoesNotExist: - pass - - - @parsecmd(argsubparsers, - cmdarg("client_handle", help = "handle of client to delete")) - def do_delete_publication_client(self, args): - """ - Delete a publication client of this RPKI entity. - """ - - try: - self.zoo.delete_publication_client(args.client_handle) - self.zoo.synchronize_pubd() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - except rpki.irdb.models.Client.DoesNotExist: - print "No such client \"%s\"" % args.client_handle - - def complete_delete_publication_client(self, *args): - return self.irdb_handle_complete(self.zoo.server_ca.clients, *args) - - - @parsecmd(argsubparsers, - cmdarg("--parent_handle", help = "override default parent handle"), - cmdarg("repository_xml", help = "XML file containing repository response")) - def do_configure_repository(self, args): - """ - Configure a publication repository for this RPKI entity. - - This command reads the repository's response to this entity's - request for publication service, extracts and cross-certifies the - BPKI data and service URI, and links the repository data with the - corresponding parent data in our local database. - """ - - self.zoo.configure_repository(args.repository_xml, args.parent_handle) - self.zoo.synchronize_ca() - - - @parsecmd(argsubparsers, - cmdarg("repository_handle", help = "handle of repository to delete")) - def do_delete_repository(self, args): - """ - Delete a repository of this RPKI entity. - """ - - try: - self.zoo.delete_repository(args.repository_handle) - self.zoo.synchronize_ca() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - except rpki.irdb.models.Repository.DoesNotExist: - print "No such repository \"%s\"" % args.repository_handle - - def complete_delete_repository(self, *args): - return self.irdb_handle_complete(self.zoo.resource_ca.repositories, *args) - - - @parsecmd(argsubparsers) - def do_delete_identity(self, args): - """ - Delete the current RPKI identity (rpkid <tenant/> object). - """ - - try: - self.zoo.delete_tenant() - self.zoo.synchronize_deleted_ca() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - - - @parsecmd(argsubparsers, - cmdarg("--valid_until", help = "override default new validity interval"), - cmdarg("child_handle", help = "handle of child to renew")) - def do_renew_child(self, args): - """ - Update validity period for one child entity. - """ - - self.zoo.renew_children(args.child_handle, args.valid_until) - self.zoo.synchronize_ca() - if self.autosync: - self.zoo.run_rpkid_now() - - def complete_renew_child(self, *args): - return self.irdb_handle_complete(self.zoo.resource_ca.children, *args) - - - @parsecmd(argsubparsers, - cmdarg("--valid_until", help = "override default new validity interval")) - def do_renew_all_children(self, args): - """ - Update validity period for all child entities. - """ - - self.zoo.renew_children(None, args.valid_until) - self.zoo.synchronize_ca() - if self.autosync: - self.zoo.run_rpkid_now() - - - @parsecmd(argsubparsers, - cmdarg("prefixes_csv", help = "CSV file listing prefixes")) - def do_load_prefixes(self, args): - """ - Load prefixes into IRDB from CSV file. - """ - - self.zoo.load_prefixes(args.prefixes_csv, True) - if self.autosync: - self.zoo.run_rpkid_now() - - - @parsecmd(argsubparsers) - def do_show_child_resources(self, args): - """ - Show resources assigned to children. - """ - - for child in self.zoo.resource_ca.children.all(): - resources = child.resource_bag - print "Child:", child.handle - if resources.asn: - print " ASN:", resources.asn - if resources.v4: - print " IPv4:", resources.v4 - if resources.v6: - print " IPv6:", resources.v6 - - - @parsecmd(argsubparsers) - def do_show_roa_requests(self, args): - """ - Show ROA requests. - """ - - for roa_request in self.zoo.resource_ca.roa_requests.all(): - prefixes = roa_request.roa_prefix_bag - print "ASN: ", roa_request.asn - if prefixes.v4: - print " IPv4:", prefixes.v4 - if prefixes.v6: - print " IPv6:", prefixes.v6 - - - @parsecmd(argsubparsers) - def do_show_ghostbuster_requests(self, args): - """ - Show Ghostbuster requests. - """ - - for ghostbuster_request in self.zoo.resource_ca.ghostbuster_requests.all(): - print "Parent:", ghostbuster_request.parent or "*" - print ghostbuster_request.vcard - - - @parsecmd(argsubparsers) - def do_show_received_resources(self, args): - """ - Show resources received by this entity from its parent(s). - """ - - q_msg = self.zoo._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_list_received_resources, tenant_handle = self.zoo.handle) - - for r_pdu in self.zoo.call_rpkid(q_msg): - - print "Parent: ", r_pdu.get("parent_handle") - print " notBefore:", r_pdu.get("notBefore") - print " notAfter: ", r_pdu.get("notAfter") - print " URI: ", r_pdu.get("uri") - print " SIA URI: ", r_pdu.get("sia_uri") - print " AIA URI: ", r_pdu.get("aia_uri") - print " ASN: ", r_pdu.get("asn") - print " IPv4: ", r_pdu.get("ipv4") - print " IPv6: ", r_pdu.get("ipv6") - - - @parsecmd(argsubparsers) - def do_show_published_objects(self, args): - """ - Show published objects. - """ - - q_msg = self.zoo._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_list_published_objects, tenant_handle = self.zoo.handle) - - for r_pdu in self.zoo.call_rpkid(q_msg): - uri = r_pdu.get("uri") - track = rpki.x509.uri_dispatch(uri)(Base64 = r_pdu.text).tracking_data(uri) - child_handle = r_pdu.get("child_handle") - - if child_handle is None: - print track - else: - print track, child_handle - - - @parsecmd(argsubparsers) - def do_show_bpki(self, args): - """ - Show this entity's BPKI objects. - """ - - print "Self: ", self.zoo.resource_ca.handle - print " notBefore:", self.zoo.resource_ca.certificate.getNotBefore() - print " notAfter: ", self.zoo.resource_ca.certificate.getNotAfter() - print " Subject: ", self.zoo.resource_ca.certificate.getSubject() - print " SKI: ", self.zoo.resource_ca.certificate.hSKI() - for bsc in self.zoo.resource_ca.bscs.all(): - print "BSC: ", bsc.handle - print " notBefore:", bsc.certificate.getNotBefore() - print " notAfter: ", bsc.certificate.getNotAfter() - print " Subject: ", bsc.certificate.getSubject() - print " SKI: ", bsc.certificate.hSKI() - for parent in self.zoo.resource_ca.parents.all(): - print "Parent: ", parent.handle - print " notBefore:", parent.certificate.getNotBefore() - print " notAfter: ", parent.certificate.getNotAfter() - print " Subject: ", parent.certificate.getSubject() - print " SKI: ", parent.certificate.hSKI() - print " URL: ", parent.service_uri - for child in self.zoo.resource_ca.children.all(): - print "Child: ", child.handle - print " notBefore:", child.certificate.getNotBefore() - print " notAfter: ", child.certificate.getNotAfter() - print " Subject: ", child.certificate.getSubject() - print " SKI: ", child.certificate.hSKI() - for repository in self.zoo.resource_ca.repositories.all(): - print "Repository: ", repository.handle - print " notBefore:", repository.certificate.getNotBefore() - print " notAfter: ", repository.certificate.getNotAfter() - print " Subject: ", repository.certificate.getSubject() - print " SKI: ", repository.certificate.hSKI() - print " URL: ", repository.service_uri - - - @parsecmd(argsubparsers, - cmdarg("asns_csv", help = "CSV file listing ASNs")) - def do_load_asns(self, args): - """ - Load ASNs into IRDB from CSV file. - """ - - self.zoo.load_asns(args.asns_csv, True) - if self.autosync: - self.zoo.run_rpkid_now() - - - @parsecmd(argsubparsers, - cmdarg("roa_requests_csv", help = "CSV file listing ROA requests")) - def do_load_roa_requests(self, args): - """ - Load ROA requests into IRDB from CSV file. - """ - - self.zoo.load_roa_requests(args.roa_requests_csv) - if self.autosync: - self.zoo.run_rpkid_now() - - - @parsecmd(argsubparsers, - cmdarg("ghostbuster_requests", help = "file listing Ghostbuster requests as a sequence of VCards")) - def do_load_ghostbuster_requests(self, args): - """ - Load Ghostbuster requests into IRDB from file. - """ - - self.zoo.load_ghostbuster_requests(args.ghostbuster_requests) - if self.autosync: - self.zoo.run_rpkid_now() - - - @parsecmd(argsubparsers, - cmdarg("--valid_until", help = "override default validity interval"), - cmdarg("router_certificate_request_xml", help = "file containing XML router certificate request")) - def do_add_router_certificate_request(self, args): - """ - Load router certificate request(s) into IRDB from XML file. - """ - - self.zoo.add_router_certificate_request(args.router_certificate_request_xml, args.valid_until) - if self.autosync: - self.zoo.run_rpkid_now() - - @parsecmd(argsubparsers, - cmdarg("gski", help = "g(SKI) of router certificate request to delete")) - def do_delete_router_certificate_request(self, args): - """ - Delete a router certificate request from the IRDB. - """ - - try: - self.zoo.delete_router_certificate_request(args.gski) - if self.autosync: - self.zoo.run_rpkid_now() - except rpki.irdb.models.ResourceHolderCA.DoesNotExist: - print "No such resource holder \"%s\"" % self.zoo.handle - except rpki.irdb.models.EECertificateRequest.DoesNotExist: - print "No certificate request matching g(SKI) \"%s\"" % args.gski - - def complete_delete_router_certificate_request(self, text, line, begidx, endidx): - return [obj.gski for obj in self.zoo.resource_ca.ee_certificate_requests.all() - if obj.gski and obj.gski.startswith(text)] - - - @parsecmd(argsubparsers) - def do_show_router_certificate_requests(self, args): - """ - Show this entity's router certificate requests. - """ - - for req in self.zoo.resource_ca.ee_certificate_requests.all(): - print "%s %s %s %s" % (req.gski, req.valid_until, req.cn, req.sn) - + @parsecmd(argsubparsers, + cmdarg("child_handle", help = "handle of child to delete")) + def do_delete_child(self, args): + """ + Delete a child of this RPKI entity. + """ + + try: + self.zoo.delete_child(args.child_handle) + self.zoo.synchronize_ca() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + except rpki.irdb.models.Child.DoesNotExist: + print "No such child \"%s\"" % args.child_handle + + def complete_delete_child(self, *args): + return self.irdb_handle_complete(self.zoo.resource_ca.children, *args) + + + @parsecmd(argsubparsers, + cmdarg("--parent_handle", help = "override default handle for new parent"), + cmdarg("parent_xml", help = "XML file containing parent's response")) + def do_configure_parent(self, args): + """ + Configure a new parent of this RPKI entity. + + This command reads the parent's response XML, extracts the + parent's BPKI and service URI information, cross-certifies the + parent's BPKI data into this entity's BPKI, and checks for offers + or referrals of publication service. If a publication offer or + referral is present, we generate a request-for-service message to + that repository, in case the user wants to avail herself of the + referral or offer. + + We do NOT attempt automatic synchronization with rpkid at the + completion of this command, because synchronization at this point + will usually fail due to the repository not being set up yet. If + you know what you are doing and for some reason really want to + synchronize here, run the synchronize command yourself. + """ + + r, parent_handle = self.zoo.configure_parent(args.parent_xml, args.parent_handle) + r.save("%s.%s.repository-request.xml" % (self.zoo.handle, parent_handle), sys.stdout) + + + @parsecmd(argsubparsers, + cmdarg("parent_handle", help = "handle of parent to delete")) + def do_delete_parent(self, args): + """ + Delete a parent of this RPKI entity. + """ + + try: + self.zoo.delete_parent(args.parent_handle) + self.zoo.synchronize_ca() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + except rpki.irdb.models.Parent.DoesNotExist: + print "No such parent \"%s\"" % args.parent_handle + + def complete_delete_parent(self, *args): + return self.irdb_handle_complete(self.zoo.resource_ca.parents, *args) + + + @parsecmd(argsubparsers) + def do_configure_root(self, args): + """ + Configure the current resource holding identity as a root. + + This configures rpkid to talk to rootd as (one of) its parent(s). + Returns repository request XML file like configure_parent does. + """ + + r = self.zoo.configure_rootd() + if r is not None: + r.save("%s.%s.repository-request.xml" % (self.zoo.handle, self.zoo.handle), sys.stdout) + self.zoo.write_bpki_files() + + + @parsecmd(argsubparsers) + def do_delete_root(self, args): + """ + Delete local RPKI root as parent of the current entity. + + This tells the current rpkid identity (<tenant/>) to stop talking to + rootd. + """ + + try: + self.zoo.delete_rootd() + self.zoo.synchronize_ca() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + except rpki.irdb.models.Rootd.DoesNotExist: + print "No associated rootd" + + + @parsecmd(argsubparsers, + cmdarg("--flat", help = "use flat publication scheme", action = "store_true"), + cmdarg("--sia_base", help = "override SIA base value"), + cmdarg("client_xml", help = "XML file containing client request")) + def do_configure_publication_client(self, args): + """ + Configure publication server to know about a new client. + + This command reads the client's request for service, + cross-certifies the client's BPKI data, and generates a response + message containing the repository's BPKI data and service URI. + """ + + r, client_handle = self.zoo.configure_publication_client(args.client_xml, args.sia_base, args.flat) + r.save("%s.repository-response.xml" % client_handle.replace("/", "."), sys.stdout) + try: + self.zoo.synchronize_pubd() + except rpki.irdb.models.Repository.DoesNotExist: + pass + + + @parsecmd(argsubparsers, + cmdarg("client_handle", help = "handle of client to delete")) + def do_delete_publication_client(self, args): + """ + Delete a publication client of this RPKI entity. + """ + + try: + self.zoo.delete_publication_client(args.client_handle) + self.zoo.synchronize_pubd() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + except rpki.irdb.models.Client.DoesNotExist: + print "No such client \"%s\"" % args.client_handle + + def complete_delete_publication_client(self, *args): + return self.irdb_handle_complete(self.zoo.server_ca.clients, *args) + + + @parsecmd(argsubparsers, + cmdarg("--parent_handle", help = "override default parent handle"), + cmdarg("repository_xml", help = "XML file containing repository response")) + def do_configure_repository(self, args): + """ + Configure a publication repository for this RPKI entity. + + This command reads the repository's response to this entity's + request for publication service, extracts and cross-certifies the + BPKI data and service URI, and links the repository data with the + corresponding parent data in our local database. + """ + + self.zoo.configure_repository(args.repository_xml, args.parent_handle) + self.zoo.synchronize_ca() + + + @parsecmd(argsubparsers, + cmdarg("repository_handle", help = "handle of repository to delete")) + def do_delete_repository(self, args): + """ + Delete a repository of this RPKI entity. + """ + + try: + self.zoo.delete_repository(args.repository_handle) + self.zoo.synchronize_ca() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + except rpki.irdb.models.Repository.DoesNotExist: + print "No such repository \"%s\"" % args.repository_handle + + def complete_delete_repository(self, *args): + return self.irdb_handle_complete(self.zoo.resource_ca.repositories, *args) + + + @parsecmd(argsubparsers) + def do_delete_identity(self, args): + """ + Delete the current RPKI identity (rpkid <tenant/> object). + """ + + try: + self.zoo.delete_tenant() + self.zoo.synchronize_deleted_ca() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + + + @parsecmd(argsubparsers, + cmdarg("--valid_until", help = "override default new validity interval"), + cmdarg("child_handle", help = "handle of child to renew")) + def do_renew_child(self, args): + """ + Update validity period for one child entity. + """ + + self.zoo.renew_children(args.child_handle, args.valid_until) + self.zoo.synchronize_ca() + if self.autosync: + self.zoo.run_rpkid_now() + + def complete_renew_child(self, *args): + return self.irdb_handle_complete(self.zoo.resource_ca.children, *args) + + + @parsecmd(argsubparsers, + cmdarg("--valid_until", help = "override default new validity interval")) + def do_renew_all_children(self, args): + """ + Update validity period for all child entities. + """ + + self.zoo.renew_children(None, args.valid_until) + self.zoo.synchronize_ca() + if self.autosync: + self.zoo.run_rpkid_now() + + + @parsecmd(argsubparsers, + cmdarg("prefixes_csv", help = "CSV file listing prefixes")) + def do_load_prefixes(self, args): + """ + Load prefixes into IRDB from CSV file. + """ + + self.zoo.load_prefixes(args.prefixes_csv, True) + if self.autosync: + self.zoo.run_rpkid_now() + + + @parsecmd(argsubparsers) + def do_show_child_resources(self, args): + """ + Show resources assigned to children. + """ + + for child in self.zoo.resource_ca.children.all(): + resources = child.resource_bag + print "Child:", child.handle + if resources.asn: + print " ASN:", resources.asn + if resources.v4: + print " IPv4:", resources.v4 + if resources.v6: + print " IPv6:", resources.v6 + + + @parsecmd(argsubparsers) + def do_show_roa_requests(self, args): + """ + Show ROA requests. + """ + + for roa_request in self.zoo.resource_ca.roa_requests.all(): + prefixes = roa_request.roa_prefix_bag + print "ASN: ", roa_request.asn + if prefixes.v4: + print " IPv4:", prefixes.v4 + if prefixes.v6: + print " IPv6:", prefixes.v6 + + + @parsecmd(argsubparsers) + def do_show_ghostbuster_requests(self, args): + """ + Show Ghostbuster requests. + """ + + for ghostbuster_request in self.zoo.resource_ca.ghostbuster_requests.all(): + print "Parent:", ghostbuster_request.parent or "*" + print ghostbuster_request.vcard + + + @parsecmd(argsubparsers) + def do_show_received_resources(self, args): + """ + Show resources received by this entity from its parent(s). + """ + + q_msg = self.zoo._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_list_received_resources, tenant_handle = self.zoo.handle) + + for r_pdu in self.zoo.call_rpkid(q_msg): + + print "Parent: ", r_pdu.get("parent_handle") + print " notBefore:", r_pdu.get("notBefore") + print " notAfter: ", r_pdu.get("notAfter") + print " URI: ", r_pdu.get("uri") + print " SIA URI: ", r_pdu.get("sia_uri") + print " AIA URI: ", r_pdu.get("aia_uri") + print " ASN: ", r_pdu.get("asn") + print " IPv4: ", r_pdu.get("ipv4") + print " IPv6: ", r_pdu.get("ipv6") + + + @parsecmd(argsubparsers) + def do_show_published_objects(self, args): + """ + Show published objects. + """ + + q_msg = self.zoo._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_list_published_objects, tenant_handle = self.zoo.handle) + + for r_pdu in self.zoo.call_rpkid(q_msg): + uri = r_pdu.get("uri") + track = rpki.x509.uri_dispatch(uri)(Base64 = r_pdu.text).tracking_data(uri) + child_handle = r_pdu.get("child_handle") + + if child_handle is None: + print track + else: + print track, child_handle + + + @parsecmd(argsubparsers) + def do_show_bpki(self, args): + """ + Show this entity's BPKI objects. + """ + + print "Self: ", self.zoo.resource_ca.handle + print " notBefore:", self.zoo.resource_ca.certificate.getNotBefore() + print " notAfter: ", self.zoo.resource_ca.certificate.getNotAfter() + print " Subject: ", self.zoo.resource_ca.certificate.getSubject() + print " SKI: ", self.zoo.resource_ca.certificate.hSKI() + for bsc in self.zoo.resource_ca.bscs.all(): + print "BSC: ", bsc.handle + print " notBefore:", bsc.certificate.getNotBefore() + print " notAfter: ", bsc.certificate.getNotAfter() + print " Subject: ", bsc.certificate.getSubject() + print " SKI: ", bsc.certificate.hSKI() + for parent in self.zoo.resource_ca.parents.all(): + print "Parent: ", parent.handle + print " notBefore:", parent.certificate.getNotBefore() + print " notAfter: ", parent.certificate.getNotAfter() + print " Subject: ", parent.certificate.getSubject() + print " SKI: ", parent.certificate.hSKI() + print " URL: ", parent.service_uri + for child in self.zoo.resource_ca.children.all(): + print "Child: ", child.handle + print " notBefore:", child.certificate.getNotBefore() + print " notAfter: ", child.certificate.getNotAfter() + print " Subject: ", child.certificate.getSubject() + print " SKI: ", child.certificate.hSKI() + for repository in self.zoo.resource_ca.repositories.all(): + print "Repository: ", repository.handle + print " notBefore:", repository.certificate.getNotBefore() + print " notAfter: ", repository.certificate.getNotAfter() + print " Subject: ", repository.certificate.getSubject() + print " SKI: ", repository.certificate.hSKI() + print " URL: ", repository.service_uri + + + @parsecmd(argsubparsers, + cmdarg("asns_csv", help = "CSV file listing ASNs")) + def do_load_asns(self, args): + """ + Load ASNs into IRDB from CSV file. + """ + + self.zoo.load_asns(args.asns_csv, True) + if self.autosync: + self.zoo.run_rpkid_now() + + + @parsecmd(argsubparsers, + cmdarg("roa_requests_csv", help = "CSV file listing ROA requests")) + def do_load_roa_requests(self, args): + """ + Load ROA requests into IRDB from CSV file. + """ + + self.zoo.load_roa_requests(args.roa_requests_csv) + if self.autosync: + self.zoo.run_rpkid_now() + + + @parsecmd(argsubparsers, + cmdarg("ghostbuster_requests", help = "file listing Ghostbuster requests as a sequence of VCards")) + def do_load_ghostbuster_requests(self, args): + """ + Load Ghostbuster requests into IRDB from file. + """ + + self.zoo.load_ghostbuster_requests(args.ghostbuster_requests) + if self.autosync: + self.zoo.run_rpkid_now() + + + @parsecmd(argsubparsers, + cmdarg("--valid_until", help = "override default validity interval"), + cmdarg("router_certificate_request_xml", help = "file containing XML router certificate request")) + def do_add_router_certificate_request(self, args): + """ + Load router certificate request(s) into IRDB from XML file. + """ + + self.zoo.add_router_certificate_request(args.router_certificate_request_xml, args.valid_until) + if self.autosync: + self.zoo.run_rpkid_now() + + @parsecmd(argsubparsers, + cmdarg("gski", help = "g(SKI) of router certificate request to delete")) + def do_delete_router_certificate_request(self, args): + """ + Delete a router certificate request from the IRDB. + """ + + try: + self.zoo.delete_router_certificate_request(args.gski) + if self.autosync: + self.zoo.run_rpkid_now() + except rpki.irdb.models.ResourceHolderCA.DoesNotExist: + print "No such resource holder \"%s\"" % self.zoo.handle + except rpki.irdb.models.EECertificateRequest.DoesNotExist: + print "No certificate request matching g(SKI) \"%s\"" % args.gski + + def complete_delete_router_certificate_request(self, text, line, begidx, endidx): + return [obj.gski for obj in self.zoo.resource_ca.ee_certificate_requests.all() + if obj.gski and obj.gski.startswith(text)] + + + @parsecmd(argsubparsers) + def do_show_router_certificate_requests(self, args): + """ + Show this entity's router certificate requests. + """ + + for req in self.zoo.resource_ca.ee_certificate_requests.all(): + print "%s %s %s %s" % (req.gski, req.valid_until, req.cn, req.sn) + - # What about updates? Validity interval, change router-id, change - # ASNs. Not sure what this looks like yet, blunder ahead with the - # core code while mulling over the UI. + # What about updates? Validity interval, change router-id, change + # ASNs. Not sure what this looks like yet, blunder ahead with the + # core code while mulling over the UI. - @parsecmd(argsubparsers) - def do_synchronize(self, args): - """ - Whack daemons to match IRDB. + @parsecmd(argsubparsers) + def do_synchronize(self, args): + """ + Whack daemons to match IRDB. - This command may be replaced by implicit synchronization embedded - in of other commands, haven't decided yet. - """ + This command may be replaced by implicit synchronization embedded + in of other commands, haven't decided yet. + """ - self.zoo.synchronize() + self.zoo.synchronize() - @parsecmd(argsubparsers) - def do_force_publication(self, args): - """ - Whack rpkid to force (re)publication of everything. + @parsecmd(argsubparsers) + def do_force_publication(self, args): + """ + Whack rpkid to force (re)publication of everything. - This is not usually necessary, as rpkid automatically publishes - changes it makes, but this command can be useful occasionally when - a fault or configuration error has left rpkid holding data which - it has not been able to publish. - """ + This is not usually necessary, as rpkid automatically publishes + changes it makes, but this command can be useful occasionally when + a fault or configuration error has left rpkid holding data which + it has not been able to publish. + """ - self.zoo.publish_world_now() + self.zoo.publish_world_now() - @parsecmd(argsubparsers) - def do_force_reissue(self, args): - """ - Whack rpkid to force reissuance of everything. + @parsecmd(argsubparsers) + def do_force_reissue(self, args): + """ + Whack rpkid to force reissuance of everything. - This is not usually necessary, as rpkid reissues automatically - objects automatically as needed, but this command can be useful - occasionally when a fault or configuration error has prevented - rpkid from reissuing when it should have. - """ + This is not usually necessary, as rpkid reissues automatically + objects automatically as needed, but this command can be useful + occasionally when a fault or configuration error has prevented + rpkid from reissuing when it should have. + """ - self.zoo.reissue() + self.zoo.reissue() - @parsecmd(argsubparsers) - def do_up_down_rekey(self, args): - """ - Initiate a "rekey" operation. + @parsecmd(argsubparsers) + def do_up_down_rekey(self, args): + """ + Initiate a "rekey" operation. - This tells rpkid to generate new keys for each certificate issued - to it via the up-down protocol. + This tells rpkid to generate new keys for each certificate issued + to it via the up-down protocol. - Rekeying is the first stage of a key rollover operation. You will - need to follow it up later with a "revoke" operation to clean up - the old keys - """ + Rekeying is the first stage of a key rollover operation. You will + need to follow it up later with a "revoke" operation to clean up + the old keys + """ - self.zoo.rekey() + self.zoo.rekey() - @parsecmd(argsubparsers) - def do_up_down_revoke(self, args): - """ - Initiate a "revoke" operation. + @parsecmd(argsubparsers) + def do_up_down_revoke(self, args): + """ + Initiate a "revoke" operation. - This tells rpkid to clean up old keys formerly used by - certificates issued to it via the up-down protocol. + This tells rpkid to clean up old keys formerly used by + certificates issued to it via the up-down protocol. - This is the cleanup stage of a key rollover operation. - """ + This is the cleanup stage of a key rollover operation. + """ - self.zoo.revoke() + self.zoo.revoke() - @parsecmd(argsubparsers) - def do_revoke_forgotten(self, args): - """ - Initiate a "revoke_forgotten" operation. + @parsecmd(argsubparsers) + def do_revoke_forgotten(self, args): + """ + Initiate a "revoke_forgotten" operation. - This tells rpkid to ask its parent to revoke certificates for - which rpkid does not know the private keys. + This tells rpkid to ask its parent to revoke certificates for + which rpkid does not know the private keys. - This should never happen during ordinary operation, but can happen - if rpkid is misconfigured or its database has been damaged, so we - need a way to resynchronize rpkid with its parent in such cases. - We could do this automatically, but as we don't know the precise - cause of the failure we don't know if it's recoverable locally - (eg, from an SQL backup), so we require a manual trigger before - discarding possibly-useful certificates. - """ + This should never happen during ordinary operation, but can happen + if rpkid is misconfigured or its database has been damaged, so we + need a way to resynchronize rpkid with its parent in such cases. + We could do this automatically, but as we don't know the precise + cause of the failure we don't know if it's recoverable locally + (eg, from an SQL backup), so we require a manual trigger before + discarding possibly-useful certificates. + """ - self.zoo.revoke_forgotten() + self.zoo.revoke_forgotten() - @parsecmd(argsubparsers) - def do_clear_all_sql_cms_replay_protection(self, args): - """ - Tell rpkid and pubd to clear replay protection. + @parsecmd(argsubparsers) + def do_clear_all_sql_cms_replay_protection(self, args): + """ + Tell rpkid and pubd to clear replay protection. - This clears the replay protection timestamps stored in SQL for all - entities known to rpkid and pubd. This is a fairly blunt - instrument, but as we don't expect this to be necessary except in - the case of gross misconfiguration, it should suffice - """ + This clears the replay protection timestamps stored in SQL for all + entities known to rpkid and pubd. This is a fairly blunt + instrument, but as we don't expect this to be necessary except in + the case of gross misconfiguration, it should suffice + """ - self.zoo.clear_all_sql_cms_replay_protection() + self.zoo.clear_all_sql_cms_replay_protection() - @parsecmd(argsubparsers) - def do_version(self, args): - """ - Show current software version number. - """ + @parsecmd(argsubparsers) + def do_version(self, args): + """ + Show current software version number. + """ - print rpki.version.VERSION + print rpki.version.VERSION - @parsecmd(argsubparsers) - def do_list_tenant_handles(self, args): - """ - List all <tenant/> handles in this rpkid instance. - """ - - for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): - print ca.handle + @parsecmd(argsubparsers) + def do_list_tenant_handles(self, args): + """ + List all <tenant/> handles in this rpkid instance. + """ + for ca in rpki.irdb.models.ResourceHolderCA.objects.all(): + print ca.handle diff --git a/rpki/rpkid.py b/rpki/rpkid.py index da6141ea..c0ddbd58 100644 --- a/rpki/rpkid.py +++ b/rpki/rpkid.py @@ -56,681 +56,681 @@ logger = logging.getLogger(__name__) class main(object): - """ - Main program for rpkid. - """ - - def __init__(self): - - os.environ.update(TZ = "UTC", - DJANGO_SETTINGS_MODULE = "rpki.django_settings.rpkid") - time.tzset() - - self.irdbd_cms_timestamp = None - self.irbe_cms_timestamp = None - - self.task_queue = [] - self.task_event = tornado.locks.Event() - - self.http_client_serialize = weakref.WeakValueDictionary() - - parser = argparse.ArgumentParser(description = __doc__) - parser.add_argument("-c", "--config", - help = "override default location of configuration file") - parser.add_argument("-f", "--foreground", action = "store_true", - help = "do not daemonize") - parser.add_argument("--pidfile", - help = "override default location of pid file") - parser.add_argument("--profile", - help = "enable profiling, saving data to PROFILE") - rpki.log.argparse_setup(parser) - args = parser.parse_args() - - self.profile = args.profile - - rpki.log.init("rpkid", args) - - self.cfg = rpki.config.parser(set_filename = args.config, section = "rpkid") - self.cfg.set_global_flags() - - if not args.foreground: - rpki.daemonize.daemon(pidfile = args.pidfile) - - if self.profile: - import cProfile - prof = cProfile.Profile() - try: - prof.runcall(self.main) - finally: - prof.dump_stats(self.profile) - logger.info("Dumped profile data to %s", self.profile) - else: - self.main() - - def main(self): - - startup_msg = self.cfg.get("startup-message", "") - if startup_msg: - logger.info(startup_msg) - - if self.profile: - logger.info("Running in profile mode with output to %s", self.profile) - - logger.debug("Initializing Django") - import django - django.setup() - - logger.debug("Initializing rpkidb...") - global rpki # pylint: disable=W0602 - import rpki.rpkidb # pylint: disable=W0621 - - logger.debug("Initializing rpkidb...done") - - self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) - self.irdb_cert = rpki.x509.X509(Auto_update = self.cfg.get("irdb-cert")) - self.irbe_cert = rpki.x509.X509(Auto_update = self.cfg.get("irbe-cert")) - self.rpkid_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpkid-cert")) - self.rpkid_key = rpki.x509.RSA( Auto_update = self.cfg.get("rpkid-key")) - - self.irdb_url = self.cfg.get("irdb-url") - - self.http_server_host = self.cfg.get("server-host", "") - self.http_server_port = self.cfg.getint("server-port") - - self.use_internal_cron = self.cfg.getboolean("use-internal-cron", True) - - self.initial_delay = random.randint(self.cfg.getint("initial-delay-min", 10), - self.cfg.getint("initial-delay-max", 120)) - - # Should be much longer in production - self.cron_period = self.cfg.getint("cron-period", 120) - - if self.use_internal_cron: - logger.debug("Scheduling initial cron pass in %s seconds", self.initial_delay) - tornado.ioloop.IOLoop.current().spawn_callback(self.cron_loop) - - logger.debug("Scheduling task loop") - tornado.ioloop.IOLoop.current().spawn_callback(self.task_loop) - - rpkid = self - - class LeftRightHandler(tornado.web.RequestHandler): # pylint: disable=W0223 - @tornado.gen.coroutine - def post(self): - yield rpkid.left_right_handler(self) - - class UpDownHandler(tornado.web.RequestHandler): # pylint: disable=W0223 - @tornado.gen.coroutine - def post(self, tenant_handle, child_handle): # pylint: disable=W0221 - yield rpkid.up_down_handler(self, tenant_handle, child_handle) - - class CronjobHandler(tornado.web.RequestHandler): # pylint: disable=W0223 - @tornado.gen.coroutine - def post(self): - yield rpkid.cronjob_handler(self) - - application = tornado.web.Application(( - (r"/left-right", LeftRightHandler), - (r"/up-down/([-a-zA-Z0-9_]+)/([-a-zA-Z0-9_]+)", UpDownHandler), - (r"/cronjob", CronjobHandler))) - - application.listen( - address = self.http_server_host, - port = self.http_server_port) - - tornado.ioloop.IOLoop.current().start() - - def task_add(self, tasks): """ - Add zero or more tasks to the task queue. + Main program for rpkid. """ - for task in tasks: - if task in self.task_queue: - logger.debug("Task %r already queued", task) - else: - logger.debug("Adding %r to task queue", task) - self.task_queue.append(task) + def __init__(self): - def task_run(self): - """ - Kick the task loop to notice recently added tasks. - """ + os.environ.update(TZ = "UTC", + DJANGO_SETTINGS_MODULE = "rpki.django_settings.rpkid") + time.tzset() - self.task_event.set() + self.irdbd_cms_timestamp = None + self.irbe_cms_timestamp = None - @tornado.gen.coroutine - def task_loop(self): - """ - Asynchronous infinite loop to run background tasks. + self.task_queue = [] + self.task_event = tornado.locks.Event() - This code is a bit finicky, because it's managing a collection of - Future objects which are running independently of the control flow - here, and the wave function doesn't collapse until we do a yield. + self.http_client_serialize = weakref.WeakValueDictionary() - So we keep this brutally simple and don't try to hide too much of - it in the AbstractTask class. For similar reasons, AbstractTask - sets aside a .future instance variable for this method's use. - """ + parser = argparse.ArgumentParser(description = __doc__) + parser.add_argument("-c", "--config", + help = "override default location of configuration file") + parser.add_argument("-f", "--foreground", action = "store_true", + help = "do not daemonize") + parser.add_argument("--pidfile", + help = "override default location of pid file") + parser.add_argument("--profile", + help = "enable profiling, saving data to PROFILE") + rpki.log.argparse_setup(parser) + args = parser.parse_args() - logger.debug("Starting task loop") - task_event_future = None - - while True: - while None in self.task_queue: - self.task_queue.remove(None) - - futures = [] - for task in self.task_queue: - if task.future is None: - task.future = task.start() - futures.append(task.future) - if task_event_future is None: - task_event_future = self.task_event.wait() - futures.append(task_event_future) - iterator = tornado.gen.WaitIterator(*futures) - - while not iterator.done(): - yield iterator.next() - if iterator.current_future is task_event_future: - self.task_event.clear() - task_event_future = None - break - else: - task = self.task_queue[iterator.current_index] - task.future = None - waiting = task.waiting() - if not waiting: - self.task_queue[iterator.current_index] = None - for task in self.task_queue: - if task is not None and not task.runnable.is_set(): - logger.debug("Reenabling task %r", task) - task.runnable.set() - if waiting: - break - - @tornado.gen.coroutine - def cron_loop(self): - """ - Asynchronous infinite loop to drive cron cycle. - """ + self.profile = args.profile - logger.debug("cron_loop(): Starting") - assert self.use_internal_cron - logger.debug("cron_loop(): Startup delay %d seconds", self.initial_delay) - yield tornado.gen.sleep(self.initial_delay) - while True: - logger.debug("cron_loop(): Running") - yield self.cron_run() - logger.debug("cron_loop(): Sleeping %d seconds", self.cron_period) - yield tornado.gen.sleep(self.cron_period) - - @tornado.gen.coroutine - def cron_run(self): - """ - Schedule periodic tasks. - """ + rpki.log.init("rpkid", args) - now = rpki.sundial.now() - logger.debug("Starting cron run") - try: - tenants = rpki.rpkidb.models.Tenant.objects.all() - except: - logger.exception("Error pulling tenants from SQL, maybe SQL server is down?") - else: - tasks = tuple(task for tenant in tenants for task in tenant.cron_tasks(self)) - self.task_add(tasks) - futures = [task.wait() for task in tasks] - self.task_run() - yield futures - logger.info("Finished cron run started at %s", now) - - @tornado.gen.coroutine - def cronjob_handler(self, handler): - """ - External trigger to schedule periodic tasks. Obsolete for - produciton use, but portions of the test framework still use this. - """ + self.cfg = rpki.config.parser(set_filename = args.config, section = "rpkid") + self.cfg.set_global_flags() - if self.use_internal_cron: - handler.set_status(500, "Running cron internally") - else: - logger.debug("Starting externally triggered cron") - yield self.cron() - handler.set_status(200) - handler.finish() + if not args.foreground: + rpki.daemonize.daemon(pidfile = args.pidfile) - @tornado.gen.coroutine - def http_fetch(self, request, serialize_on_full_url = False): - """ - Wrapper around tornado.httpclient.AsyncHTTPClient() which - serializes requests to any particular HTTP server, to avoid - spurious CMS replay errors. - """ + if self.profile: + import cProfile + prof = cProfile.Profile() + try: + prof.runcall(self.main) + finally: + prof.dump_stats(self.profile) + logger.info("Dumped profile data to %s", self.profile) + else: + self.main() - # The current definition of "particular HTTP server" is based only - # on the "netloc" portion of the URL, which could in theory could - # cause deadlocks in a loopback scenario; no such deadlocks have - # shown up in testing, but if such a thing were to occur, it would - # look like an otherwise inexplicable HTTP timeout. The solution, - # should this occur, would be to use the entire URL as the lookup - # key, perhaps only for certain protocols. - # - # The reason for the current scheme is that at least one protocol - # (publication) uses RESTful URLs but has a single service-wide - # CMS replay detection database, which translates to meaning that - # we need to serialize all requests for that service, not just - # requests to a particular URL. - - if serialize_on_full_url: - netlock = request.url - else: - netlock = urlparse.urlparse(request.url).netloc - - try: - lock = self.http_client_serialize[netlock] - except KeyError: - lock = self.http_client_serialize[netlock] = tornado.locks.Lock() - - http_client = tornado.httpclient.AsyncHTTPClient() - - with (yield lock.acquire()): - response = yield http_client.fetch(request) - - raise tornado.gen.Return(response) - - @staticmethod - def _compose_left_right_query(): - """ - Compose top level element of a left-right query to irdbd. - """ + def main(self): - return Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, - type = "query", version = rpki.left_right.version) + startup_msg = self.cfg.get("startup-message", "") + if startup_msg: + logger.info(startup_msg) - @tornado.gen.coroutine - def irdb_query(self, q_msg): - """ - Perform an IRDB callback query. - """ + if self.profile: + logger.info("Running in profile mode with output to %s", self.profile) - q_tags = set(q_pdu.tag for q_pdu in q_msg) + logger.debug("Initializing Django") + import django + django.setup() - q_der = rpki.left_right.cms_msg().wrap(q_msg, self.rpkid_key, self.rpkid_cert) + logger.debug("Initializing rpkidb...") + global rpki # pylint: disable=W0602 + import rpki.rpkidb # pylint: disable=W0621 - http_request = tornado.httpclient.HTTPRequest( - url = self.irdb_url, - method = "POST", - body = q_der, - headers = { "Content-Type" : rpki.left_right.content_type }) + logger.debug("Initializing rpkidb...done") - http_response = yield self.http_fetch(http_request) + self.bpki_ta = rpki.x509.X509(Auto_update = self.cfg.get("bpki-ta")) + self.irdb_cert = rpki.x509.X509(Auto_update = self.cfg.get("irdb-cert")) + self.irbe_cert = rpki.x509.X509(Auto_update = self.cfg.get("irbe-cert")) + self.rpkid_cert = rpki.x509.X509(Auto_update = self.cfg.get("rpkid-cert")) + self.rpkid_key = rpki.x509.RSA( Auto_update = self.cfg.get("rpkid-key")) - # Tornado already checked http_response.code for us + self.irdb_url = self.cfg.get("irdb-url") - content_type = http_response.headers.get("Content-Type") + self.http_server_host = self.cfg.get("server-host", "") + self.http_server_port = self.cfg.getint("server-port") + + self.use_internal_cron = self.cfg.getboolean("use-internal-cron", True) + + self.initial_delay = random.randint(self.cfg.getint("initial-delay-min", 10), + self.cfg.getint("initial-delay-max", 120)) + + # Should be much longer in production + self.cron_period = self.cfg.getint("cron-period", 120) + + if self.use_internal_cron: + logger.debug("Scheduling initial cron pass in %s seconds", self.initial_delay) + tornado.ioloop.IOLoop.current().spawn_callback(self.cron_loop) + + logger.debug("Scheduling task loop") + tornado.ioloop.IOLoop.current().spawn_callback(self.task_loop) + + rpkid = self + + class LeftRightHandler(tornado.web.RequestHandler): # pylint: disable=W0223 + @tornado.gen.coroutine + def post(self): + yield rpkid.left_right_handler(self) + + class UpDownHandler(tornado.web.RequestHandler): # pylint: disable=W0223 + @tornado.gen.coroutine + def post(self, tenant_handle, child_handle): # pylint: disable=W0221 + yield rpkid.up_down_handler(self, tenant_handle, child_handle) + + class CronjobHandler(tornado.web.RequestHandler): # pylint: disable=W0223 + @tornado.gen.coroutine + def post(self): + yield rpkid.cronjob_handler(self) + + application = tornado.web.Application(( + (r"/left-right", LeftRightHandler), + (r"/up-down/([-a-zA-Z0-9_]+)/([-a-zA-Z0-9_]+)", UpDownHandler), + (r"/cronjob", CronjobHandler))) + + application.listen( + address = self.http_server_host, + port = self.http_server_port) + + tornado.ioloop.IOLoop.current().start() + + def task_add(self, tasks): + """ + Add zero or more tasks to the task queue. + """ + + for task in tasks: + if task in self.task_queue: + logger.debug("Task %r already queued", task) + else: + logger.debug("Adding %r to task queue", task) + self.task_queue.append(task) + + def task_run(self): + """ + Kick the task loop to notice recently added tasks. + """ + + self.task_event.set() + + @tornado.gen.coroutine + def task_loop(self): + """ + Asynchronous infinite loop to run background tasks. + + This code is a bit finicky, because it's managing a collection of + Future objects which are running independently of the control flow + here, and the wave function doesn't collapse until we do a yield. + + So we keep this brutally simple and don't try to hide too much of + it in the AbstractTask class. For similar reasons, AbstractTask + sets aside a .future instance variable for this method's use. + """ + + logger.debug("Starting task loop") + task_event_future = None + + while True: + while None in self.task_queue: + self.task_queue.remove(None) + + futures = [] + for task in self.task_queue: + if task.future is None: + task.future = task.start() + futures.append(task.future) + if task_event_future is None: + task_event_future = self.task_event.wait() + futures.append(task_event_future) + iterator = tornado.gen.WaitIterator(*futures) + + while not iterator.done(): + yield iterator.next() + if iterator.current_future is task_event_future: + self.task_event.clear() + task_event_future = None + break + else: + task = self.task_queue[iterator.current_index] + task.future = None + waiting = task.waiting() + if not waiting: + self.task_queue[iterator.current_index] = None + for task in self.task_queue: + if task is not None and not task.runnable.is_set(): + logger.debug("Reenabling task %r", task) + task.runnable.set() + if waiting: + break + + @tornado.gen.coroutine + def cron_loop(self): + """ + Asynchronous infinite loop to drive cron cycle. + """ + + logger.debug("cron_loop(): Starting") + assert self.use_internal_cron + logger.debug("cron_loop(): Startup delay %d seconds", self.initial_delay) + yield tornado.gen.sleep(self.initial_delay) + while True: + logger.debug("cron_loop(): Running") + yield self.cron_run() + logger.debug("cron_loop(): Sleeping %d seconds", self.cron_period) + yield tornado.gen.sleep(self.cron_period) + + @tornado.gen.coroutine + def cron_run(self): + """ + Schedule periodic tasks. + """ + + now = rpki.sundial.now() + logger.debug("Starting cron run") + try: + tenants = rpki.rpkidb.models.Tenant.objects.all() + except: + logger.exception("Error pulling tenants from SQL, maybe SQL server is down?") + else: + tasks = tuple(task for tenant in tenants for task in tenant.cron_tasks(self)) + self.task_add(tasks) + futures = [task.wait() for task in tasks] + self.task_run() + yield futures + logger.info("Finished cron run started at %s", now) + + @tornado.gen.coroutine + def cronjob_handler(self, handler): + """ + External trigger to schedule periodic tasks. Obsolete for + produciton use, but portions of the test framework still use this. + """ + + if self.use_internal_cron: + handler.set_status(500, "Running cron internally") + else: + logger.debug("Starting externally triggered cron") + yield self.cron() + handler.set_status(200) + handler.finish() + + @tornado.gen.coroutine + def http_fetch(self, request, serialize_on_full_url = False): + """ + Wrapper around tornado.httpclient.AsyncHTTPClient() which + serializes requests to any particular HTTP server, to avoid + spurious CMS replay errors. + """ + + # The current definition of "particular HTTP server" is based only + # on the "netloc" portion of the URL, which could in theory could + # cause deadlocks in a loopback scenario; no such deadlocks have + # shown up in testing, but if such a thing were to occur, it would + # look like an otherwise inexplicable HTTP timeout. The solution, + # should this occur, would be to use the entire URL as the lookup + # key, perhaps only for certain protocols. + # + # The reason for the current scheme is that at least one protocol + # (publication) uses RESTful URLs but has a single service-wide + # CMS replay detection database, which translates to meaning that + # we need to serialize all requests for that service, not just + # requests to a particular URL. + + if serialize_on_full_url: + netlock = request.url + else: + netlock = urlparse.urlparse(request.url).netloc - if content_type not in rpki.left_right.allowed_content_types: - raise rpki.exceptions.BadContentType("HTTP Content-Type %r, expected %r" % (rpki.left_right.content_type, content_type)) + try: + lock = self.http_client_serialize[netlock] + except KeyError: + lock = self.http_client_serialize[netlock] = tornado.locks.Lock() - r_der = http_response.body + http_client = tornado.httpclient.AsyncHTTPClient() - r_cms = rpki.left_right.cms_msg(DER = r_der) - r_msg = r_cms.unwrap((self.bpki_ta, self.irdb_cert)) + with (yield lock.acquire()): + response = yield http_client.fetch(request) - self.irdbd_cms_timestamp = r_cms.check_replay(self.irdbd_cms_timestamp, self.irdb_url) + raise tornado.gen.Return(response) - #rpki.left_right.check_response(r_msg) + @staticmethod + def _compose_left_right_query(): + """ + Compose top level element of a left-right query to irdbd. + """ - if r_msg.get("type") != "reply" or not all(r_pdu.tag in q_tags for r_pdu in r_msg): - raise rpki.exceptions.BadIRDBReply("Unexpected response to IRDB query: %s" % r_cms.pretty_print_content()) + return Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, + type = "query", version = rpki.left_right.version) - raise tornado.gen.Return(r_msg) + @tornado.gen.coroutine + def irdb_query(self, q_msg): + """ + Perform an IRDB callback query. + """ - @tornado.gen.coroutine - def irdb_query_child_resources(self, tenant_handle, child_handle): - """ - Ask IRDB about a child's resources. - """ + q_tags = set(q_pdu.tag for q_pdu in q_msg) - q_msg = self._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_list_resources, tenant_handle = tenant_handle, child_handle = child_handle) + q_der = rpki.left_right.cms_msg().wrap(q_msg, self.rpkid_key, self.rpkid_cert) - r_msg = yield self.irdb_query(q_msg) + http_request = tornado.httpclient.HTTPRequest( + url = self.irdb_url, + method = "POST", + body = q_der, + headers = { "Content-Type" : rpki.left_right.content_type }) - if len(r_msg) != 1: - raise rpki.exceptions.BadIRDBReply("Expected exactly one PDU from IRDB: %s" % r_msg.pretty_print_content()) + http_response = yield self.http_fetch(http_request) - bag = rpki.resource_set.resource_bag( - asn = rpki.resource_set.resource_set_as(r_msg[0].get("asn")), - v4 = rpki.resource_set.resource_set_ipv4(r_msg[0].get("ipv4")), - v6 = rpki.resource_set.resource_set_ipv6(r_msg[0].get("ipv6")), - valid_until = rpki.sundial.datetime.fromXMLtime(r_msg[0].get("valid_until"))) + # Tornado already checked http_response.code for us - raise tornado.gen.Return(bag) + content_type = http_response.headers.get("Content-Type") - @tornado.gen.coroutine - def irdb_query_roa_requests(self, tenant_handle): - """ - Ask IRDB about self's ROA requests. - """ + if content_type not in rpki.left_right.allowed_content_types: + raise rpki.exceptions.BadContentType("HTTP Content-Type %r, expected %r" % (rpki.left_right.content_type, content_type)) - q_msg = self._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_list_roa_requests, tenant_handle = tenant_handle) - r_msg = yield self.irdb_query(q_msg) - raise tornado.gen.Return(r_msg) + r_der = http_response.body - @tornado.gen.coroutine - def irdb_query_ghostbuster_requests(self, tenant_handle, parent_handles): - """ - Ask IRDB about self's ghostbuster record requests. - """ + r_cms = rpki.left_right.cms_msg(DER = r_der) + r_msg = r_cms.unwrap((self.bpki_ta, self.irdb_cert)) - q_msg = self._compose_left_right_query() - for parent_handle in parent_handles: - SubElement(q_msg, rpki.left_right.tag_list_ghostbuster_requests, - tenant_handle = tenant_handle, parent_handle = parent_handle) - r_msg = yield self.irdb_query(q_msg) - raise tornado.gen.Return(r_msg) + self.irdbd_cms_timestamp = r_cms.check_replay(self.irdbd_cms_timestamp, self.irdb_url) - @tornado.gen.coroutine - def irdb_query_ee_certificate_requests(self, tenant_handle): - """ - Ask IRDB about self's EE certificate requests. - """ + #rpki.left_right.check_response(r_msg) - q_msg = self._compose_left_right_query() - SubElement(q_msg, rpki.left_right.tag_list_ee_certificate_requests, tenant_handle = tenant_handle) - r_msg = yield self.irdb_query(q_msg) - raise tornado.gen.Return(r_msg) + if r_msg.get("type") != "reply" or not all(r_pdu.tag in q_tags for r_pdu in r_msg): + raise rpki.exceptions.BadIRDBReply("Unexpected response to IRDB query: %s" % r_cms.pretty_print_content()) - @property - def left_right_models(self): - """ - Map element tag to rpkidb model. - """ + raise tornado.gen.Return(r_msg) - try: - return self._left_right_models - except AttributeError: - import rpki.rpkidb.models # pylint: disable=W0621 - self._left_right_models = { - rpki.left_right.tag_tenant : rpki.rpkidb.models.Tenant, - rpki.left_right.tag_bsc : rpki.rpkidb.models.BSC, - rpki.left_right.tag_parent : rpki.rpkidb.models.Parent, - rpki.left_right.tag_child : rpki.rpkidb.models.Child, - rpki.left_right.tag_repository : rpki.rpkidb.models.Repository } - return self._left_right_models - - @property - def left_right_trivial_handlers(self): - """ - Map element tag to bound handler methods for trivial PDU types. - """ + @tornado.gen.coroutine + def irdb_query_child_resources(self, tenant_handle, child_handle): + """ + Ask IRDB about a child's resources. + """ - try: - return self._left_right_trivial_handlers - except AttributeError: - self._left_right_trivial_handlers = { - rpki.left_right.tag_list_published_objects : self.handle_list_published_objects, - rpki.left_right.tag_list_received_resources : self.handle_list_received_resources } - return self._left_right_trivial_handlers + q_msg = self._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_list_resources, tenant_handle = tenant_handle, child_handle = child_handle) - def handle_list_published_objects(self, q_pdu, r_msg): - """ - <list_published_objects/> server. - """ + r_msg = yield self.irdb_query(q_msg) - tenant_handle = q_pdu.get("tenant_handle") - msg_tag = q_pdu.get("tag") - - kw = dict(tenant_handle = tenant_handle) - if msg_tag is not None: - kw.update(tag = msg_tag) - - for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__tenant__tenant_handle = tenant_handle, state = "active"): - SubElement(r_msg, rpki.left_right.tag_list_published_objects, - uri = ca_detail.crl_uri, **kw).text = ca_detail.latest_crl.get_Base64() - SubElement(r_msg, rpki.left_right.tag_list_published_objects, - uri = ca_detail.manifest_uri, **kw).text = ca_detail.latest_manifest.get_Base64() - for c in ca_detail.child_certs.all(): - SubElement(r_msg, rpki.left_right.tag_list_published_objects, - uri = c.uri, child_handle = c.child.child_handle, **kw).text = c.cert.get_Base64() - for r in ca_detail.roas.filter(roa__isnull = False): - SubElement(r_msg, rpki.left_right.tag_list_published_objects, - uri = r.uri, **kw).text = r.roa.get_Base64() - for g in ca_detail.ghostbusters.all(): - SubElement(r_msg, rpki.left_right.tag_list_published_objects, - uri = g.uri, **kw).text = g.ghostbuster.get_Base64() - for c in ca_detail.ee_certificates.all(): - SubElement(r_msg, rpki.left_right.tag_list_published_objects, - uri = c.uri, **kw).text = c.cert.get_Base64() - - def handle_list_received_resources(self, q_pdu, r_msg): - """ - <list_received_resources/> server. - """ + if len(r_msg) != 1: + raise rpki.exceptions.BadIRDBReply("Expected exactly one PDU from IRDB: %s" % r_msg.pretty_print_content()) - logger.debug(".handle_list_received_resources() %s", ElementToString(q_pdu)) - tenant_handle = q_pdu.get("tenant_handle") - msg_tag = q_pdu.get("tag") - for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__tenant__tenant_handle = tenant_handle, - state = "active", latest_ca_cert__isnull = False): - cert = ca_detail.latest_ca_cert - resources = cert.get_3779resources() - r_pdu = SubElement(r_msg, rpki.left_right.tag_list_received_resources, - tenant_handle = tenant_handle, - parent_handle = ca_detail.ca.parent.parent_handle, - uri = ca_detail.ca_cert_uri, - notBefore = str(cert.getNotBefore()), - notAfter = str(cert.getNotAfter()), - sia_uri = cert.get_sia_directory_uri(), - aia_uri = cert.get_aia_uri(), - asn = str(resources.asn), - ipv4 = str(resources.v4), - ipv6 = str(resources.v6)) - if msg_tag is not None: - r_pdu.set("tag", msg_tag) - - @tornado.gen.coroutine - def left_right_handler(self, handler): - """ - Process one left-right message. - """ + bag = rpki.resource_set.resource_bag( + asn = rpki.resource_set.resource_set_as(r_msg[0].get("asn")), + v4 = rpki.resource_set.resource_set_ipv4(r_msg[0].get("ipv4")), + v6 = rpki.resource_set.resource_set_ipv6(r_msg[0].get("ipv6")), + valid_until = rpki.sundial.datetime.fromXMLtime(r_msg[0].get("valid_until"))) - logger.debug("Entering left_right_handler()") + raise tornado.gen.Return(bag) - content_type = handler.request.headers["Content-Type"] - if content_type not in rpki.left_right.allowed_content_types: - handler.set_status(415, "No handler for Content-Type %s" % content_type) - handler.finish() - return + @tornado.gen.coroutine + def irdb_query_roa_requests(self, tenant_handle): + """ + Ask IRDB about self's ROA requests. + """ - handler.set_header("Content-Type", rpki.left_right.content_type) + q_msg = self._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_list_roa_requests, tenant_handle = tenant_handle) + r_msg = yield self.irdb_query(q_msg) + raise tornado.gen.Return(r_msg) - try: - q_cms = rpki.left_right.cms_msg(DER = handler.request.body) - q_msg = q_cms.unwrap((self.bpki_ta, self.irbe_cert)) - r_msg = Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, - type = "reply", version = rpki.left_right.version) - self.irbe_cms_timestamp = q_cms.check_replay(self.irbe_cms_timestamp, handler.request.path) + @tornado.gen.coroutine + def irdb_query_ghostbuster_requests(self, tenant_handle, parent_handles): + """ + Ask IRDB about self's ghostbuster record requests. + """ - assert q_msg.tag.startswith(rpki.left_right.xmlns) - assert all(q_pdu.tag.startswith(rpki.left_right.xmlns) for q_pdu in q_msg) + q_msg = self._compose_left_right_query() + for parent_handle in parent_handles: + SubElement(q_msg, rpki.left_right.tag_list_ghostbuster_requests, + tenant_handle = tenant_handle, parent_handle = parent_handle) + r_msg = yield self.irdb_query(q_msg) + raise tornado.gen.Return(r_msg) - if q_msg.get("version") != rpki.left_right.version: - raise rpki.exceptions.BadQuery("Unrecognized protocol version") + @tornado.gen.coroutine + def irdb_query_ee_certificate_requests(self, tenant_handle): + """ + Ask IRDB about self's EE certificate requests. + """ - if q_msg.get("type") != "query": - raise rpki.exceptions.BadQuery("Message type is not query") + q_msg = self._compose_left_right_query() + SubElement(q_msg, rpki.left_right.tag_list_ee_certificate_requests, tenant_handle = tenant_handle) + r_msg = yield self.irdb_query(q_msg) + raise tornado.gen.Return(r_msg) - for q_pdu in q_msg: + @property + def left_right_models(self): + """ + Map element tag to rpkidb model. + """ try: - action = q_pdu.get("action") - model = self.left_right_models.get(q_pdu.tag) - - if q_pdu.tag in self.left_right_trivial_handlers: - self.left_right_trivial_handlers[q_pdu.tag](q_pdu, r_msg) - - elif action in ("get", "list"): - for obj in model.objects.xml_list(q_pdu): - obj.xml_template.encode(obj, q_pdu, r_msg) - - elif action == "destroy": - obj = model.objects.xml_get_for_delete(q_pdu) - yield obj.xml_pre_delete_hook(self) - obj.delete() - obj.xml_template.acknowledge(obj, q_pdu, r_msg) + return self._left_right_models + except AttributeError: + import rpki.rpkidb.models # pylint: disable=W0621 + self._left_right_models = { + rpki.left_right.tag_tenant : rpki.rpkidb.models.Tenant, + rpki.left_right.tag_bsc : rpki.rpkidb.models.BSC, + rpki.left_right.tag_parent : rpki.rpkidb.models.Parent, + rpki.left_right.tag_child : rpki.rpkidb.models.Child, + rpki.left_right.tag_repository : rpki.rpkidb.models.Repository } + return self._left_right_models + + @property + def left_right_trivial_handlers(self): + """ + Map element tag to bound handler methods for trivial PDU types. + """ - elif action in ("create", "set"): - obj = model.objects.xml_get_or_create(q_pdu) - obj.xml_template.decode(obj, q_pdu) - obj.xml_pre_save_hook(q_pdu) - obj.save() - yield obj.xml_post_save_hook(self, q_pdu) - obj.xml_template.acknowledge(obj, q_pdu, r_msg) + try: + return self._left_right_trivial_handlers + except AttributeError: + self._left_right_trivial_handlers = { + rpki.left_right.tag_list_published_objects : self.handle_list_published_objects, + rpki.left_right.tag_list_received_resources : self.handle_list_received_resources } + return self._left_right_trivial_handlers + + def handle_list_published_objects(self, q_pdu, r_msg): + """ + <list_published_objects/> server. + """ + + tenant_handle = q_pdu.get("tenant_handle") + msg_tag = q_pdu.get("tag") + + kw = dict(tenant_handle = tenant_handle) + if msg_tag is not None: + kw.update(tag = msg_tag) + + for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__tenant__tenant_handle = tenant_handle, state = "active"): + SubElement(r_msg, rpki.left_right.tag_list_published_objects, + uri = ca_detail.crl_uri, **kw).text = ca_detail.latest_crl.get_Base64() + SubElement(r_msg, rpki.left_right.tag_list_published_objects, + uri = ca_detail.manifest_uri, **kw).text = ca_detail.latest_manifest.get_Base64() + for c in ca_detail.child_certs.all(): + SubElement(r_msg, rpki.left_right.tag_list_published_objects, + uri = c.uri, child_handle = c.child.child_handle, **kw).text = c.cert.get_Base64() + for r in ca_detail.roas.filter(roa__isnull = False): + SubElement(r_msg, rpki.left_right.tag_list_published_objects, + uri = r.uri, **kw).text = r.roa.get_Base64() + for g in ca_detail.ghostbusters.all(): + SubElement(r_msg, rpki.left_right.tag_list_published_objects, + uri = g.uri, **kw).text = g.ghostbuster.get_Base64() + for c in ca_detail.ee_certificates.all(): + SubElement(r_msg, rpki.left_right.tag_list_published_objects, + uri = c.uri, **kw).text = c.cert.get_Base64() + + def handle_list_received_resources(self, q_pdu, r_msg): + """ + <list_received_resources/> server. + """ + + logger.debug(".handle_list_received_resources() %s", ElementToString(q_pdu)) + tenant_handle = q_pdu.get("tenant_handle") + msg_tag = q_pdu.get("tag") + for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__tenant__tenant_handle = tenant_handle, + state = "active", latest_ca_cert__isnull = False): + cert = ca_detail.latest_ca_cert + resources = cert.get_3779resources() + r_pdu = SubElement(r_msg, rpki.left_right.tag_list_received_resources, + tenant_handle = tenant_handle, + parent_handle = ca_detail.ca.parent.parent_handle, + uri = ca_detail.ca_cert_uri, + notBefore = str(cert.getNotBefore()), + notAfter = str(cert.getNotAfter()), + sia_uri = cert.get_sia_directory_uri(), + aia_uri = cert.get_aia_uri(), + asn = str(resources.asn), + ipv4 = str(resources.v4), + ipv6 = str(resources.v6)) + if msg_tag is not None: + r_pdu.set("tag", msg_tag) + + @tornado.gen.coroutine + def left_right_handler(self, handler): + """ + Process one left-right message. + """ + + logger.debug("Entering left_right_handler()") + + content_type = handler.request.headers["Content-Type"] + if content_type not in rpki.left_right.allowed_content_types: + handler.set_status(415, "No handler for Content-Type %s" % content_type) + handler.finish() + return + + handler.set_header("Content-Type", rpki.left_right.content_type) - else: - raise rpki.exceptions.BadQuery("Unrecognized action %r" % action) + try: + q_cms = rpki.left_right.cms_msg(DER = handler.request.body) + q_msg = q_cms.unwrap((self.bpki_ta, self.irbe_cert)) + r_msg = Element(rpki.left_right.tag_msg, nsmap = rpki.left_right.nsmap, + type = "reply", version = rpki.left_right.version) + self.irbe_cms_timestamp = q_cms.check_replay(self.irbe_cms_timestamp, handler.request.path) + + assert q_msg.tag.startswith(rpki.left_right.xmlns) + assert all(q_pdu.tag.startswith(rpki.left_right.xmlns) for q_pdu in q_msg) + + if q_msg.get("version") != rpki.left_right.version: + raise rpki.exceptions.BadQuery("Unrecognized protocol version") + + if q_msg.get("type") != "query": + raise rpki.exceptions.BadQuery("Message type is not query") + + for q_pdu in q_msg: + + try: + action = q_pdu.get("action") + model = self.left_right_models.get(q_pdu.tag) + + if q_pdu.tag in self.left_right_trivial_handlers: + self.left_right_trivial_handlers[q_pdu.tag](q_pdu, r_msg) + + elif action in ("get", "list"): + for obj in model.objects.xml_list(q_pdu): + obj.xml_template.encode(obj, q_pdu, r_msg) + + elif action == "destroy": + obj = model.objects.xml_get_for_delete(q_pdu) + yield obj.xml_pre_delete_hook(self) + obj.delete() + obj.xml_template.acknowledge(obj, q_pdu, r_msg) + + elif action in ("create", "set"): + obj = model.objects.xml_get_or_create(q_pdu) + obj.xml_template.decode(obj, q_pdu) + obj.xml_pre_save_hook(q_pdu) + obj.save() + yield obj.xml_post_save_hook(self, q_pdu) + obj.xml_template.acknowledge(obj, q_pdu, r_msg) + + else: + raise rpki.exceptions.BadQuery("Unrecognized action %r" % action) + + except Exception, e: + if not isinstance(e, rpki.exceptions.NotFound): + logger.exception("Unhandled exception serving left-right PDU %r", q_pdu) + error_tenant_handle = q_pdu.get("tenant_handle") + error_tag = q_pdu.get("tag") + r_pdu = SubElement(r_msg, rpki.left_right.tag_report_error, error_code = e.__class__.__name__) + r_pdu.text = str(e) + if error_tag is not None: + r_pdu.set("tag", error_tag) + if error_tenant_handle is not None: + r_pdu.set("tenant_handle", error_tenant_handle) + break + + handler.set_status(200) + handler.finish(rpki.left_right.cms_msg().wrap(r_msg, self.rpkid_key, self.rpkid_cert)) + logger.debug("Normal exit from left_right_handler()") except Exception, e: - if not isinstance(e, rpki.exceptions.NotFound): - logger.exception("Unhandled exception serving left-right PDU %r", q_pdu) - error_tenant_handle = q_pdu.get("tenant_handle") - error_tag = q_pdu.get("tag") - r_pdu = SubElement(r_msg, rpki.left_right.tag_report_error, error_code = e.__class__.__name__) - r_pdu.text = str(e) - if error_tag is not None: - r_pdu.set("tag", error_tag) - if error_tenant_handle is not None: - r_pdu.set("tenant_handle", error_tenant_handle) - break - - handler.set_status(200) - handler.finish(rpki.left_right.cms_msg().wrap(r_msg, self.rpkid_key, self.rpkid_cert)) - logger.debug("Normal exit from left_right_handler()") - - except Exception, e: - logger.exception("Unhandled exception serving left-right request") - handler.set_status(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) - handler.finish() - - @tornado.gen.coroutine - def up_down_handler(self, handler, tenant_handle, child_handle): - """ - Process one up-down PDU. - """ + logger.exception("Unhandled exception serving left-right request") + handler.set_status(500, "Unhandled exception %s: %s" % (e.__class__.__name__, e)) + handler.finish() - logger.debug("Entering up_down_handler()") + @tornado.gen.coroutine + def up_down_handler(self, handler, tenant_handle, child_handle): + """ + Process one up-down PDU. + """ - content_type = handler.request.headers["Content-Type"] - if content_type not in rpki.up_down.allowed_content_types: - handler.set_status(415, "No handler for Content-Type %s" % content_type) - handler.finish() - return + logger.debug("Entering up_down_handler()") - try: - child = rpki.rpkidb.models.Child.objects.get(tenant__tenant_handle = tenant_handle, child_handle = child_handle) - q_der = handler.request.body - r_der = yield child.serve_up_down(self, q_der) - handler.set_header("Content-Type", rpki.up_down.content_type) - handler.set_status(200) - handler.finish(r_der) + content_type = handler.request.headers["Content-Type"] + if content_type not in rpki.up_down.allowed_content_types: + handler.set_status(415, "No handler for Content-Type %s" % content_type) + handler.finish() + return - except rpki.rpkidb.models.Child.DoesNotExist: - logger.info("Child %r of tenant %r not found", child_handle, tenant_handle) - handler.set_status(400, "Child %r not found" % child_handle) - handler.finish() + try: + child = rpki.rpkidb.models.Child.objects.get(tenant__tenant_handle = tenant_handle, child_handle = child_handle) + q_der = handler.request.body + r_der = yield child.serve_up_down(self, q_der) + handler.set_header("Content-Type", rpki.up_down.content_type) + handler.set_status(200) + handler.finish(r_der) + + except rpki.rpkidb.models.Child.DoesNotExist: + logger.info("Child %r of tenant %r not found", child_handle, tenant_handle) + handler.set_status(400, "Child %r not found" % child_handle) + handler.finish() - except Exception, e: - logger.exception("Unhandled exception processing up-down request") - handler.set_status(400, "Could not process PDU: %s" % e) - handler.finish() + except Exception, e: + logger.exception("Unhandled exception processing up-down request") + handler.set_status(400, "Could not process PDU: %s" % e) + handler.finish() class publication_queue(object): - """ - Utility to simplify publication from within rpkid. - - General idea here is to accumulate a collection of objects to be - published, in one or more repositories, each potentially with its - own completion callback. Eventually we want to publish everything - we've accumulated, at which point we need to iterate over the - collection and do repository.call_pubd() for each repository. - """ - - replace = True - - def __init__(self, rpkid): - self.rpkid = rpkid - self.clear() - - def clear(self): - self.repositories = {} - self.msgs = {} - self.handlers = {} - if self.replace: - self.uris = {} - - def queue(self, uri, repository, handler = None, - old_obj = None, new_obj = None, old_hash = None): - - assert old_obj is not None or new_obj is not None or old_hash is not None - assert old_obj is None or old_hash is None - assert old_obj is None or isinstance(old_obj, rpki.x509.uri_dispatch(uri)) - assert new_obj is None or isinstance(new_obj, rpki.x509.uri_dispatch(uri)) - - logger.debug("Queuing publication action: uri %s, old %r, new %r, hash %s", - uri, old_obj, new_obj, old_hash) - - # id(repository) may need to change to repository.peer_contact_uri - # once we convert from our custom SQL cache to Django ORM. - - rid = id(repository) - if rid not in self.repositories: - self.repositories[rid] = repository - self.msgs[rid] = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, - type = "query", version = rpki.publication.version) - - if self.replace and uri in self.uris: - logger.debug("Removing publication duplicate %r", self.uris[uri]) - old_pdu = self.uris.pop(uri) - self.msgs[rid].remove(old_pdu) - pdu_hash = old_pdu.get("hash") - elif old_hash is not None: - pdu_hash = old_hash - elif old_obj is None: - pdu_hash = None - else: - pdu_hash = rpki.x509.sha256(old_obj.get_DER()).encode("hex") - - if new_obj is None: - pdu = SubElement(self.msgs[rid], rpki.publication.tag_withdraw, uri = uri, hash = pdu_hash) - else: - pdu = SubElement(self.msgs[rid], rpki.publication.tag_publish, uri = uri) - pdu.text = new_obj.get_Base64() - if pdu_hash is not None: - pdu.set("hash", pdu_hash) - - if handler is not None: - tag = str(id(pdu)) - self.handlers[tag] = handler - pdu.set("tag", tag) - - if self.replace: - self.uris[uri] = pdu - - @tornado.gen.coroutine - def call_pubd(self): - for rid in self.repositories: - logger.debug("Calling pubd[%r]", self.repositories[rid]) - yield self.repositories[rid].call_pubd(self.rpkid, self.msgs[rid], self.handlers) - self.clear() - - @property - def size(self): - return sum(len(self.msgs[rid]) for rid in self.repositories) - - def empty(self): - return not self.msgs + """ + Utility to simplify publication from within rpkid. + + General idea here is to accumulate a collection of objects to be + published, in one or more repositories, each potentially with its + own completion callback. Eventually we want to publish everything + we've accumulated, at which point we need to iterate over the + collection and do repository.call_pubd() for each repository. + """ + + replace = True + + def __init__(self, rpkid): + self.rpkid = rpkid + self.clear() + + def clear(self): + self.repositories = {} + self.msgs = {} + self.handlers = {} + if self.replace: + self.uris = {} + + def queue(self, uri, repository, handler = None, + old_obj = None, new_obj = None, old_hash = None): + + assert old_obj is not None or new_obj is not None or old_hash is not None + assert old_obj is None or old_hash is None + assert old_obj is None or isinstance(old_obj, rpki.x509.uri_dispatch(uri)) + assert new_obj is None or isinstance(new_obj, rpki.x509.uri_dispatch(uri)) + + logger.debug("Queuing publication action: uri %s, old %r, new %r, hash %s", + uri, old_obj, new_obj, old_hash) + + # id(repository) may need to change to repository.peer_contact_uri + # once we convert from our custom SQL cache to Django ORM. + + rid = id(repository) + if rid not in self.repositories: + self.repositories[rid] = repository + self.msgs[rid] = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, + type = "query", version = rpki.publication.version) + + if self.replace and uri in self.uris: + logger.debug("Removing publication duplicate %r", self.uris[uri]) + old_pdu = self.uris.pop(uri) + self.msgs[rid].remove(old_pdu) + pdu_hash = old_pdu.get("hash") + elif old_hash is not None: + pdu_hash = old_hash + elif old_obj is None: + pdu_hash = None + else: + pdu_hash = rpki.x509.sha256(old_obj.get_DER()).encode("hex") + + if new_obj is None: + pdu = SubElement(self.msgs[rid], rpki.publication.tag_withdraw, uri = uri, hash = pdu_hash) + else: + pdu = SubElement(self.msgs[rid], rpki.publication.tag_publish, uri = uri) + pdu.text = new_obj.get_Base64() + if pdu_hash is not None: + pdu.set("hash", pdu_hash) + + if handler is not None: + tag = str(id(pdu)) + self.handlers[tag] = handler + pdu.set("tag", tag) + + if self.replace: + self.uris[uri] = pdu + + @tornado.gen.coroutine + def call_pubd(self): + for rid in self.repositories: + logger.debug("Calling pubd[%r]", self.repositories[rid]) + yield self.repositories[rid].call_pubd(self.rpkid, self.msgs[rid], self.handlers) + self.clear() + + @property + def size(self): + return sum(len(self.msgs[rid]) for rid in self.repositories) + + def empty(self): + return not self.msgs diff --git a/rpki/rpkid_tasks.py b/rpki/rpkid_tasks.py index 642d5dda..5c28afc3 100644 --- a/rpki/rpkid_tasks.py +++ b/rpki/rpkid_tasks.py @@ -43,497 +43,497 @@ logger = logging.getLogger(__name__) task_classes = () def queue_task(cls): - """ - Class decorator to add a new task class to task_classes. - """ + """ + Class decorator to add a new task class to task_classes. + """ - global task_classes - task_classes += (cls,) - return cls + global task_classes + task_classes += (cls,) + return cls class AbstractTask(object): - """ - Abstract base class for rpkid scheduler task objects. - """ - - ## @var timeslice - # How long before a task really should consider yielding the CPU to - # let something else run. - - timeslice = rpki.sundial.timedelta(seconds = 15) - - def __init__(self, rpkid, tenant, description = None): - self.rpkid = rpkid - self.tenant = tenant - self.description = description - self.runnable = tornado.locks.Event() - self.done_this = None - self.done_next = None - self.due_date = None - self.started = False - self.runnable.set() - self.clear() - - # This field belongs to the rpkid task_loop(), don't touch. - self.future = None - - def __repr__(self): - return rpki.log.log_repr(self, self.description) - - @tornado.gen.coroutine - def start(self): - try: - logger.debug("%r: Starting", self) - self.due_date = rpki.sundial.now() + self.timeslice - self.clear() - self.started = True - yield self.main() - except: - logger.exception("%r: Unhandled exception", self) - #raise - finally: - logger.debug("%r: Exiting", self) - self.due_date = None - self.started = False - self.clear() - if self.done_this is not None: - self.done_this.notify_all() - self.done_this = self.done_next - self.done_next = None - - def wait(self): - done = "done_next" if self.started else "done_this" - condition = getattr(self, done) - if condition is None: - condition = tornado.locks.Condition() - setattr(self, done, condition) - future = condition.wait() - return future - - def waiting(self): - return self.done_this is not None - - @tornado.gen.coroutine - def postpone(self): - logger.debug("%r: Postponing", self) - self.due_date = None - self.runnable.clear() - yield self.runnable.wait() - logger.debug("%r: Resuming", self) - self.due_date = rpki.sundial.now() + self.timeslice - - @property - def overdue(self): - return rpki.sundial.now() > self.due_date - - @tornado.gen.coroutine - def main(self): - raise NotImplementedError - - def clear(self): - pass + """ + Abstract base class for rpkid scheduler task objects. + """ + + ## @var timeslice + # How long before a task really should consider yielding the CPU to + # let something else run. + + timeslice = rpki.sundial.timedelta(seconds = 15) + + def __init__(self, rpkid, tenant, description = None): + self.rpkid = rpkid + self.tenant = tenant + self.description = description + self.runnable = tornado.locks.Event() + self.done_this = None + self.done_next = None + self.due_date = None + self.started = False + self.runnable.set() + self.clear() + + # This field belongs to the rpkid task_loop(), don't touch. + self.future = None + + def __repr__(self): + return rpki.log.log_repr(self, self.description) + + @tornado.gen.coroutine + def start(self): + try: + logger.debug("%r: Starting", self) + self.due_date = rpki.sundial.now() + self.timeslice + self.clear() + self.started = True + yield self.main() + except: + logger.exception("%r: Unhandled exception", self) + #raise + finally: + logger.debug("%r: Exiting", self) + self.due_date = None + self.started = False + self.clear() + if self.done_this is not None: + self.done_this.notify_all() + self.done_this = self.done_next + self.done_next = None + + def wait(self): + done = "done_next" if self.started else "done_this" + condition = getattr(self, done) + if condition is None: + condition = tornado.locks.Condition() + setattr(self, done, condition) + future = condition.wait() + return future + + def waiting(self): + return self.done_this is not None + + @tornado.gen.coroutine + def postpone(self): + logger.debug("%r: Postponing", self) + self.due_date = None + self.runnable.clear() + yield self.runnable.wait() + logger.debug("%r: Resuming", self) + self.due_date = rpki.sundial.now() + self.timeslice + + @property + def overdue(self): + return rpki.sundial.now() > self.due_date + + @tornado.gen.coroutine + def main(self): + raise NotImplementedError + + def clear(self): + pass @queue_task class PollParentTask(AbstractTask): - """ - Run the regular client poll cycle with each of this tenant's - parents, in turn. - """ - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Polling parents", self) + """ + Run the regular client poll cycle with each of this tenant's + parents, in turn. + """ + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Polling parents", self) + + for parent in self.tenant.parents.all(): + try: + logger.debug("%r: Executing list query", self) + r_msg = yield parent.up_down_list_query(rpkid = self.rpkid) + except: + logger.exception("%r: Couldn't get resource class list from parent %r, skipping", self, parent) + continue + + logger.debug("%r: Parsing list response", self) + + ca_map = dict((ca.parent_resource_class, ca) for ca in parent.cas.all()) + + for rc in r_msg.getiterator(rpki.up_down.tag_class): + try: + class_name = rc.get("class_name") + ca = ca_map.pop(class_name, None) + if ca is None: + logger.debug("%r: Creating new CA for resource class %r", self, class_name) + yield rpki.rpkidb.models.CA.create(rpkid = self.rpkid, parent = parent, rc = rc) + else: + logger.debug("%r: Checking updates for existing CA %r for resource class %r", self, ca, class_name) + yield ca.check_for_updates(rpkid = self.rpkid, parent = parent, rc = rc) + except: + logger.exception("Couldn't update resource class %r, skipping", class_name) + + for ca, class_name in ca_map.iteritems(): + logger.debug("%r: Destroying orphaned CA %r for resource class %r", self, ca, class_name) + yield ca.destroy(parent) - for parent in self.tenant.parents.all(): - try: - logger.debug("%r: Executing list query", self) - r_msg = yield parent.up_down_list_query(rpkid = self.rpkid) - except: - logger.exception("%r: Couldn't get resource class list from parent %r, skipping", self, parent) - continue - logger.debug("%r: Parsing list response", self) - - ca_map = dict((ca.parent_resource_class, ca) for ca in parent.cas.all()) +@queue_task +class UpdateChildrenTask(AbstractTask): + """ + Check for updated IRDB data for all of this tenant's children and + issue new certs as necessary. Must handle changes both in + resources and in expiration date. + """ + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Updating children", self) + now = rpki.sundial.now() + rsn = now + rpki.sundial.timedelta(seconds = self.tenant.regen_margin) + publisher = rpki.rpkid.publication_queue(self.rpkid) + + for child in self.tenant.children.all(): + try: + if self.overdue: + yield publisher.call_pubd() + yield self.postpone() + + child_certs = list(child.child_certs.filter(ca_detail__state = "active")) + + if child_certs: + irdb_resources = yield self.rpkid.irdb_query_child_resources(child.tenant.tenant_handle, child.child_handle) + + for child_cert in child_certs: + ca_detail = child_cert.ca_detail + old_resources = child_cert.cert.get_3779resources() + new_resources = old_resources & irdb_resources & ca_detail.latest_ca_cert.get_3779resources() + old_aia = child_cert.cert.get_AIA()[0] + new_aia = ca_detail.ca_cert_uri + + assert child_cert.gski == child_cert.cert.gSKI() + + if new_resources.empty(): + logger.debug("Resources shrank to the null set, revoking and withdrawing child %s certificate g(SKI) %s", child.child_handle, child_cert.gski) + child_cert.revoke(publisher = publisher) + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + elif (old_resources != new_resources or old_aia != new_aia or (old_resources.valid_until < rsn and irdb_resources.valid_until > now and old_resources.valid_until != irdb_resources.valid_until)): + logger.debug("Need to reissue child %s certificate g(SKI) %s", child.child_handle, child_cert.gski) + if old_resources != new_resources: + logger.debug("Child %s g(SKI) %s resources changed: old %s new %s", child.child_handle, child_cert.gski, old_resources, new_resources) + if old_resources.valid_until != irdb_resources.valid_until: + logger.debug("Child %s g(SKI) %s validity changed: old %s new %s", child.child_handle, child_cert.gski, old_resources.valid_until, irdb_resources.valid_until) + + new_resources.valid_until = irdb_resources.valid_until + child_cert.reissue(ca_detail = ca_detail, resources = new_resources, publisher = publisher) + + elif old_resources.valid_until < now: + logger.debug("Child %s certificate g(SKI) %s has expired: cert.valid_until %s, irdb.valid_until %s", child.child_handle, child_cert.gski, old_resources.valid_until, irdb_resources.valid_until) + child_cert.delete() + publisher.queue(uri = child_cert.uri, old_obj = child_cert.cert, repository = ca_detail.ca.parent.repository) + ca_detail.generate_manifest(publisher = publisher) + + except: + logger.exception("%r: Couldn't update child %r, skipping", self, child) - for rc in r_msg.getiterator(rpki.up_down.tag_class): try: - class_name = rc.get("class_name") - ca = ca_map.pop(class_name, None) - if ca is None: - logger.debug("%r: Creating new CA for resource class %r", self, class_name) - yield rpki.rpkidb.models.CA.create(rpkid = self.rpkid, parent = parent, rc = rc) - else: - logger.debug("%r: Checking updates for existing CA %r for resource class %r", self, ca, class_name) - yield ca.check_for_updates(rpkid = self.rpkid, parent = parent, rc = rc) + yield publisher.call_pubd() except: - logger.exception("Couldn't update resource class %r, skipping", class_name) - - for ca, class_name in ca_map.iteritems(): - logger.debug("%r: Destroying orphaned CA %r for resource class %r", self, ca, class_name) - yield ca.destroy(parent) + logger.exception("%r: Couldn't publish, skipping", self) @queue_task -class UpdateChildrenTask(AbstractTask): - """ - Check for updated IRDB data for all of this tenant's children and - issue new certs as necessary. Must handle changes both in - resources and in expiration date. - """ - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Updating children", self) - now = rpki.sundial.now() - rsn = now + rpki.sundial.timedelta(seconds = self.tenant.regen_margin) - publisher = rpki.rpkid.publication_queue(self.rpkid) - - for child in self.tenant.children.all(): - try: - if self.overdue: - yield publisher.call_pubd() - yield self.postpone() - - child_certs = list(child.child_certs.filter(ca_detail__state = "active")) - - if child_certs: - irdb_resources = yield self.rpkid.irdb_query_child_resources(child.tenant.tenant_handle, child.child_handle) - - for child_cert in child_certs: - ca_detail = child_cert.ca_detail - old_resources = child_cert.cert.get_3779resources() - new_resources = old_resources & irdb_resources & ca_detail.latest_ca_cert.get_3779resources() - old_aia = child_cert.cert.get_AIA()[0] - new_aia = ca_detail.ca_cert_uri - - assert child_cert.gski == child_cert.cert.gSKI() - - if new_resources.empty(): - logger.debug("Resources shrank to the null set, revoking and withdrawing child %s certificate g(SKI) %s", child.child_handle, child_cert.gski) - child_cert.revoke(publisher = publisher) - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - elif (old_resources != new_resources or old_aia != new_aia or (old_resources.valid_until < rsn and irdb_resources.valid_until > now and old_resources.valid_until != irdb_resources.valid_until)): - logger.debug("Need to reissue child %s certificate g(SKI) %s", child.child_handle, child_cert.gski) - if old_resources != new_resources: - logger.debug("Child %s g(SKI) %s resources changed: old %s new %s", child.child_handle, child_cert.gski, old_resources, new_resources) - if old_resources.valid_until != irdb_resources.valid_until: - logger.debug("Child %s g(SKI) %s validity changed: old %s new %s", child.child_handle, child_cert.gski, old_resources.valid_until, irdb_resources.valid_until) - - new_resources.valid_until = irdb_resources.valid_until - child_cert.reissue(ca_detail = ca_detail, resources = new_resources, publisher = publisher) - - elif old_resources.valid_until < now: - logger.debug("Child %s certificate g(SKI) %s has expired: cert.valid_until %s, irdb.valid_until %s", child.child_handle, child_cert.gski, old_resources.valid_until, irdb_resources.valid_until) - child_cert.delete() - publisher.queue(uri = child_cert.uri, old_obj = child_cert.cert, repository = ca_detail.ca.parent.repository) - ca_detail.generate_manifest(publisher = publisher) - - except: - logger.exception("%r: Couldn't update child %r, skipping", self, child) - - try: - yield publisher.call_pubd() - except: - logger.exception("%r: Couldn't publish, skipping", self) +class UpdateROAsTask(AbstractTask): + """ + Generate or update ROAs for this tenant. + """ + def clear(self): + self.publisher = None + self.ca_details = None + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Updating ROAs", self) + + try: + r_msg = yield self.rpkid.irdb_query_roa_requests(self.tenant.tenant_handle) + except: + logger.exception("Could not fetch ROA requests for %s, skipping", self.tenant.tenant_handle) + return + + logger.debug("%r: Received response to query for ROA requests: %r", self, r_msg) + + roas = {} + seen = set() + orphans = [] + updates = [] + self.publisher = rpki.rpkid.publication_queue(self.rpkid) + self.ca_details = set() + + for roa in self.tenant.roas.all(): + k = (roa.asn, str(roa.ipv4), str(roa.ipv6)) + if k not in roas: + roas[k] = roa + elif (roa.roa is not None and roa.cert is not None and roa.ca_detail is not None and roa.ca_detail.state == "active" and (roas[k].roa is None or roas[k].cert is None or roas[k].ca_detail is None or roas[k].ca_detail.state != "active")): + orphans.append(roas[k]) + roas[k] = roa + else: + orphans.append(roa) + + for r_pdu in r_msg: + k = (r_pdu.get("asn"), r_pdu.get("ipv4"), r_pdu.get("ipv6")) + if k in seen: + logger.warning("%r: Skipping duplicate ROA request %r", self, r_pdu) + else: + seen.add(k) + roa = roas.pop(k, None) + if roa is None: + roa = rpki.rpkidb.models.ROA(tenant = self.tenant, asn = long(r_pdu.get("asn")), ipv4 = r_pdu.get("ipv4"), ipv6 = r_pdu.get("ipv6")) + logger.debug("%r: Created new %r", self, roa) + else: + logger.debug("%r: Found existing %r", self, roa) + updates.append(roa) + + orphans.extend(roas.itervalues()) + + while updates: + if self.overdue: + yield self.publish() + yield self.postpone() + roa = updates.pop(0) + try: + roa.update(publisher = self.publisher, fast = True) + self.ca_details.add(roa.ca_detail) + except rpki.exceptions.NoCoveringCertForROA: + logger.warning("%r: No covering certificate for %r, skipping", self, roa) + except: + logger.exception("%r: Could not update %r, skipping", self, roa) + + for roa in orphans: + try: + self.ca_details.add(roa.ca_detail) + roa.revoke(publisher = self.publisher, fast = True) + except: + logger.exception("%r: Could not revoke %r", self, roa) -@queue_task -class UpdateROAsTask(AbstractTask): - """ - Generate or update ROAs for this tenant. - """ - - def clear(self): - self.publisher = None - self.ca_details = None - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Updating ROAs", self) - - try: - r_msg = yield self.rpkid.irdb_query_roa_requests(self.tenant.tenant_handle) - except: - logger.exception("Could not fetch ROA requests for %s, skipping", self.tenant.tenant_handle) - return - - logger.debug("%r: Received response to query for ROA requests: %r", self, r_msg) - - roas = {} - seen = set() - orphans = [] - updates = [] - self.publisher = rpki.rpkid.publication_queue(self.rpkid) - self.ca_details = set() - - for roa in self.tenant.roas.all(): - k = (roa.asn, str(roa.ipv4), str(roa.ipv6)) - if k not in roas: - roas[k] = roa - elif (roa.roa is not None and roa.cert is not None and roa.ca_detail is not None and roa.ca_detail.state == "active" and (roas[k].roa is None or roas[k].cert is None or roas[k].ca_detail is None or roas[k].ca_detail.state != "active")): - orphans.append(roas[k]) - roas[k] = roa - else: - orphans.append(roa) - - for r_pdu in r_msg: - k = (r_pdu.get("asn"), r_pdu.get("ipv4"), r_pdu.get("ipv6")) - if k in seen: - logger.warning("%r: Skipping duplicate ROA request %r", self, r_pdu) - else: - seen.add(k) - roa = roas.pop(k, None) - if roa is None: - roa = rpki.rpkidb.models.ROA(tenant = self.tenant, asn = long(r_pdu.get("asn")), ipv4 = r_pdu.get("ipv4"), ipv6 = r_pdu.get("ipv6")) - logger.debug("%r: Created new %r", self, roa) - else: - logger.debug("%r: Found existing %r", self, roa) - updates.append(roa) - - orphans.extend(roas.itervalues()) - - while updates: - if self.overdue: yield self.publish() - yield self.postpone() - roa = updates.pop(0) - try: - roa.update(publisher = self.publisher, fast = True) - self.ca_details.add(roa.ca_detail) - except rpki.exceptions.NoCoveringCertForROA: - logger.warning("%r: No covering certificate for %r, skipping", self, roa) - except: - logger.exception("%r: Could not update %r, skipping", self, roa) - - for roa in orphans: - try: - self.ca_details.add(roa.ca_detail) - roa.revoke(publisher = self.publisher, fast = True) - except: - logger.exception("%r: Could not revoke %r", self, roa) - - yield self.publish() - - @tornado.gen.coroutine - def publish(self): - if not self.publisher.empty(): - for ca_detail in self.ca_details: - logger.debug("%r: Generating new CRL for %r", self, ca_detail) - ca_detail.generate_crl(publisher = self.publisher) - logger.debug("%r: Generating new manifest for %r", self, ca_detail) - ca_detail.generate_manifest(publisher = self.publisher) - yield self.publisher.call_pubd() - self.ca_details.clear() + + @tornado.gen.coroutine + def publish(self): + if not self.publisher.empty(): + for ca_detail in self.ca_details: + logger.debug("%r: Generating new CRL for %r", self, ca_detail) + ca_detail.generate_crl(publisher = self.publisher) + logger.debug("%r: Generating new manifest for %r", self, ca_detail) + ca_detail.generate_manifest(publisher = self.publisher) + yield self.publisher.call_pubd() + self.ca_details.clear() @queue_task class UpdateGhostbustersTask(AbstractTask): - """ - Generate or update Ghostbuster records for this tenant. - - This was originally based on the ROA update code. It's possible - that both could benefit from refactoring, but at this point the - potential scaling issues for ROAs completely dominate structure of - the ROA code, and aren't relevant here unless someone is being - exceptionally silly. - """ - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Updating Ghostbuster records", self) - parent_handles = set(p.parent_handle for p in self.tenant.parents.all()) - - try: - r_msg = yield self.rpkid.irdb_query_ghostbuster_requests(self.tenant.tenant_handle, parent_handles) - - ghostbusters = {} - orphans = [] - publisher = rpki.rpkid.publication_queue(self.rpkid) - ca_details = set() - seen = set() - - for ghostbuster in self.tenant.ghostbusters.all(): - k = (ghostbuster.ca_detail.pk, ghostbuster.vcard) - if ghostbuster.ca_detail.state != "active" or k in ghostbusters: - orphans.append(ghostbuster) - else: - ghostbusters[k] = ghostbuster - - for r_pdu in r_msg: + """ + Generate or update Ghostbuster records for this tenant. + + This was originally based on the ROA update code. It's possible + that both could benefit from refactoring, but at this point the + potential scaling issues for ROAs completely dominate structure of + the ROA code, and aren't relevant here unless someone is being + exceptionally silly. + """ + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Updating Ghostbuster records", self) + parent_handles = set(p.parent_handle for p in self.tenant.parents.all()) + try: - self.tenant.parents.get(parent_handle = r_pdu.get("parent_handle")) - except rpki.rpkidb.models.Parent.DoesNotExist: - logger.warning("%r: Unknown parent_handle %r in Ghostbuster request, skipping", self, r_pdu.get("parent_handle")) - continue - k = (r_pdu.get("parent_handle"), r_pdu.text) - if k in seen: - logger.warning("%r: Skipping duplicate Ghostbuster request %r", self, r_pdu) - continue - seen.add(k) - for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__parent_handle = r_pdu.get("parent_handle"), ca__parent__tenant = self.tenant, state = "active"): - ghostbuster = ghostbusters.pop((ca_detail.pk, r_pdu.text), None) - if ghostbuster is None: - ghostbuster = rpki.rpkidb.models.Ghostbuster(tenant = self.tenant, ca_detail = ca_detail, vcard = r_pdu.text) - logger.debug("%r: Created new %r for %r", self, ghostbuster, r_pdu.get("parent_handle")) - else: - logger.debug("%r: Found existing %r for %s", self, ghostbuster, r_pdu.get("parent_handle")) - ghostbuster.update(publisher = publisher, fast = True) - ca_details.add(ca_detail) - - orphans.extend(ghostbusters.itervalues()) - for ghostbuster in orphans: - ca_details.add(ghostbuster.ca_detail) - ghostbuster.revoke(publisher = publisher, fast = True) - - for ca_detail in ca_details: - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - yield publisher.call_pubd() - - except: - logger.exception("Could not update Ghostbuster records for %s, skipping", self.tenant.tenant_handle) + r_msg = yield self.rpkid.irdb_query_ghostbuster_requests(self.tenant.tenant_handle, parent_handles) + + ghostbusters = {} + orphans = [] + publisher = rpki.rpkid.publication_queue(self.rpkid) + ca_details = set() + seen = set() + + for ghostbuster in self.tenant.ghostbusters.all(): + k = (ghostbuster.ca_detail.pk, ghostbuster.vcard) + if ghostbuster.ca_detail.state != "active" or k in ghostbusters: + orphans.append(ghostbuster) + else: + ghostbusters[k] = ghostbuster + + for r_pdu in r_msg: + try: + self.tenant.parents.get(parent_handle = r_pdu.get("parent_handle")) + except rpki.rpkidb.models.Parent.DoesNotExist: + logger.warning("%r: Unknown parent_handle %r in Ghostbuster request, skipping", self, r_pdu.get("parent_handle")) + continue + k = (r_pdu.get("parent_handle"), r_pdu.text) + if k in seen: + logger.warning("%r: Skipping duplicate Ghostbuster request %r", self, r_pdu) + continue + seen.add(k) + for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__parent_handle = r_pdu.get("parent_handle"), ca__parent__tenant = self.tenant, state = "active"): + ghostbuster = ghostbusters.pop((ca_detail.pk, r_pdu.text), None) + if ghostbuster is None: + ghostbuster = rpki.rpkidb.models.Ghostbuster(tenant = self.tenant, ca_detail = ca_detail, vcard = r_pdu.text) + logger.debug("%r: Created new %r for %r", self, ghostbuster, r_pdu.get("parent_handle")) + else: + logger.debug("%r: Found existing %r for %s", self, ghostbuster, r_pdu.get("parent_handle")) + ghostbuster.update(publisher = publisher, fast = True) + ca_details.add(ca_detail) + + orphans.extend(ghostbusters.itervalues()) + for ghostbuster in orphans: + ca_details.add(ghostbuster.ca_detail) + ghostbuster.revoke(publisher = publisher, fast = True) + + for ca_detail in ca_details: + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + yield publisher.call_pubd() + + except: + logger.exception("Could not update Ghostbuster records for %s, skipping", self.tenant.tenant_handle) @queue_task class UpdateEECertificatesTask(AbstractTask): - """ - Generate or update EE certificates for this self. - - Not yet sure what kind of scaling constraints this task might have, - so keeping it simple for initial version, we can optimize later. - """ - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Updating EE certificates", self) - - try: - r_msg = yield self.rpkid.irdb_query_ee_certificate_requests(self.tenant.tenant_handle) - - publisher = rpki.rpkid.publication_queue(self.rpkid) - - existing = dict() - for ee in self.tenant.ee_certificates.all(): - gski = ee.gski - if gski not in existing: - existing[gski] = set() - existing[gski].add(ee) - - ca_details = set() - - for r_pdu in r_msg: - gski = r_pdu.get("gski") - ees = existing.pop(gski, ()) - - resources = rpki.resource_set.resource_bag( - asn = rpki.resource_set.resource_set_as(r_pdu.get("asn")), - v4 = rpki.resource_set.resource_set_ipv4(r_pdu.get("ipv4")), - v6 = rpki.resource_set.resource_set_ipv6(r_pdu.get("ipv6")), - valid_until = rpki.sundial.datetime.fromXMLtime(r_pdu.get("valid_until"))) - covering = self.tenant.find_covering_ca_details(resources) - ca_details.update(covering) - - for ee in ees: - if ee.ca_detail in covering: - logger.debug("Updating existing EE certificate for %s %s", gski, resources) - ee.reissue(resources = resources, publisher = publisher) - covering.remove(ee.ca_detail) - else: - logger.debug("Existing EE certificate for %s %s is no longer covered", gski, resources) - ee.revoke(publisher = publisher) - - subject_name = rpki.x509.X501DN.from_cn(r_pdu.get("cn"), r_pdu.get("sn")) - subject_key = rpki.x509.PKCS10(Base64 = r_pdu[0].text).getPublicKey() - - for ca_detail in covering: - logger.debug("No existing EE certificate for %s %s", gski, resources) - rpki.rpkidb.models.EECertificate.create( # sic: class method, not Django manager method (for now, anyway) - ca_detail = ca_detail, - subject_name = subject_name, - subject_key = subject_key, - resources = resources, - publisher = publisher, - eku = r_pdu.get("eku", "").split(",") or None) - - # Anything left is an orphan - for ees in existing.values(): - for ee in ees: - ca_details.add(ee.ca_detail) - ee.revoke(publisher = publisher) - - for ca_detail in ca_details: - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - yield publisher.call_pubd() - - except: - logger.exception("Could not update EE certificates for %s, skipping", self.tenant.tenant_handle) + """ + Generate or update EE certificates for this self. + Not yet sure what kind of scaling constraints this task might have, + so keeping it simple for initial version, we can optimize later. + """ + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Updating EE certificates", self) -@queue_task -class RegenerateCRLsAndManifestsTask(AbstractTask): - """ - Generate new CRLs and manifests as necessary for all of this tenant's - CAs. Extracting nextUpdate from a manifest is hard at the moment - due to implementation silliness, so for now we generate a new - manifest whenever we generate a new CRL - - This code also cleans up tombstones left behind by revoked ca_detail - objects, since we're walking through the relevant portions of the - database anyway. - """ - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Regenerating CRLs and manifests", self) - - try: - now = rpki.sundial.now() - crl_interval = rpki.sundial.timedelta(seconds = self.tenant.crl_interval) - regen_margin = max(rpki.sundial.timedelta(seconds = self.rpkid.cron_period) * 2, crl_interval / 4) - publisher = rpki.rpkid.publication_queue(self.rpkid) - - for ca in rpki.rpkidb.models.CA.objects.filter(parent__tenant = self.tenant): try: - for ca_detail in ca.ca_details.filter(state = "revoked"): - if now > ca_detail.latest_crl.getNextUpdate(): - ca_detail.destroy(ca = ca, publisher = publisher) - for ca_detail in ca.ca_details.filter(state__in = ("active", "deprecated")): - if now + regen_margin > ca_detail.latest_crl.getNextUpdate(): - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) + r_msg = yield self.rpkid.irdb_query_ee_certificate_requests(self.tenant.tenant_handle) + + publisher = rpki.rpkid.publication_queue(self.rpkid) + + existing = dict() + for ee in self.tenant.ee_certificates.all(): + gski = ee.gski + if gski not in existing: + existing[gski] = set() + existing[gski].add(ee) + + ca_details = set() + + for r_pdu in r_msg: + gski = r_pdu.get("gski") + ees = existing.pop(gski, ()) + + resources = rpki.resource_set.resource_bag( + asn = rpki.resource_set.resource_set_as(r_pdu.get("asn")), + v4 = rpki.resource_set.resource_set_ipv4(r_pdu.get("ipv4")), + v6 = rpki.resource_set.resource_set_ipv6(r_pdu.get("ipv6")), + valid_until = rpki.sundial.datetime.fromXMLtime(r_pdu.get("valid_until"))) + covering = self.tenant.find_covering_ca_details(resources) + ca_details.update(covering) + + for ee in ees: + if ee.ca_detail in covering: + logger.debug("Updating existing EE certificate for %s %s", gski, resources) + ee.reissue(resources = resources, publisher = publisher) + covering.remove(ee.ca_detail) + else: + logger.debug("Existing EE certificate for %s %s is no longer covered", gski, resources) + ee.revoke(publisher = publisher) + + subject_name = rpki.x509.X501DN.from_cn(r_pdu.get("cn"), r_pdu.get("sn")) + subject_key = rpki.x509.PKCS10(Base64 = r_pdu[0].text).getPublicKey() + + for ca_detail in covering: + logger.debug("No existing EE certificate for %s %s", gski, resources) + rpki.rpkidb.models.EECertificate.create( # sic: class method, not Django manager method (for now, anyway) + ca_detail = ca_detail, + subject_name = subject_name, + subject_key = subject_key, + resources = resources, + publisher = publisher, + eku = r_pdu.get("eku", "").split(",") or None) + + # Anything left is an orphan + for ees in existing.values(): + for ee in ees: + ca_details.add(ee.ca_detail) + ee.revoke(publisher = publisher) + + for ca_detail in ca_details: + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + yield publisher.call_pubd() + except: - logger.exception("%r: Couldn't regenerate CRLs and manifests for CA %r, skipping", self, ca) + logger.exception("Could not update EE certificates for %s, skipping", self.tenant.tenant_handle) - yield publisher.call_pubd() - except: - logger.exception("%r: Couldn't publish updated CRLs and manifests, skipping", self) +@queue_task +class RegenerateCRLsAndManifestsTask(AbstractTask): + """ + Generate new CRLs and manifests as necessary for all of this tenant's + CAs. Extracting nextUpdate from a manifest is hard at the moment + due to implementation silliness, so for now we generate a new + manifest whenever we generate a new CRL + + This code also cleans up tombstones left behind by revoked ca_detail + objects, since we're walking through the relevant portions of the + database anyway. + """ + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Regenerating CRLs and manifests", self) + + try: + now = rpki.sundial.now() + crl_interval = rpki.sundial.timedelta(seconds = self.tenant.crl_interval) + regen_margin = max(rpki.sundial.timedelta(seconds = self.rpkid.cron_period) * 2, crl_interval / 4) + publisher = rpki.rpkid.publication_queue(self.rpkid) + + for ca in rpki.rpkidb.models.CA.objects.filter(parent__tenant = self.tenant): + try: + for ca_detail in ca.ca_details.filter(state = "revoked"): + if now > ca_detail.latest_crl.getNextUpdate(): + ca_detail.destroy(ca = ca, publisher = publisher) + for ca_detail in ca.ca_details.filter(state__in = ("active", "deprecated")): + if now + regen_margin > ca_detail.latest_crl.getNextUpdate(): + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + except: + logger.exception("%r: Couldn't regenerate CRLs and manifests for CA %r, skipping", self, ca) + + yield publisher.call_pubd() + + except: + logger.exception("%r: Couldn't publish updated CRLs and manifests, skipping", self) @queue_task class CheckFailedPublication(AbstractTask): - """ - Periodic check for objects we tried to publish but failed (eg, due - to pubd being down or unreachable). - """ - - @tornado.gen.coroutine - def main(self): - logger.debug("%r: Checking for failed publication actions", self) - - try: - publisher = rpki.rpkid.publication_queue(self.rpkid) - for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__tenant = self.tenant, state = "active"): - ca_detail.check_failed_publication(publisher) - yield publisher.call_pubd() - - except: - logger.exception("%r: Couldn't run failed publications, skipping", self) + """ + Periodic check for objects we tried to publish but failed (eg, due + to pubd being down or unreachable). + """ + + @tornado.gen.coroutine + def main(self): + logger.debug("%r: Checking for failed publication actions", self) + + try: + publisher = rpki.rpkid.publication_queue(self.rpkid) + for ca_detail in rpki.rpkidb.models.CADetail.objects.filter(ca__parent__tenant = self.tenant, state = "active"): + ca_detail.check_failed_publication(publisher) + yield publisher.call_pubd() + + except: + logger.exception("%r: Couldn't run failed publications, skipping", self) diff --git a/rpki/rpkidb/models.py b/rpki/rpkidb/models.py index 1a293360..ab16a176 100644 --- a/rpki/rpkidb/models.py +++ b/rpki/rpkidb/models.py @@ -5,7 +5,6 @@ Django ORM models for rpkid. from __future__ import unicode_literals import logging -import base64 import tornado.gen import tornado.web @@ -35,189 +34,189 @@ logger = logging.getLogger(__name__) # very simple change given migrations. class XMLTemplate(object): - """ - Encapsulate all the voodoo for transcoding between lxml and ORM. - """ - - # Type map to simplify declaration of Base64 sub-elements. - - element_type = dict(bpki_cert = rpki.x509.X509, - bpki_glue = rpki.x509.X509, - pkcs10_request = rpki.x509.PKCS10, - signing_cert = rpki.x509.X509, - signing_cert_crl = rpki.x509.CRL) - - - def __init__(self, name, attributes = (), booleans = (), elements = (), readonly = (), handles = ()): - self.name = name - self.handles = handles - self.attributes = attributes - self.booleans = booleans - self.elements = elements - self.readonly = readonly - - - def encode(self, obj, q_pdu, r_msg): - """ - Encode an ORM object as XML. """ - - r_pdu = SubElement(r_msg, rpki.left_right.xmlns + self.name, nsmap = rpki.left_right.nsmap, action = q_pdu.get("action")) - if self.name != "tenant": - r_pdu.set("tenant_handle", obj.tenant.tenant_handle) - r_pdu.set(self.name + "_handle", getattr(obj, self.name + "_handle")) - if q_pdu.get("tag"): - r_pdu.set("tag", q_pdu.get("tag")) - for h in self.handles: - k = h.xml_template.name - v = getattr(obj, k) - if v is not None: - r_pdu.set(k + "_handle", getattr(v, k + "_handle")) - for k in self.attributes: - v = getattr(obj, k) - if v is not None: - r_pdu.set(k, str(v)) - for k in self.booleans: - if getattr(obj, k): - r_pdu.set(k, "yes") - for k in self.elements + self.readonly: - v = getattr(obj, k) - if v is not None and not v.empty(): - SubElement(r_pdu, rpki.left_right.xmlns + k).text = v.get_Base64() - logger.debug("XMLTemplate.encode(): %s", ElementToString(r_pdu)) - - - def acknowledge(self, obj, q_pdu, r_msg): - """ - Add an acknowledgement PDU in response to a create, set, or - destroy action. - - This includes a bit of special-case code for BSC objects which has - to go somewhere; we could handle it via some kind method of - call-out to the BSC model, but it's not worth building a general - mechanism for one case, so we do it inline and have done. - """ - - assert q_pdu.tag == rpki.left_right.xmlns + self.name - action = q_pdu.get("action") - r_pdu = SubElement(r_msg, rpki.left_right.xmlns + self.name, nsmap = rpki.left_right.nsmap, action = action) - if self.name != "tenant": - r_pdu.set("tenant_handle", obj.tenant.tenant_handle) - r_pdu.set(self.name + "_handle", getattr(obj, self.name + "_handle")) - if q_pdu.get("tag"): - r_pdu.set("tag", q_pdu.get("tag")) - if self.name == "bsc" and action != "destroy" and obj.pkcs10_request is not None: - assert not obj.pkcs10_request.empty() - SubElement(r_pdu, rpki.left_right.xmlns + "pkcs10_request").text = obj.pkcs10_request.get_Base64() - logger.debug("XMLTemplate.acknowledge(): %s", ElementToString(r_pdu)) - - - def decode(self, obj, q_pdu): - """ - Decode XML into an ORM object. - """ - - logger.debug("XMLTemplate.decode(): %r %s", obj, ElementToString(q_pdu)) - assert q_pdu.tag == rpki.left_right.xmlns + self.name - for h in self.handles: - k = h.xml_template.name - v = q_pdu.get(k + "_handle") - if v is not None: - setattr(obj, k, h.objects.get(**{k + "_handle" : v, "tenant" : obj.tenant})) - for k in self.attributes: - v = q_pdu.get(k) - if v is not None: - v.encode("ascii") - if v.isdigit(): - v = long(v) - setattr(obj, k, v) - for k in self.booleans: - v = q_pdu.get(k) - if v is not None: - setattr(obj, k, v == "yes") - for k in self.elements: - v = q_pdu.findtext(rpki.left_right.xmlns + k) - if v and v.strip(): - setattr(obj, k, self.element_type[k](Base64 = v)) + Encapsulate all the voodoo for transcoding between lxml and ORM. + """ + + # Type map to simplify declaration of Base64 sub-elements. + + element_type = dict(bpki_cert = rpki.x509.X509, + bpki_glue = rpki.x509.X509, + pkcs10_request = rpki.x509.PKCS10, + signing_cert = rpki.x509.X509, + signing_cert_crl = rpki.x509.CRL) + + + def __init__(self, name, attributes = (), booleans = (), elements = (), readonly = (), handles = ()): + self.name = name + self.handles = handles + self.attributes = attributes + self.booleans = booleans + self.elements = elements + self.readonly = readonly + + + def encode(self, obj, q_pdu, r_msg): + """ + Encode an ORM object as XML. + """ + + r_pdu = SubElement(r_msg, rpki.left_right.xmlns + self.name, nsmap = rpki.left_right.nsmap, action = q_pdu.get("action")) + if self.name != "tenant": + r_pdu.set("tenant_handle", obj.tenant.tenant_handle) + r_pdu.set(self.name + "_handle", getattr(obj, self.name + "_handle")) + if q_pdu.get("tag"): + r_pdu.set("tag", q_pdu.get("tag")) + for h in self.handles: + k = h.xml_template.name + v = getattr(obj, k) + if v is not None: + r_pdu.set(k + "_handle", getattr(v, k + "_handle")) + for k in self.attributes: + v = getattr(obj, k) + if v is not None: + r_pdu.set(k, str(v)) + for k in self.booleans: + if getattr(obj, k): + r_pdu.set(k, "yes") + for k in self.elements + self.readonly: + v = getattr(obj, k) + if v is not None and not v.empty(): + SubElement(r_pdu, rpki.left_right.xmlns + k).text = v.get_Base64() + logger.debug("XMLTemplate.encode(): %s", ElementToString(r_pdu)) + + + def acknowledge(self, obj, q_pdu, r_msg): + """ + Add an acknowledgement PDU in response to a create, set, or + destroy action. + + This includes a bit of special-case code for BSC objects which has + to go somewhere; we could handle it via some kind method of + call-out to the BSC model, but it's not worth building a general + mechanism for one case, so we do it inline and have done. + """ + + assert q_pdu.tag == rpki.left_right.xmlns + self.name + action = q_pdu.get("action") + r_pdu = SubElement(r_msg, rpki.left_right.xmlns + self.name, nsmap = rpki.left_right.nsmap, action = action) + if self.name != "tenant": + r_pdu.set("tenant_handle", obj.tenant.tenant_handle) + r_pdu.set(self.name + "_handle", getattr(obj, self.name + "_handle")) + if q_pdu.get("tag"): + r_pdu.set("tag", q_pdu.get("tag")) + if self.name == "bsc" and action != "destroy" and obj.pkcs10_request is not None: + assert not obj.pkcs10_request.empty() + SubElement(r_pdu, rpki.left_right.xmlns + "pkcs10_request").text = obj.pkcs10_request.get_Base64() + logger.debug("XMLTemplate.acknowledge(): %s", ElementToString(r_pdu)) + + + def decode(self, obj, q_pdu): + """ + Decode XML into an ORM object. + """ + + logger.debug("XMLTemplate.decode(): %r %s", obj, ElementToString(q_pdu)) + assert q_pdu.tag == rpki.left_right.xmlns + self.name + for h in self.handles: + k = h.xml_template.name + v = q_pdu.get(k + "_handle") + if v is not None: + setattr(obj, k, h.objects.get(**{k + "_handle" : v, "tenant" : obj.tenant})) + for k in self.attributes: + v = q_pdu.get(k) + if v is not None: + v.encode("ascii") + if v.isdigit(): + v = long(v) + setattr(obj, k, v) + for k in self.booleans: + v = q_pdu.get(k) + if v is not None: + setattr(obj, k, v == "yes") + for k in self.elements: + v = q_pdu.findtext(rpki.left_right.xmlns + k) + if v and v.strip(): + setattr(obj, k, self.element_type[k](Base64 = v)) class XMLManager(models.Manager): # pylint: disable=W0232 - """ - Add a few methods which locate or create an object or objects - corresponding to the handles in an XML element, as appropriate. - - This assumes that models which use it have an "xml_template" - class attribute holding an XMLTemplate object (above). - """ - - def xml_get_or_create(self, xml): - name = self.model.xml_template.name - action = xml.get("action") - assert xml.tag == rpki.left_right.xmlns + name and action in ("create", "set") - d = { name + "_handle" : xml.get(name + "_handle") } - if name != "tenant" and action != "create": - d["tenant__tenant_handle"] = xml.get("tenant_handle") - logger.debug("XMLManager.xml_get_or_create(): name %s action %s filter %r", name, action, d) - result = self.model(**d) if action == "create" else self.get(**d) - if name != "tenant" and action == "create": - result.tenant = Tenant.objects.get(tenant_handle = xml.get("tenant_handle")) - logger.debug("XMLManager.xml_get_or_create(): name %s action %s filter %r result %r", name, action, d, result) - return result - - def xml_list(self, xml): - name = self.model.xml_template.name - action = xml.get("action") - assert xml.tag == rpki.left_right.xmlns + name and action in ("get", "list") - d = {} - if action == "get": - d[name + "_handle"] = xml.get(name + "_handle") - if name != "tenant": - d["tenant__tenant_handle"] = xml.get("tenant_handle") - logger.debug("XMLManager.xml_list(): name %s action %s filter %r", name, action, d) - result = self.filter(**d) if d else self.all() - logger.debug("XMLManager.xml_list(): name %s action %s filter %r result %r", name, action, d, result) - return result - - def xml_get_for_delete(self, xml): - name = self.model.xml_template.name - action = xml.get("action") - assert xml.tag == rpki.left_right.xmlns + name and action == "destroy" - d = { name + "_handle" : xml.get(name + "_handle") } - if name != "tenant": - d["tenant__tenant_handle"] = xml.get("tenant_handle") - logger.debug("XMLManager.xml_get_for_delete(): name %s action %s filter %r", name, action, d) - result = self.get(**d) - logger.debug("XMLManager.xml_get_for_delete(): name %s action %s filter %r result %r", name, action, d, result) - return result + """ + Add a few methods which locate or create an object or objects + corresponding to the handles in an XML element, as appropriate. + + This assumes that models which use it have an "xml_template" + class attribute holding an XMLTemplate object (above). + """ + + def xml_get_or_create(self, xml): + name = self.model.xml_template.name + action = xml.get("action") + assert xml.tag == rpki.left_right.xmlns + name and action in ("create", "set") + d = { name + "_handle" : xml.get(name + "_handle") } + if name != "tenant" and action != "create": + d["tenant__tenant_handle"] = xml.get("tenant_handle") + logger.debug("XMLManager.xml_get_or_create(): name %s action %s filter %r", name, action, d) + result = self.model(**d) if action == "create" else self.get(**d) + if name != "tenant" and action == "create": + result.tenant = Tenant.objects.get(tenant_handle = xml.get("tenant_handle")) + logger.debug("XMLManager.xml_get_or_create(): name %s action %s filter %r result %r", name, action, d, result) + return result + + def xml_list(self, xml): + name = self.model.xml_template.name + action = xml.get("action") + assert xml.tag == rpki.left_right.xmlns + name and action in ("get", "list") + d = {} + if action == "get": + d[name + "_handle"] = xml.get(name + "_handle") + if name != "tenant": + d["tenant__tenant_handle"] = xml.get("tenant_handle") + logger.debug("XMLManager.xml_list(): name %s action %s filter %r", name, action, d) + result = self.filter(**d) if d else self.all() + logger.debug("XMLManager.xml_list(): name %s action %s filter %r result %r", name, action, d, result) + return result + + def xml_get_for_delete(self, xml): + name = self.model.xml_template.name + action = xml.get("action") + assert xml.tag == rpki.left_right.xmlns + name and action == "destroy" + d = { name + "_handle" : xml.get(name + "_handle") } + if name != "tenant": + d["tenant__tenant_handle"] = xml.get("tenant_handle") + logger.debug("XMLManager.xml_get_for_delete(): name %s action %s filter %r", name, action, d) + result = self.get(**d) + logger.debug("XMLManager.xml_get_for_delete(): name %s action %s filter %r result %r", name, action, d, result) + return result def xml_hooks(cls): - """ - Class decorator to add default XML hooks. - """ + """ + Class decorator to add default XML hooks. + """ - # Maybe inheritance from an abstract model would work here. Then - # again, maybe we could use this decorator to do something prettier - # for the XMLTemplate setup. Whatever. Gussie up later. + # Maybe inheritance from an abstract model would work here. Then + # again, maybe we could use this decorator to do something prettier + # for the XMLTemplate setup. Whatever. Gussie up later. - def default_xml_pre_save_hook(self, q_pdu): - logger.debug("default_xml_pre_save_hook()") + def default_xml_pre_save_hook(self, q_pdu): + logger.debug("default_xml_pre_save_hook()") - @tornado.gen.coroutine - def default_xml_post_save_hook(self, rpkid, q_pdu): - logger.debug("default_xml_post_save_hook()") + @tornado.gen.coroutine + def default_xml_post_save_hook(self, rpkid, q_pdu): + logger.debug("default_xml_post_save_hook()") - @tornado.gen.coroutine - def default_xml_pre_delete_hook(self, rpkid): - logger.debug("default_xml_pre_delete_hook()") + @tornado.gen.coroutine + def default_xml_pre_delete_hook(self, rpkid): + logger.debug("default_xml_pre_delete_hook()") - for name, method in (("xml_pre_save_hook", default_xml_pre_save_hook), - ("xml_post_save_hook", default_xml_post_save_hook), - ("xml_pre_delete_hook", default_xml_pre_delete_hook)): - if not hasattr(cls, name): - setattr(cls, name, method) + for name, method in (("xml_pre_save_hook", default_xml_pre_save_hook), + ("xml_post_save_hook", default_xml_post_save_hook), + ("xml_pre_delete_hook", default_xml_pre_delete_hook)): + if not hasattr(cls, name): + setattr(cls, name, method) - return cls + return cls # Models. @@ -227,2128 +226,2129 @@ def xml_hooks(cls): @xml_hooks class Tenant(models.Model): - tenant_handle = models.SlugField(max_length = 255) - use_hsm = models.BooleanField(default = False) - crl_interval = models.BigIntegerField(null = True) - regen_margin = models.BigIntegerField(null = True) - bpki_cert = CertificateField(null = True) - bpki_glue = CertificateField(null = True) - objects = XMLManager() - - xml_template = XMLTemplate( - name = "tenant", - attributes = ("crl_interval", "regen_margin"), - booleans = ("use_hsm",), - elements = ("bpki_cert", "bpki_glue")) - - @tornado.gen.coroutine - def xml_pre_delete_hook(self, rpkid): - yield [parent.destroy() for parent in self.parents.all()] - - @tornado.gen.coroutine - def xml_post_save_hook(self, rpkid, q_pdu): - rekey = q_pdu.get("rekey") - revoke = q_pdu.get("revoke") - reissue = q_pdu.get("reissue") - revoke_forgotten = q_pdu.get("revoke_forgotten") - - if q_pdu.get("clear_replay_protection"): - for parent in self.parents.all(): - parent.clear_replay_protection() - for child in self.children.all(): - child.clear_replay_protection() - for repository in self.repositories.all(): - repository.clear_replay_protection() - - futures = [] - - if rekey or revoke or reissue or revoke_forgotten: - for parent in self.parents.all(): - if rekey: - futures.append(parent.serve_rekey(rpkid)) - if revoke: - futures.append(parent.serve_revoke(rpkid)) - if reissue: - futures.append(parent.serve_reissue(rpkid)) - if revoke_forgotten: - futures.append(parent.serve_revoke_forgotten(rpkid)) - - if q_pdu.get("publish_world_now"): - futures.append(self.serve_publish_world_now(rpkid)) - if q_pdu.get("run_now"): - futures.append(self.serve_run_now(rpkid)) - - yield futures - - - @tornado.gen.coroutine - def serve_publish_world_now(self, rpkid): - publisher = rpki.rpkid.publication_queue(rpkid) - repositories = set() - objects = dict() - - for parent in self.parents.all(): - - repository = parent.repository - if repository.peer_contact_uri in repositories: - continue - repositories.add(repository.peer_contact_uri) - q_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, - type = "query", version = rpki.publication.version) - SubElement(q_msg, rpki.publication.tag_list, tag = "list") - - r_msg = yield repository.call_pubd(rpkid, q_msg, length_check = False) - - for r_pdu in r_msg: - assert r_pdu.tag == rpki.publication.tag_list - if r_pdu.get("uri") in objects: - logger.warning("pubd reported multiple published copies of URI %r, this makes no sense, blundering onwards", r_pdu.get("uri")) - else: - objects[r_pdu.get("uri")] = (r_pdu.get("hash"), repository) - - def reconcile(uri, obj, repository): - h, r = objects.pop(uri, (None, None)) - if h is not None: - assert r == repository - publisher.queue(uri = uri, new_obj = obj, old_hash = h, repository = repository) - - for ca_detail in CADetail.objects.filter(ca__parent__tenant = self, state = "active"): - repository = ca_detail.ca.parent.repository - reconcile(uri = ca_detail.crl_uri, obj = ca_detail.latest_crl, repository = repository) - reconcile(uri = ca_detail.manifest_uri, obj = ca_detail.latest_manifest, repository = repository) - for c in ca_detail.child_certs.all(): - reconcile(uri = c.uri, obj = c.cert, repository = repository) - for r in ca_detail.roas.filter(roa__isnull = False): - reconcile(uri = r.uri, obj = r.roa, repository = repository) - for g in ca_detail.ghostbusters.all(): - reconcile(uri = g.uri, obj = g.ghostbuster, repository = repository) - for c in ca_detail.ee_certificates.all(): - reconcile(uri = c.uri, obj = c.cert, repository = repository) - for u in objects: - h, r = objects[u] - publisher.queue(uri = u, old_hash = h, repository = r) - - yield publisher.call_pubd() - - - @tornado.gen.coroutine - def serve_run_now(self, rpkid): - logger.debug("Forced immediate run of periodic actions for tenant %s[%r]", self.tenant_handle, self) - tasks = self.cron_tasks(rpkid) - rpkid.task_add(tasks) - futures = [task.wait() for task in tasks] - rpkid.task_run() - yield futures - - - def cron_tasks(self, rpkid): - try: - return self._cron_tasks - except AttributeError: - self._cron_tasks = tuple(task(rpkid, self) for task in rpki.rpkid_tasks.task_classes) - return self._cron_tasks - - - def find_covering_ca_details(self, resources): - """ - Return all active CADetails for this <tenant/> which cover a - particular set of resources. - - If we expected there to be a large number of CADetails, we - could add index tables and write fancy SQL query to do this, but - for the expected common case where there are only one or two - active CADetails per <tenant/>, it's probably not worth it. In - any case, this is an optimization we can leave for later. - """ - - return set(ca_detail - for ca_detail in CADetail.objects.filter(ca__parent__tenant = self, state = "active") - if ca_detail.covers(resources)) - - -@xml_hooks -class BSC(models.Model): - bsc_handle = models.SlugField(max_length = 255) - private_key_id = RSAPrivateKeyField() - pkcs10_request = PKCS10Field() - hash_alg = EnumField(choices = ("sha256",), default = "sha256") - signing_cert = CertificateField(null = True) - signing_cert_crl = CRLField(null = True) - tenant = models.ForeignKey(Tenant, related_name = "bscs") - objects = XMLManager() - - class Meta: # pylint: disable=C1001,W0232 - unique_together = ("tenant", "bsc_handle") - - xml_template = XMLTemplate( - name = "bsc", - elements = ("signing_cert", "signing_cert_crl"), - readonly = ("pkcs10_request",)) - - - def xml_pre_save_hook(self, q_pdu): - # Handle key generation, only supports RSA with SHA-256 for now. - if q_pdu.get("generate_keypair"): - assert q_pdu.get("key_type") in (None, "rsa") and q_pdu.get("hash_alg") in (None, "sha256") - self.private_key_id = rpki.x509.RSA.generate(keylength = int(q_pdu.get("key_length", 2048))) - self.pkcs10_request = rpki.x509.PKCS10.create(keypair = self.private_key_id) - - -@xml_hooks -class Repository(models.Model): - repository_handle = models.SlugField(max_length = 255) - peer_contact_uri = models.TextField(null = True) - rrdp_notification_uri = models.TextField(null = True) - bpki_cert = CertificateField(null = True) - bpki_glue = CertificateField(null = True) - last_cms_timestamp = SundialField(null = True) - bsc = models.ForeignKey(BSC, related_name = "repositories") - tenant = models.ForeignKey(Tenant, related_name = "repositories") - objects = XMLManager() - - class Meta: # pylint: disable=C1001,W0232 - unique_together = ("tenant", "repository_handle") + tenant_handle = models.SlugField(max_length = 255) + use_hsm = models.BooleanField(default = False) + crl_interval = models.BigIntegerField(null = True) + regen_margin = models.BigIntegerField(null = True) + bpki_cert = CertificateField(null = True) + bpki_glue = CertificateField(null = True) + objects = XMLManager() + + xml_template = XMLTemplate( + name = "tenant", + attributes = ("crl_interval", "regen_margin"), + booleans = ("use_hsm",), + elements = ("bpki_cert", "bpki_glue")) + + @tornado.gen.coroutine + def xml_pre_delete_hook(self, rpkid): + yield [parent.destroy() for parent in self.parents.all()] + + @tornado.gen.coroutine + def xml_post_save_hook(self, rpkid, q_pdu): + rekey = q_pdu.get("rekey") + revoke = q_pdu.get("revoke") + reissue = q_pdu.get("reissue") + revoke_forgotten = q_pdu.get("revoke_forgotten") + + if q_pdu.get("clear_replay_protection"): + for parent in self.parents.all(): + parent.clear_replay_protection() + for child in self.children.all(): + child.clear_replay_protection() + for repository in self.repositories.all(): + repository.clear_replay_protection() + + futures = [] + + if rekey or revoke or reissue or revoke_forgotten: + for parent in self.parents.all(): + if rekey: + futures.append(parent.serve_rekey(rpkid)) + if revoke: + futures.append(parent.serve_revoke(rpkid)) + if reissue: + futures.append(parent.serve_reissue(rpkid)) + if revoke_forgotten: + futures.append(parent.serve_revoke_forgotten(rpkid)) + + if q_pdu.get("publish_world_now"): + futures.append(self.serve_publish_world_now(rpkid)) + if q_pdu.get("run_now"): + futures.append(self.serve_run_now(rpkid)) + + yield futures + + + @tornado.gen.coroutine + def serve_publish_world_now(self, rpkid): + publisher = rpki.rpkid.publication_queue(rpkid) + repositories = set() + objects = dict() + + for parent in self.parents.all(): + + repository = parent.repository + if repository.peer_contact_uri in repositories: + continue + repositories.add(repository.peer_contact_uri) + q_msg = Element(rpki.publication.tag_msg, nsmap = rpki.publication.nsmap, + type = "query", version = rpki.publication.version) + SubElement(q_msg, rpki.publication.tag_list, tag = "list") + + r_msg = yield repository.call_pubd(rpkid, q_msg, length_check = False) + + for r_pdu in r_msg: + assert r_pdu.tag == rpki.publication.tag_list + if r_pdu.get("uri") in objects: + logger.warning("pubd reported multiple published copies of URI %r, this makes no sense, blundering onwards", r_pdu.get("uri")) + else: + objects[r_pdu.get("uri")] = (r_pdu.get("hash"), repository) + + def reconcile(uri, obj, repository): + h, r = objects.pop(uri, (None, None)) + if h is not None: + assert r == repository + publisher.queue(uri = uri, new_obj = obj, old_hash = h, repository = repository) + + for ca_detail in CADetail.objects.filter(ca__parent__tenant = self, state = "active"): + repository = ca_detail.ca.parent.repository + reconcile(uri = ca_detail.crl_uri, obj = ca_detail.latest_crl, repository = repository) + reconcile(uri = ca_detail.manifest_uri, obj = ca_detail.latest_manifest, repository = repository) + for c in ca_detail.child_certs.all(): + reconcile(uri = c.uri, obj = c.cert, repository = repository) + for r in ca_detail.roas.filter(roa__isnull = False): + reconcile(uri = r.uri, obj = r.roa, repository = repository) + for g in ca_detail.ghostbusters.all(): + reconcile(uri = g.uri, obj = g.ghostbuster, repository = repository) + for c in ca_detail.ee_certificates.all(): + reconcile(uri = c.uri, obj = c.cert, repository = repository) + for u in objects: + h, r = objects[u] + publisher.queue(uri = u, old_hash = h, repository = r) - xml_template = XMLTemplate( - name = "repository", - handles = (BSC,), - attributes = ("peer_contact_uri", "rrdp_notification_uri"), - elements = ("bpki_cert", "bpki_glue")) + yield publisher.call_pubd() - @tornado.gen.coroutine - def xml_post_save_hook(self, rpkid, q_pdu): - if q_pdu.get("clear_replay_protection"): - self.clear_replay_protection() + @tornado.gen.coroutine + def serve_run_now(self, rpkid): + logger.debug("Forced immediate run of periodic actions for tenant %s[%r]", self.tenant_handle, self) + tasks = self.cron_tasks(rpkid) + rpkid.task_add(tasks) + futures = [task.wait() for task in tasks] + rpkid.task_run() + yield futures - def clear_replay_protection(self): - self.last_cms_timestamp = None - self.save() + def cron_tasks(self, rpkid): + try: + return self._cron_tasks + except AttributeError: + self._cron_tasks = tuple(task(rpkid, self) for task in rpki.rpkid_tasks.task_classes) + return self._cron_tasks - @tornado.gen.coroutine - def call_pubd(self, rpkid, q_msg, handlers = {}, length_check = True): # pylint: disable=W0102 - """ - Send a message to publication daemon and return the response. + def find_covering_ca_details(self, resources): + """ + Return all active CADetails for this <tenant/> which cover a + particular set of resources. - As a convenience, attempting to send an empty message returns - immediate success without sending anything. + If we expected there to be a large number of CADetails, we + could add index tables and write fancy SQL query to do this, but + for the expected common case where there are only one or two + active CADetails per <tenant/>, it's probably not worth it. In + any case, this is an optimization we can leave for later. + """ - handlers is a dict of handler functions to process the response - PDUs. If the tag value in the response PDU appears in the dict, - the associated handler is called to process the PDU. If no tag - matches, a default handler is called to check for errors; a - handler value of False suppresses calling of the default handler. - """ + return set(ca_detail + for ca_detail in CADetail.objects.filter(ca__parent__tenant = self, state = "active") + if ca_detail.covers(resources)) - if len(q_msg) == 0: - return - for q_pdu in q_msg: - logger.info("Sending %r to pubd", q_pdu) +@xml_hooks +class BSC(models.Model): + bsc_handle = models.SlugField(max_length = 255) + private_key_id = RSAPrivateKeyField() + pkcs10_request = PKCS10Field() + hash_alg = EnumField(choices = ("sha256",), default = "sha256") + signing_cert = CertificateField(null = True) + signing_cert_crl = CRLField(null = True) + tenant = models.ForeignKey(Tenant, related_name = "bscs") + objects = XMLManager() - q_der = rpki.publication.cms_msg().wrap(q_msg, self.bsc.private_key_id, self.bsc.signing_cert, self.bsc.signing_cert_crl) + class Meta: # pylint: disable=C1001,W0232 + unique_together = ("tenant", "bsc_handle") - http_request = tornado.httpclient.HTTPRequest( - url = self.peer_contact_uri, - method = "POST", - body = q_der, - headers = { "Content-Type" : rpki.publication.content_type }) + xml_template = XMLTemplate( + name = "bsc", + elements = ("signing_cert", "signing_cert_crl"), + readonly = ("pkcs10_request",)) - http_response = yield rpkid.http_fetch(http_request) - # Tornado already checked http_response.code for us + def xml_pre_save_hook(self, q_pdu): + # Handle key generation, only supports RSA with SHA-256 for now. + if q_pdu.get("generate_keypair"): + assert q_pdu.get("key_type") in (None, "rsa") and q_pdu.get("hash_alg") in (None, "sha256") + self.private_key_id = rpki.x509.RSA.generate(keylength = int(q_pdu.get("key_length", 2048))) + self.pkcs10_request = rpki.x509.PKCS10.create(keypair = self.private_key_id) - content_type = http_response.headers.get("Content-Type") - if content_type not in rpki.publication.allowed_content_types: - raise rpki.exceptions.BadContentType("HTTP Content-Type %r, expected %r" % (rpki.publication.content_type, content_type)) +@xml_hooks +class Repository(models.Model): + repository_handle = models.SlugField(max_length = 255) + peer_contact_uri = models.TextField(null = True) + rrdp_notification_uri = models.TextField(null = True) + bpki_cert = CertificateField(null = True) + bpki_glue = CertificateField(null = True) + last_cms_timestamp = SundialField(null = True) + bsc = models.ForeignKey(BSC, related_name = "repositories") + tenant = models.ForeignKey(Tenant, related_name = "repositories") + objects = XMLManager() - r_der = http_response.body - r_cms = rpki.publication.cms_msg(DER = r_der) - r_msg = r_cms.unwrap((rpkid.bpki_ta, self.tenant.bpki_cert, self.tenant.bpki_glue, self.bpki_cert, self.bpki_glue)) - r_cms.check_replay_sql(self, self.peer_contact_uri) + class Meta: # pylint: disable=C1001,W0232 + unique_together = ("tenant", "repository_handle") - for r_pdu in r_msg: - handler = handlers.get(r_pdu.get("tag"), rpki.publication.raise_if_error) - if handler: - logger.debug("Calling pubd handler %r", handler) - handler(r_pdu) + xml_template = XMLTemplate( + name = "repository", + handles = (BSC,), + attributes = ("peer_contact_uri", "rrdp_notification_uri"), + elements = ("bpki_cert", "bpki_glue")) - if length_check and len(q_msg) != len(r_msg): - raise rpki.exceptions.BadPublicationReply("Wrong number of response PDUs from pubd: sent %r, got %r" % (q_msg, r_msg)) - raise tornado.gen.Return(r_msg) + @tornado.gen.coroutine + def xml_post_save_hook(self, rpkid, q_pdu): + if q_pdu.get("clear_replay_protection"): + self.clear_replay_protection() -@xml_hooks -class Parent(models.Model): - parent_handle = models.SlugField(max_length = 255) - bpki_cert = CertificateField(null = True) - bpki_glue = CertificateField(null = True) - peer_contact_uri = models.TextField(null = True) - sia_base = models.TextField(null = True) - sender_name = models.TextField(null = True) - recipient_name = models.TextField(null = True) - last_cms_timestamp = SundialField(null = True) - tenant = models.ForeignKey(Tenant, related_name = "parents") - bsc = models.ForeignKey(BSC, related_name = "parents") - repository = models.ForeignKey(Repository, related_name = "parents") - objects = XMLManager() - - class Meta: # pylint: disable=C1001,W0232 - unique_together = ("tenant", "parent_handle") - - xml_template = XMLTemplate( - name = "parent", - handles = (BSC, Repository), - attributes = ("peer_contact_uri", "sia_base", "sender_name", "recipient_name"), - elements = ("bpki_cert", "bpki_glue")) - - - @tornado.gen.coroutine - def xml_pre_delete_hook(self, rpkid): - yield self.destroy(rpkid, delete_parent = False) - - @tornado.gen.coroutine - def xml_post_save_hook(self, rpkid, q_pdu): - if q_pdu.get("clear_replay_protection"): - self.clear_replay_protection() - futures = [] - if q_pdu.get("rekey"): - futures.append(self.serve_rekey(rpkid)) - if q_pdu.get("revoke"): - futures.append(self.serve_revoke(rpkid)) - if q_pdu.get("reissue"): - futures.append(self.serve_reissue(rpkid)) - if q_pdu.get("revoke_forgotten"): - futures.append(self.serve_revoke_forgotten(rpkid)) - yield futures - - @tornado.gen.coroutine - def serve_rekey(self, rpkid): - yield [ca.rekey() for ca in self.cas.all()] - - @tornado.gen.coroutine - def serve_revoke(self, rpkid): - yield [ca.revoke() for ca in self.cas.all()] - - @tornado.gen.coroutine - def serve_reissue(self, rpkid): - yield [ca.reissue() for ca in self.cas.all()] - - def clear_replay_protection(self): - self.last_cms_timestamp = None - self.save() - - - @tornado.gen.coroutine - def get_skis(self, rpkid): - """ - Fetch SKIs that this parent thinks we have. In theory this should - agree with our own database, but in practice stuff can happen, so - sometimes we need to know what our parent thinks. + def clear_replay_protection(self): + self.last_cms_timestamp = None + self.save() - Result is a dictionary with the resource class name as key and a - set of SKIs as value. - This, like everything else dealing with SKIs in the up-down - protocol, is mis-named: we're really dealing with g(SKI) values, - not raw SKI values. Sorry. - """ + @tornado.gen.coroutine + def call_pubd(self, rpkid, q_msg, handlers = {}, length_check = True): # pylint: disable=W0102 + """ + Send a message to publication daemon and return the response. - r_msg = yield self.up_down_list_query(rpkid = rpkid) + As a convenience, attempting to send an empty message returns + immediate success without sending anything. - ski_map = {} + handlers is a dict of handler functions to process the response + PDUs. If the tag value in the response PDU appears in the dict, + the associated handler is called to process the PDU. If no tag + matches, a default handler is called to check for errors; a + handler value of False suppresses calling of the default handler. + """ - for rc in r_msg.getiterator(rpki.up_down.tag_class): - skis = set() - for c in rc.getiterator(rpki.up_down.tag_certificate): - skis.add(rpki.x509.X509(Base64 = c.text).gSKI()) - ski_map[rc.get("class_name")] = skis + if len(q_msg) == 0: + return - raise tornado.gen.Return(ski_map) + for q_pdu in q_msg: + logger.info("Sending %r to pubd", q_pdu) + q_der = rpki.publication.cms_msg().wrap(q_msg, self.bsc.private_key_id, self.bsc.signing_cert, self.bsc.signing_cert_crl) - @tornado.gen.coroutine - def revoke_skis(self, rpkid, rc_name, skis_to_revoke): - """ - Revoke a set of SKIs within a particular resource class. - """ + http_request = tornado.httpclient.HTTPRequest( + url = self.peer_contact_uri, + method = "POST", + body = q_der, + headers = { "Content-Type" : rpki.publication.content_type }) - for ski in skis_to_revoke: - logger.debug("Asking parent %r to revoke class %r, g(SKI) %s", self, rc_name, ski) - yield self.up_down_revoke_query(rpkid = rpkid, class_name = rc_name, ski = ski) + http_response = yield rpkid.http_fetch(http_request) + # Tornado already checked http_response.code for us - @tornado.gen.coroutine - def serve_revoke_forgotten(self, rpkid): - """ - Handle a left-right revoke_forgotten action for this parent. - - This is a bit fiddly: we have to compare the result of an up-down - list query with what we have locally and identify the SKIs of any - certificates that have gone missing. This should never happen in - ordinary operation, but can arise if we have somehow lost a - private key, in which case there is nothing more we can do with - the issued cert, so we have to clear it. As this really is not - supposed to happen, we don't clear it automatically, instead we - require an explicit trigger. - """ + content_type = http_response.headers.get("Content-Type") - skis_from_parent = yield self.get_skis(rpkid) - for rc_name, skis_to_revoke in skis_from_parent.iteritems(): - for ca_detail in CADetail.objects.filter(ca__parent = self).exclude(state = "revoked"): - skis_to_revoke.discard(ca_detail.latest_ca_cert.gSKI()) - yield self.revoke_skis(rpkid, rc_name, skis_to_revoke) + if content_type not in rpki.publication.allowed_content_types: + raise rpki.exceptions.BadContentType("HTTP Content-Type %r, expected %r" % (rpki.publication.content_type, content_type)) + r_der = http_response.body + r_cms = rpki.publication.cms_msg(DER = r_der) + r_msg = r_cms.unwrap((rpkid.bpki_ta, self.tenant.bpki_cert, self.tenant.bpki_glue, self.bpki_cert, self.bpki_glue)) + r_cms.check_replay_sql(self, self.peer_contact_uri) - @tornado.gen.coroutine - def destroy(self, rpkid, delete_parent = True): - """ - Delete all the CA stuff under this parent, and perhaps the parent - itself. - """ + for r_pdu in r_msg: + handler = handlers.get(r_pdu.get("tag"), rpki.publication.raise_if_error) + if handler: + logger.debug("Calling pubd handler %r", handler) + handler(r_pdu) - yield [ca.destroy(self) for ca in self.cas()] - yield self.serve_revoke_forgotten(rpkid) - if delete_parent: - self.delete() + if length_check and len(q_msg) != len(r_msg): + raise rpki.exceptions.BadPublicationReply("Wrong number of response PDUs from pubd: sent %r, got %r" % (q_msg, r_msg)) + raise tornado.gen.Return(r_msg) - def _compose_up_down_query(self, query_type): - return Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap, version = rpki.up_down.version, - sender = self.sender_name, recipient = self.recipient_name, type = query_type) + +@xml_hooks +class Parent(models.Model): + parent_handle = models.SlugField(max_length = 255) + bpki_cert = CertificateField(null = True) + bpki_glue = CertificateField(null = True) + peer_contact_uri = models.TextField(null = True) + sia_base = models.TextField(null = True) + sender_name = models.TextField(null = True) + recipient_name = models.TextField(null = True) + last_cms_timestamp = SundialField(null = True) + tenant = models.ForeignKey(Tenant, related_name = "parents") + bsc = models.ForeignKey(BSC, related_name = "parents") + repository = models.ForeignKey(Repository, related_name = "parents") + objects = XMLManager() + + class Meta: # pylint: disable=C1001,W0232 + unique_together = ("tenant", "parent_handle") + + xml_template = XMLTemplate( + name = "parent", + handles = (BSC, Repository), + attributes = ("peer_contact_uri", "sia_base", "sender_name", "recipient_name"), + elements = ("bpki_cert", "bpki_glue")) + + + @tornado.gen.coroutine + def xml_pre_delete_hook(self, rpkid): + yield self.destroy(rpkid, delete_parent = False) + + @tornado.gen.coroutine + def xml_post_save_hook(self, rpkid, q_pdu): + if q_pdu.get("clear_replay_protection"): + self.clear_replay_protection() + futures = [] + if q_pdu.get("rekey"): + futures.append(self.serve_rekey(rpkid)) + if q_pdu.get("revoke"): + futures.append(self.serve_revoke(rpkid)) + if q_pdu.get("reissue"): + futures.append(self.serve_reissue(rpkid)) + if q_pdu.get("revoke_forgotten"): + futures.append(self.serve_revoke_forgotten(rpkid)) + yield futures + + @tornado.gen.coroutine + def serve_rekey(self, rpkid): + yield [ca.rekey() for ca in self.cas.all()] + + @tornado.gen.coroutine + def serve_revoke(self, rpkid): + yield [ca.revoke() for ca in self.cas.all()] + + @tornado.gen.coroutine + def serve_reissue(self, rpkid): + yield [ca.reissue() for ca in self.cas.all()] + + def clear_replay_protection(self): + self.last_cms_timestamp = None + self.save() + + + @tornado.gen.coroutine + def get_skis(self, rpkid): + """ + Fetch SKIs that this parent thinks we have. In theory this should + agree with our own database, but in practice stuff can happen, so + sometimes we need to know what our parent thinks. + + Result is a dictionary with the resource class name as key and a + set of SKIs as value. + + This, like everything else dealing with SKIs in the up-down + protocol, is mis-named: we're really dealing with g(SKI) values, + not raw SKI values. Sorry. + """ + + r_msg = yield self.up_down_list_query(rpkid = rpkid) + + ski_map = {} + + for rc in r_msg.getiterator(rpki.up_down.tag_class): + skis = set() + for c in rc.getiterator(rpki.up_down.tag_certificate): + skis.add(rpki.x509.X509(Base64 = c.text).gSKI()) + ski_map[rc.get("class_name")] = skis + + raise tornado.gen.Return(ski_map) + + + @tornado.gen.coroutine + def revoke_skis(self, rpkid, rc_name, skis_to_revoke): + """ + Revoke a set of SKIs within a particular resource class. + """ + + for ski in skis_to_revoke: + logger.debug("Asking parent %r to revoke class %r, g(SKI) %s", self, rc_name, ski) + yield self.up_down_revoke_query(rpkid = rpkid, class_name = rc_name, ski = ski) + + + @tornado.gen.coroutine + def serve_revoke_forgotten(self, rpkid): + """ + Handle a left-right revoke_forgotten action for this parent. + + This is a bit fiddly: we have to compare the result of an up-down + list query with what we have locally and identify the SKIs of any + certificates that have gone missing. This should never happen in + ordinary operation, but can arise if we have somehow lost a + private key, in which case there is nothing more we can do with + the issued cert, so we have to clear it. As this really is not + supposed to happen, we don't clear it automatically, instead we + require an explicit trigger. + """ + + skis_from_parent = yield self.get_skis(rpkid) + for rc_name, skis_to_revoke in skis_from_parent.iteritems(): + for ca_detail in CADetail.objects.filter(ca__parent = self).exclude(state = "revoked"): + skis_to_revoke.discard(ca_detail.latest_ca_cert.gSKI()) + yield self.revoke_skis(rpkid, rc_name, skis_to_revoke) + + + @tornado.gen.coroutine + def destroy(self, rpkid, delete_parent = True): + """ + Delete all the CA stuff under this parent, and perhaps the parent + itself. + """ + + yield [ca.destroy(self) for ca in self.cas()] + yield self.serve_revoke_forgotten(rpkid) + if delete_parent: + self.delete() - @tornado.gen.coroutine - def up_down_list_query(self, rpkid): - q_msg = self._compose_up_down_query("list") - r_msg = yield self.query_up_down(rpkid, q_msg) - raise tornado.gen.Return(r_msg) + def _compose_up_down_query(self, query_type): + return Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap, version = rpki.up_down.version, + sender = self.sender_name, recipient = self.recipient_name, type = query_type) - @tornado.gen.coroutine - def up_down_issue_query(self, rpkid, ca, ca_detail): - logger.debug("Parent.up_down_issue_query(): caRepository %r rpkiManifest %r rpkiNotify %r", - ca.sia_uri, ca_detail.manifest_uri, ca.parent.repository.rrdp_notification_uri) - pkcs10 = rpki.x509.PKCS10.create( - keypair = ca_detail.private_key_id, - is_ca = True, - caRepository = ca.sia_uri, - rpkiManifest = ca_detail.manifest_uri, - rpkiNotify = ca.parent.repository.rrdp_notification_uri) - q_msg = self._compose_up_down_query("issue") - q_pdu = SubElement(q_msg, rpki.up_down.tag_request, class_name = ca.parent_resource_class) - q_pdu.text = pkcs10.get_Base64() - r_msg = yield self.query_up_down(rpkid, q_msg) - raise tornado.gen.Return(r_msg) + @tornado.gen.coroutine + def up_down_list_query(self, rpkid): + q_msg = self._compose_up_down_query("list") + r_msg = yield self.query_up_down(rpkid, q_msg) + raise tornado.gen.Return(r_msg) - @tornado.gen.coroutine - def up_down_revoke_query(self, rpkid, class_name, ski): - q_msg = self._compose_up_down_query("revoke") - SubElement(q_msg, rpki.up_down.tag_key, class_name = class_name, ski = ski) - r_msg = yield self.query_up_down(rpkid, q_msg) - raise tornado.gen.Return(r_msg) + @tornado.gen.coroutine + def up_down_issue_query(self, rpkid, ca, ca_detail): + logger.debug("Parent.up_down_issue_query(): caRepository %r rpkiManifest %r rpkiNotify %r", + ca.sia_uri, ca_detail.manifest_uri, ca.parent.repository.rrdp_notification_uri) + pkcs10 = rpki.x509.PKCS10.create( + keypair = ca_detail.private_key_id, + is_ca = True, + caRepository = ca.sia_uri, + rpkiManifest = ca_detail.manifest_uri, + rpkiNotify = ca.parent.repository.rrdp_notification_uri) + q_msg = self._compose_up_down_query("issue") + q_pdu = SubElement(q_msg, rpki.up_down.tag_request, class_name = ca.parent_resource_class) + q_pdu.text = pkcs10.get_Base64() + r_msg = yield self.query_up_down(rpkid, q_msg) + raise tornado.gen.Return(r_msg) - @tornado.gen.coroutine - def query_up_down(self, rpkid, q_msg): + @tornado.gen.coroutine + def up_down_revoke_query(self, rpkid, class_name, ski): + q_msg = self._compose_up_down_query("revoke") + SubElement(q_msg, rpki.up_down.tag_key, class_name = class_name, ski = ski) + r_msg = yield self.query_up_down(rpkid, q_msg) + raise tornado.gen.Return(r_msg) - if self.bsc is None: - raise rpki.exceptions.BSCNotFound("Could not find BSC") - if self.bsc.signing_cert is None: - raise rpki.exceptions.BSCNotReady("BSC %r is not yet usable" % self.bsc.bsc_handle) + @tornado.gen.coroutine + def query_up_down(self, rpkid, q_msg): - q_der = rpki.up_down.cms_msg().wrap(q_msg, self.bsc.private_key_id, self.bsc.signing_cert, self.bsc.signing_cert_crl) + if self.bsc is None: + raise rpki.exceptions.BSCNotFound("Could not find BSC") - http_request = tornado.httpclient.HTTPRequest( - url = self.peer_contact_uri, - method = "POST", - body = q_der, - headers = { "Content-Type" : rpki.up_down.content_type }) + if self.bsc.signing_cert is None: + raise rpki.exceptions.BSCNotReady("BSC %r is not yet usable" % self.bsc.bsc_handle) - http_response = yield rpkid.http_fetch(http_request) + q_der = rpki.up_down.cms_msg().wrap(q_msg, self.bsc.private_key_id, self.bsc.signing_cert, self.bsc.signing_cert_crl) - # Tornado already checked http_response.code for us + http_request = tornado.httpclient.HTTPRequest( + url = self.peer_contact_uri, + method = "POST", + body = q_der, + headers = { "Content-Type" : rpki.up_down.content_type }) - content_type = http_response.headers.get("Content-Type") + http_response = yield rpkid.http_fetch(http_request) + + # Tornado already checked http_response.code for us - if content_type not in rpki.up_down.allowed_content_types: - raise rpki.exceptions.BadContentType("HTTP Content-Type %r, expected %r" % (rpki.up_down.content_type, content_type)) + content_type = http_response.headers.get("Content-Type") - r_der = http_response.body - r_cms = rpki.up_down.cms_msg(DER = r_der) - r_msg = r_cms.unwrap((rpkid.bpki_ta, self.tenant.bpki_cert, self.tenant.bpki_glue, self.bpki_cert, self.bpki_glue)) - r_cms.check_replay_sql(self, self.peer_contact_uri) - rpki.up_down.check_response(r_msg, q_msg.get("type")) + if content_type not in rpki.up_down.allowed_content_types: + raise rpki.exceptions.BadContentType("HTTP Content-Type %r, expected %r" % (rpki.up_down.content_type, content_type)) - raise tornado.gen.Return(r_msg) + r_der = http_response.body + r_cms = rpki.up_down.cms_msg(DER = r_der) + r_msg = r_cms.unwrap((rpkid.bpki_ta, self.tenant.bpki_cert, self.tenant.bpki_glue, self.bpki_cert, self.bpki_glue)) + r_cms.check_replay_sql(self, self.peer_contact_uri) + rpki.up_down.check_response(r_msg, q_msg.get("type")) + raise tornado.gen.Return(r_msg) - def construct_sia_uri(self, rc): - """ - Construct the sia_uri value for a CA under this parent given - configured information and the parent's up-down protocol - list_response PDU. - """ - sia_uri = rc.get("suggested_sia_head", "") - if not sia_uri.startswith("rsync://") or not sia_uri.startswith(self.sia_base): - sia_uri = self.sia_base - if not sia_uri.endswith("/"): - raise rpki.exceptions.BadURISyntax("SIA URI must end with a slash: %s" % sia_uri) - return sia_uri + def construct_sia_uri(self, rc): + """ + Construct the sia_uri value for a CA under this parent given + configured information and the parent's up-down protocol + list_response PDU. + """ + + sia_uri = rc.get("suggested_sia_head", "") + if not sia_uri.startswith("rsync://") or not sia_uri.startswith(self.sia_base): + sia_uri = self.sia_base + if not sia_uri.endswith("/"): + raise rpki.exceptions.BadURISyntax("SIA URI must end with a slash: %s" % sia_uri) + return sia_uri class CA(models.Model): - last_crl_sn = models.BigIntegerField(default = 1) - last_manifest_sn = models.BigIntegerField(default = 1) - next_manifest_update = SundialField(null = True) - next_crl_update = SundialField(null = True) - last_issued_sn = models.BigIntegerField(default = 1) - sia_uri = models.TextField(null = True) - parent_resource_class = models.TextField(null = True) # Not sure this should allow NULL - parent = models.ForeignKey(Parent, related_name = "cas") - - # So it turns out that there's always a 1:1 mapping between the - # class_name we receive from our parent and the class_name we issue - # to our children: in spite of the obfuscated way that we used to - # handle class names, we never actually added a way for the back-end - # to create new classes. Not clear we want to encourage this, but - # if we wanted to support it, simple approach would probably be an - # optional class_name attribute in the left-right <list_resources/> - # response; if not present, we'd use parent's class_name as now, - # otherwise we'd use the supplied class_name. - - # ca_obj has a zillion properties encoding various specialized - # ca_detail queries. ORM query syntax probably renders this OBE, - # but need to translate in existing code. - # - #def pending_ca_details(self): return self.ca_details.filter(state = "pending") - #def active_ca_detail(self): return self.ca_details.get(state = "active") - #def deprecated_ca_details(self): return self.ca_details.filter(state = "deprecated") - #def active_or_deprecated_ca_details(self): return self.ca_details.filter(state__in = ("active", "deprecated")) - #def revoked_ca_details(self): return self.ca_details.filter(state = "revoked") - #def issue_response_candidate_ca_details(self): return self.ca_details.exclude(state = "revoked") - - - @tornado.gen.coroutine - def check_for_updates(self, rpkid, parent, rc): - """ - Parent has signaled continued existance of a resource class we - already knew about, so we need to check for an updated - certificate, changes in resource coverage, revocation and reissue - with the same key, etc. - """ - - logger.debug("check_for_updates()") - sia_uri = parent.construct_sia_uri(rc) - sia_uri_changed = self.sia_uri != sia_uri - - if sia_uri_changed: - logger.debug("SIA changed: was %s now %s", self.sia_uri, sia_uri) - self.sia_uri = sia_uri + last_crl_sn = models.BigIntegerField(default = 1) + last_manifest_sn = models.BigIntegerField(default = 1) + next_manifest_update = SundialField(null = True) + next_crl_update = SundialField(null = True) + last_issued_sn = models.BigIntegerField(default = 1) + sia_uri = models.TextField(null = True) + parent_resource_class = models.TextField(null = True) # Not sure this should allow NULL + parent = models.ForeignKey(Parent, related_name = "cas") + + # So it turns out that there's always a 1:1 mapping between the + # class_name we receive from our parent and the class_name we issue + # to our children: in spite of the obfuscated way that we used to + # handle class names, we never actually added a way for the back-end + # to create new classes. Not clear we want to encourage this, but + # if we wanted to support it, simple approach would probably be an + # optional class_name attribute in the left-right <list_resources/> + # response; if not present, we'd use parent's class_name as now, + # otherwise we'd use the supplied class_name. + + # ca_obj has a zillion properties encoding various specialized + # ca_detail queries. ORM query syntax probably renders this OBE, + # but need to translate in existing code. + # + #def pending_ca_details(self): return self.ca_details.filter(state = "pending") + #def active_ca_detail(self): return self.ca_details.get(state = "active") + #def deprecated_ca_details(self): return self.ca_details.filter(state = "deprecated") + #def active_or_deprecated_ca_details(self): return self.ca_details.filter(state__in = ("active", "deprecated")) + #def revoked_ca_details(self): return self.ca_details.filter(state = "revoked") + #def issue_response_candidate_ca_details(self): return self.ca_details.exclude(state = "revoked") + + + @tornado.gen.coroutine + def check_for_updates(self, rpkid, parent, rc): + """ + Parent has signaled continued existance of a resource class we + already knew about, so we need to check for an updated + certificate, changes in resource coverage, revocation and reissue + with the same key, etc. + """ + + logger.debug("check_for_updates()") + sia_uri = parent.construct_sia_uri(rc) + sia_uri_changed = self.sia_uri != sia_uri + + if sia_uri_changed: + logger.debug("SIA changed: was %s now %s", self.sia_uri, sia_uri) + self.sia_uri = sia_uri + + class_name = rc.get("class_name") + + rc_resources = rpki.resource_set.resource_bag( + rc.get("resource_set_as"), + rc.get("resource_set_ipv4"), + rc.get("resource_set_ipv6"), + rc.get("resource_set_notafter")) + + cert_map = {} + + for c in rc.getiterator(rpki.up_down.tag_certificate): + x = rpki.x509.X509(Base64 = c.text) + u = rpki.up_down.multi_uri(c.get("cert_url")).rsync() + cert_map[x.gSKI()] = (x, u) + + ca_details = self.ca_details.exclude(state = "revoked") + + if not ca_details: + logger.warning("Existing resource class %s to %s from %s with no certificates, rekeying", + class_name, parent.tenant.tenant_handle, parent.parent_handle) + yield self.rekey(rpkid) + return + + for ca_detail in ca_details: + + rc_cert, rc_cert_uri = cert_map.pop(ca_detail.public_key.gSKI(), (None, None)) + + if rc_cert is None: + logger.warning("g(SKI) %s in resource class %s is in database but missing from list_response to %s from %s, " + "maybe parent certificate went away?", + ca_detail.public_key.gSKI(), class_name, parent.tenant.tenant_handle, parent.parent_handle) + publisher = rpki.rpkid.publication_queue(rpkid) + ca_detail.destroy(ca = ca_detail.ca, publisher = publisher) + yield publisher.call_pubd() + continue + + if ca_detail.state == "active" and ca_detail.ca_cert_uri != rc_cert_uri: + logger.debug("AIA changed: was %s now %s", ca_detail.ca_cert_uri, rc_cert_uri) + ca_detail.ca_cert_uri = rc_cert_uri + ca_detail.save() + + if ca_detail.state not in ("pending", "active"): + continue + + if ca_detail.state == "pending": + current_resources = rpki.resource_set.resource_bag() + else: + current_resources = ca_detail.latest_ca_cert.get_3779resources() + + if (ca_detail.state == "pending" or + sia_uri_changed or + ca_detail.latest_ca_cert != rc_cert or + ca_detail.latest_ca_cert.getNotAfter() != rc_resources.valid_until or + current_resources.undersized(rc_resources) or + current_resources.oversized(rc_resources)): + + yield ca_detail.update( + rpkid = rpkid, + parent = parent, + ca = self, + rc = rc, + sia_uri_changed = sia_uri_changed, + old_resources = current_resources) + + if cert_map: + logger.warning("Unknown certificate g(SKI)%s %s in resource class %s in list_response to %s from %s, maybe you want to \"revoke_forgotten\"?", + "" if len(cert_map) == 1 else "s", ", ".join(cert_map), class_name, parent.tenant.tenant_handle, parent.parent_handle) + + + # Called from exactly one place, in rpki.rpkid_tasks.PollParentTask.class_loop(). + # Might want to refactor. + + @classmethod + @tornado.gen.coroutine + def create(cls, rpkid, parent, rc): + """ + Parent has signaled existance of a new resource class, so we need + to create and set up a corresponding CA object. + """ - class_name = rc.get("class_name") + self = cls.objects.create(parent = parent, + parent_resource_class = rc.get("class_name"), + sia_uri = parent.construct_sia_uri(rc)) + + ca_detail = CADetail.create(self) + + logger.debug("Sending issue request to %r from %r", parent, self.create) - rc_resources = rpki.resource_set.resource_bag( - rc.get("resource_set_as"), - rc.get("resource_set_ipv4"), - rc.get("resource_set_ipv6"), - rc.get("resource_set_notafter")) + r_msg = yield parent.up_down_issue_query(rpkid = rpkid, ca = self, ca_detail = ca_detail) - cert_map = {} + c = r_msg[0][0] - for c in rc.getiterator(rpki.up_down.tag_certificate): - x = rpki.x509.X509(Base64 = c.text) - u = rpki.up_down.multi_uri(c.get("cert_url")).rsync() - cert_map[x.gSKI()] = (x, u) + logger.debug("CA %r received certificate %s", self, c.get("cert_url")) - ca_details = self.ca_details.exclude(state = "revoked") + yield ca_detail.activate( + rpkid = rpkid, + ca = self, + cert = rpki.x509.X509(Base64 = c.text), + uri = c.get("cert_url")) - if not ca_details: - logger.warning("Existing resource class %s to %s from %s with no certificates, rekeying", - class_name, parent.tenant.tenant_handle, parent.parent_handle) - yield self.rekey(rpkid) - return - for ca_detail in ca_details: + @tornado.gen.coroutine + def destroy(self, rpkid, parent): + """ + The list of current resource classes received from parent does not + include the class corresponding to this CA, so we need to delete + it (and its little dog too...). - rc_cert, rc_cert_uri = cert_map.pop(ca_detail.public_key.gSKI(), (None, None)) + All certs published by this CA are now invalid, so need to + withdraw them, the CRL, and the manifest from the repository, + delete all child_cert and ca_detail records associated with this + CA, then finally delete this CA itself. + """ - if rc_cert is None: - logger.warning("g(SKI) %s in resource class %s is in database but missing from list_response to %s from %s, " - "maybe parent certificate went away?", - ca_detail.public_key.gSKI(), class_name, parent.tenant.tenant_handle, parent.parent_handle) publisher = rpki.rpkid.publication_queue(rpkid) - ca_detail.destroy(ca = ca_detail.ca, publisher = publisher) - yield publisher.call_pubd() - continue - - if ca_detail.state == "active" and ca_detail.ca_cert_uri != rc_cert_uri: - logger.debug("AIA changed: was %s now %s", ca_detail.ca_cert_uri, rc_cert_uri) - ca_detail.ca_cert_uri = rc_cert_uri - ca_detail.save() - - if ca_detail.state not in ("pending", "active"): - continue - - if ca_detail.state == "pending": - current_resources = rpki.resource_set.resource_bag() - else: - current_resources = ca_detail.latest_ca_cert.get_3779resources() - - if (ca_detail.state == "pending" or - sia_uri_changed or - ca_detail.latest_ca_cert != rc_cert or - ca_detail.latest_ca_cert.getNotAfter() != rc_resources.valid_until or - current_resources.undersized(rc_resources) or - current_resources.oversized(rc_resources)): - - yield ca_detail.update( - rpkid = rpkid, - parent = parent, - ca = self, - rc = rc, - sia_uri_changed = sia_uri_changed, - old_resources = current_resources) - - if cert_map: - logger.warning("Unknown certificate g(SKI)%s %s in resource class %s in list_response to %s from %s, maybe you want to \"revoke_forgotten\"?", - "" if len(cert_map) == 1 else "s", ", ".join(cert_map), class_name, parent.tenant.tenant_handle, parent.parent_handle) - - - # Called from exactly one place, in rpki.rpkid_tasks.PollParentTask.class_loop(). - # Might want to refactor. - - @classmethod - @tornado.gen.coroutine - def create(cls, rpkid, parent, rc): - """ - Parent has signaled existance of a new resource class, so we need - to create and set up a corresponding CA object. - """ - self = cls.objects.create(parent = parent, - parent_resource_class = rc.get("class_name"), - sia_uri = parent.construct_sia_uri(rc)) + for ca_detail in self.ca_details.all(): + ca_detail.destroy(ca = self, publisher = publisher, allow_failure = True) - ca_detail = CADetail.create(self) + try: + yield publisher.call_pubd() - logger.debug("Sending issue request to %r from %r", parent, self.create) + except: + logger.exception("Could not delete CA %r, skipping", self) - r_msg = yield parent.up_down_issue_query(rpkid = rpkid, ca = self, ca_detail = ca_detail) + else: + logger.debug("Deleting %r", self) + self.delete() - c = r_msg[0][0] - logger.debug("CA %r received certificate %s", self, c.get("cert_url")) + def next_serial_number(self): + """ + Allocate a certificate serial number. + """ - yield ca_detail.activate( - rpkid = rpkid, - ca = self, - cert = rpki.x509.X509(Base64 = c.text), - uri = c.get("cert_url")) + self.last_issued_sn += 1 + self.save() + return self.last_issued_sn - @tornado.gen.coroutine - def destroy(self, rpkid, parent): - """ - The list of current resource classes received from parent does not - include the class corresponding to this CA, so we need to delete - it (and its little dog too...). - - All certs published by this CA are now invalid, so need to - withdraw them, the CRL, and the manifest from the repository, - delete all child_cert and ca_detail records associated with this - CA, then finally delete this CA itself. - """ + def next_manifest_number(self): + """ + Allocate a manifest serial number. + """ - publisher = rpki.rpkid.publication_queue(rpkid) + self.last_manifest_sn += 1 + self.save() + return self.last_manifest_sn - for ca_detail in self.ca_details.all(): - ca_detail.destroy(ca = self, publisher = publisher, allow_failure = True) - try: - yield publisher.call_pubd() + def next_crl_number(self): + """ + Allocate a CRL serial number. + """ - except: - logger.exception("Could not delete CA %r, skipping", self) + self.last_crl_sn += 1 + self.save() + return self.last_crl_sn - else: - logger.debug("Deleting %r", self) - self.delete() + @tornado.gen.coroutine + def rekey(self, rpkid): + """ + Initiate a rekey operation for this CA. Generate a new keypair. + Request cert from parent using new keypair. Mark result as our + active ca_detail. Reissue all child certs issued by this CA using + the new ca_detail. + """ - def next_serial_number(self): - """ - Allocate a certificate serial number. - """ + try: + old_detail = self.ca_details.get(state = "active") + except CADetail.DoesNotExist: + old_detail = None - self.last_issued_sn += 1 - self.save() - return self.last_issued_sn + new_detail = CADetail.create(ca = self) # sic: class method, not manager function (for now, anyway) + logger.debug("Sending issue request to %r from %r", self.parent, self.rekey) - def next_manifest_number(self): - """ - Allocate a manifest serial number. - """ + r_msg = yield self.parent.up_down_issue_query(rpkid = rpkid, ca = self, ca_detail = new_detail) - self.last_manifest_sn += 1 - self.save() - return self.last_manifest_sn + c = r_msg[0][0] + logger.debug("CA %r received certificate %s", self, c.get("cert_url")) - def next_crl_number(self): - """ - Allocate a CRL serial number. - """ + yield new_detail.activate( + rpkid = rpkid, + ca = self, + cert = rpki.x509.X509(Base64 = c.text), + uri = c.get("cert_url"), + predecessor = old_detail) - self.last_crl_sn += 1 - self.save() - return self.last_crl_sn + @tornado.gen.coroutine + def revoke(self, revoke_all = False): + """ + Revoke deprecated ca_detail objects associated with this CA, or + all ca_details associated with this CA if revoke_all is set. + """ - @tornado.gen.coroutine - def rekey(self, rpkid): - """ - Initiate a rekey operation for this CA. Generate a new keypair. - Request cert from parent using new keypair. Mark result as our - active ca_detail. Reissue all child certs issued by this CA using - the new ca_detail. - """ + if revoke_all: + ca_details = self.ca_details.all() + else: + ca_details = self.ca_details.filter(state = "deprecated") - try: - old_detail = self.ca_details.get(state = "active") - except CADetail.DoesNotExist: - old_detail = None + yield [ca_detail.revoke() for ca_detail in ca_details] - new_detail = CADetail.create(ca = self) # sic: class method, not manager function (for now, anyway) - logger.debug("Sending issue request to %r from %r", self.parent, self.rekey) + @tornado.gen.coroutine + def reissue(self): + """ + Reissue all current certificates issued by this CA. + """ - r_msg = yield self.parent.up_down_issue_query(rpkid = rpkid, ca = self, ca_detail = new_detail) + ca_detail = self.ca_details.get(state = "active") + if ca_detail: + yield ca_detail.reissue() - c = r_msg[0][0] - logger.debug("CA %r received certificate %s", self, c.get("cert_url")) +class CADetail(models.Model): + public_key = PublicKeyField(null = True) + private_key_id = RSAPrivateKeyField(null = True) + latest_crl = CRLField(null = True) + crl_published = SundialField(null = True) + latest_ca_cert = CertificateField(null = True) + manifest_private_key_id = RSAPrivateKeyField(null = True) + manifest_public_key = PublicKeyField(null = True) + latest_manifest_cert = CertificateField(null = True) + latest_manifest = ManifestField(null = True) + manifest_published = SundialField(null = True) + state = EnumField(choices = ("pending", "active", "deprecated", "revoked")) + ca_cert_uri = models.TextField(null = True) + ca = models.ForeignKey(CA, related_name = "ca_details") - yield new_detail.activate( - rpkid = rpkid, - ca = self, - cert = rpki.x509.X509(Base64 = c.text), - uri = c.get("cert_url"), - predecessor = old_detail) + # Like the old ca_obj class, the old ca_detail_obj class had ten + # zillion properties and methods encapsulating SQL queries. + # Translate as we go. - @tornado.gen.coroutine - def revoke(self, revoke_all = False): - """ - Revoke deprecated ca_detail objects associated with this CA, or - all ca_details associated with this CA if revoke_all is set. - """ - if revoke_all: - ca_details = self.ca_details.all() - else: - ca_details = self.ca_details.filter(state = "deprecated") + @property + def crl_uri(self): + """ + Return publication URI for this ca_detail's CRL. + """ - yield [ca_detail.revoke() for ca_detail in ca_details] + return self.ca.sia_uri + self.crl_uri_tail - @tornado.gen.coroutine - def reissue(self): - """ - Reissue all current certificates issued by this CA. - """ + @property + def crl_uri_tail(self): + """ + Return tail (filename portion) of publication URI for this ca_detail's CRL. + """ - ca_detail = self.ca_details.get(state = "active") - if ca_detail: - yield ca_detail.reissue() + return self.public_key.gSKI() + ".crl" -class CADetail(models.Model): - public_key = PublicKeyField(null = True) - private_key_id = RSAPrivateKeyField(null = True) - latest_crl = CRLField(null = True) - crl_published = SundialField(null = True) - latest_ca_cert = CertificateField(null = True) - manifest_private_key_id = RSAPrivateKeyField(null = True) - manifest_public_key = PublicKeyField(null = True) - latest_manifest_cert = CertificateField(null = True) - latest_manifest = ManifestField(null = True) - manifest_published = SundialField(null = True) - state = EnumField(choices = ("pending", "active", "deprecated", "revoked")) - ca_cert_uri = models.TextField(null = True) - ca = models.ForeignKey(CA, related_name = "ca_details") - - - # Like the old ca_obj class, the old ca_detail_obj class had ten - # zillion properties and methods encapsulating SQL queries. - # Translate as we go. - - - @property - def crl_uri(self): - """ - Return publication URI for this ca_detail's CRL. - """ + @property + def manifest_uri(self): + """ + Return publication URI for this ca_detail's manifest. + """ - return self.ca.sia_uri + self.crl_uri_tail + return self.ca.sia_uri + self.public_key.gSKI() + ".mft" - @property - def crl_uri_tail(self): - """ - Return tail (filename portion) of publication URI for this ca_detail's CRL. - """ + def has_expired(self): + """ + Return whether this ca_detail's certificate has expired. + """ - return self.public_key.gSKI() + ".crl" + return self.latest_ca_cert.getNotAfter() <= rpki.sundial.now() - @property - def manifest_uri(self): - """ - Return publication URI for this ca_detail's manifest. - """ + def covers(self, target): + """ + Test whether this ca-detail covers a given set of resources. + """ - return self.ca.sia_uri + self.public_key.gSKI() + ".mft" + assert not target.asn.inherit and not target.v4.inherit and not target.v6.inherit + me = self.latest_ca_cert.get_3779resources() + return target.asn <= me.asn and target.v4 <= me.v4 and target.v6 <= me.v6 - def has_expired(self): - """ - Return whether this ca_detail's certificate has expired. - """ + @tornado.gen.coroutine + def activate(self, rpkid, ca, cert, uri, predecessor = None): + """ + Activate this ca_detail. + """ - return self.latest_ca_cert.getNotAfter() <= rpki.sundial.now() + publisher = rpki.rpkid.publication_queue(rpkid) + self.latest_ca_cert = cert + self.ca_cert_uri = uri + self.generate_manifest_cert() + self.state = "active" + self.generate_crl(publisher = publisher) + self.generate_manifest(publisher = publisher) + self.save() + + if predecessor is not None: + predecessor.state = "deprecated" + predecessor.save() + for child_cert in predecessor.child_certs.all(): + child_cert.reissue(ca_detail = self, publisher = publisher) + for roa in predecessor.roas.all(): + roa.regenerate(publisher = publisher) + for ghostbuster in predecessor.ghostbusters.all(): + ghostbuster.regenerate(publisher = publisher) + predecessor.generate_crl(publisher = publisher) + predecessor.generate_manifest(publisher = publisher) + yield publisher.call_pubd() - def covers(self, target): - """ - Test whether this ca-detail covers a given set of resources. - """ - assert not target.asn.inherit and not target.v4.inherit and not target.v6.inherit - me = self.latest_ca_cert.get_3779resources() - return target.asn <= me.asn and target.v4 <= me.v4 and target.v6 <= me.v6 + def destroy(self, ca, publisher, allow_failure = False): + """ + Delete this ca_detail and all of the certs it issued. + If allow_failure is true, we clean up as much as we can but don't + raise an exception. + """ - @tornado.gen.coroutine - def activate(self, rpkid, ca, cert, uri, predecessor = None): - """ - Activate this ca_detail. - """ + repository = ca.parent.repository + handler = False if allow_failure else None + for child_cert in self.child_certs.all(): + publisher.queue(uri = child_cert.uri, old_obj = child_cert.cert, repository = repository, handler = handler) + child_cert.delete() + for roa in self.roas.all(): + roa.revoke(publisher = publisher, allow_failure = allow_failure, fast = True) + for ghostbuster in self.ghostbusters.all(): + ghostbuster.revoke(publisher = publisher, allow_failure = allow_failure, fast = True) + if self.latest_manifest is not None: + publisher.queue(uri = self.manifest_uri, old_obj = self.latest_manifest, repository = repository, handler = handler) + if self.latest_crl is not None: + publisher.queue(uri = self.crl_uri, old_obj = self.latest_crl, repository = repository, handler = handler) + for cert in self.revoked_certs.all(): # + self.child_certs.all() + logger.debug("Deleting %r", cert) + cert.delete() + logger.debug("Deleting %r", self) + self.delete() - publisher = rpki.rpkid.publication_queue(rpkid) - self.latest_ca_cert = cert - self.ca_cert_uri = uri - self.generate_manifest_cert() - self.state = "active" - self.generate_crl(publisher = publisher) - self.generate_manifest(publisher = publisher) - self.save() - - if predecessor is not None: - predecessor.state = "deprecated" - predecessor.save() - for child_cert in predecessor.child_certs.all(): - child_cert.reissue(ca_detail = self, publisher = publisher) - for roa in predecessor.roas.all(): - roa.regenerate(publisher = publisher) - for ghostbuster in predecessor.ghostbusters.all(): - ghostbuster.regenerate(publisher = publisher) - predecessor.generate_crl(publisher = publisher) - predecessor.generate_manifest(publisher = publisher) - - yield publisher.call_pubd() - - - def destroy(self, ca, publisher, allow_failure = False): - """ - Delete this ca_detail and all of the certs it issued. - If allow_failure is true, we clean up as much as we can but don't - raise an exception. - """ + @tornado.gen.coroutine + def revoke(self, rpkid): + """ + Request revocation of all certificates whose g(SKI) matches the key + for this ca_detail. - repository = ca.parent.repository - handler = False if allow_failure else None - for child_cert in self.child_certs.all(): - publisher.queue(uri = child_cert.uri, old_obj = child_cert.cert, repository = repository, handler = handler) - child_cert.delete() - for roa in self.roas.all(): - roa.revoke(publisher = publisher, allow_failure = allow_failure, fast = True) - for ghostbuster in self.ghostbusters.all(): - ghostbuster.revoke(publisher = publisher, allow_failure = allow_failure, fast = True) - if self.latest_manifest is not None: - publisher.queue(uri = self.manifest_uri, old_obj = self.latest_manifest, repository = repository, handler = handler) - if self.latest_crl is not None: - publisher.queue(uri = self.crl_uri, old_obj = self.latest_crl, repository = repository, handler = handler) - for cert in self.revoked_certs.all(): # + self.child_certs.all() - logger.debug("Deleting %r", cert) - cert.delete() - logger.debug("Deleting %r", self) - self.delete() - - - @tornado.gen.coroutine - def revoke(self, rpkid): - """ - Request revocation of all certificates whose g(SKI) matches the key - for this ca_detail. + Tasks: - Tasks: + - Request revocation of old keypair by parent. - - Request revocation of old keypair by parent. + - Revoke all child certs issued by the old keypair. - - Revoke all child certs issued by the old keypair. + - Generate a final CRL, signed with the old keypair, listing all + the revoked certs, with a next CRL time after the last cert or + CRL signed by the old keypair will have expired. - - Generate a final CRL, signed with the old keypair, listing all - the revoked certs, with a next CRL time after the last cert or - CRL signed by the old keypair will have expired. + - Generate a corresponding final manifest. - - Generate a corresponding final manifest. + - Destroy old keypairs. - - Destroy old keypairs. + - Leave final CRL and manifest in place until their nextupdate + time has passed. + """ - - Leave final CRL and manifest in place until their nextupdate - time has passed. - """ + gski = self.latest_ca_cert.gSKI() - gski = self.latest_ca_cert.gSKI() + logger.debug("Asking parent to revoke CA certificate matching g(SKI) = %s", gski) - logger.debug("Asking parent to revoke CA certificate matching g(SKI) = %s", gski) + r_msg = yield self.ca.parent.up_down_revoke_query(rpkid = rpkid, class_name = self.ca.parent_resource_class, ski = gski) - r_msg = yield self.ca.parent.up_down_revoke_query(rpkid = rpkid, class_name = self.ca.parent_resource_class, ski = gski) + if r_msg[0].get("class_name") != self.ca.parent_resource_class: + raise rpki.exceptions.ResourceClassMismatch - if r_msg[0].get("class_name") != self.ca.parent_resource_class: - raise rpki.exceptions.ResourceClassMismatch + if r_msg[0].get("ski") != gski: + raise rpki.exceptions.SKIMismatch - if r_msg[0].get("ski") != gski: - raise rpki.exceptions.SKIMismatch + logger.debug("Parent revoked g(SKI) %s, starting cleanup", gski) - logger.debug("Parent revoked g(SKI) %s, starting cleanup", gski) + crl_interval = rpki.sundial.timedelta(seconds = self.ca.parent.tenant.crl_interval) - crl_interval = rpki.sundial.timedelta(seconds = self.ca.parent.tenant.crl_interval) + nextUpdate = rpki.sundial.now() - nextUpdate = rpki.sundial.now() + if self.latest_manifest is not None: + self.latest_manifest.extract_if_needed() + nextUpdate = nextUpdate.later(self.latest_manifest.getNextUpdate()) - if self.latest_manifest is not None: - self.latest_manifest.extract_if_needed() - nextUpdate = nextUpdate.later(self.latest_manifest.getNextUpdate()) + if self.latest_crl is not None: + nextUpdate = nextUpdate.later(self.latest_crl.getNextUpdate()) - if self.latest_crl is not None: - nextUpdate = nextUpdate.later(self.latest_crl.getNextUpdate()) + publisher = rpki.rpkid.publication_queue(rpkid) - publisher = rpki.rpkid.publication_queue(rpkid) + for child_cert in self.child_certs.all(): + nextUpdate = nextUpdate.later(child_cert.cert.getNotAfter()) + child_cert.revoke(publisher = publisher) - for child_cert in self.child_certs.all(): - nextUpdate = nextUpdate.later(child_cert.cert.getNotAfter()) - child_cert.revoke(publisher = publisher) + for roa in self.roas.all(): + nextUpdate = nextUpdate.later(roa.cert.getNotAfter()) + roa.revoke(publisher = publisher) - for roa in self.roas.all(): - nextUpdate = nextUpdate.later(roa.cert.getNotAfter()) - roa.revoke(publisher = publisher) + for ghostbuster in self.ghostbusters.all(): + nextUpdate = nextUpdate.later(ghostbuster.cert.getNotAfter()) + ghostbuster.revoke(publisher = publisher) - for ghostbuster in self.ghostbusters.all(): - nextUpdate = nextUpdate.later(ghostbuster.cert.getNotAfter()) - ghostbuster.revoke(publisher = publisher) + nextUpdate += crl_interval - nextUpdate += crl_interval + self.generate_crl(publisher = publisher, nextUpdate = nextUpdate) + self.generate_manifest(publisher = publisher, nextUpdate = nextUpdate) + self.private_key_id = None + self.manifest_private_key_id = None + self.manifest_public_key = None + self.latest_manifest_cert = None + self.state = "revoked" + self.save() - self.generate_crl(publisher = publisher, nextUpdate = nextUpdate) - self.generate_manifest(publisher = publisher, nextUpdate = nextUpdate) - self.private_key_id = None - self.manifest_private_key_id = None - self.manifest_public_key = None - self.latest_manifest_cert = None - self.state = "revoked" - self.save() + yield publisher.call_pubd() - yield publisher.call_pubd() + @tornado.gen.coroutine + def update(self, rpkid, parent, ca, rc, sia_uri_changed, old_resources): + """ + Need to get a new certificate for this ca_detail and perhaps frob + children of this ca_detail. + """ - @tornado.gen.coroutine - def update(self, rpkid, parent, ca, rc, sia_uri_changed, old_resources): - """ - Need to get a new certificate for this ca_detail and perhaps frob - children of this ca_detail. - """ + logger.debug("Sending issue request to %r from %r", parent, self.update) - logger.debug("Sending issue request to %r from %r", parent, self.update) + r_msg = yield parent.up_down_issue_query(rpkid = rpkid, ca = ca, ca_detail = self) - r_msg = yield parent.up_down_issue_query(rpkid = rpkid, ca = ca, ca_detail = self) + c = r_msg[0][0] - c = r_msg[0][0] + cert = rpki.x509.X509(Base64 = c.text) + cert_url = c.get("cert_url") - cert = rpki.x509.X509(Base64 = c.text) - cert_url = c.get("cert_url") + logger.debug("CA %r received certificate %s", self, cert_url) - logger.debug("CA %r received certificate %s", self, cert_url) + if self.state == "pending": + yield self.activate(rpkid = rpkid, ca = ca, cert = cert, uri = cert_url) + return - if self.state == "pending": - yield self.activate(rpkid = rpkid, ca = ca, cert = cert, uri = cert_url) - return + validity_changed = self.latest_ca_cert is None or self.latest_ca_cert.getNotAfter() != cert.getNotAfter() - validity_changed = self.latest_ca_cert is None or self.latest_ca_cert.getNotAfter() != cert.getNotAfter() + publisher = rpki.rpkid.publication_queue(rpkid) - publisher = rpki.rpkid.publication_queue(rpkid) + if self.latest_ca_cert != cert: + self.latest_ca_cert = cert + self.save() + self.generate_manifest_cert() + self.generate_crl(publisher = publisher) + self.generate_manifest(publisher = publisher) - if self.latest_ca_cert != cert: - self.latest_ca_cert = cert - self.save() - self.generate_manifest_cert() - self.generate_crl(publisher = publisher) - self.generate_manifest(publisher = publisher) + new_resources = self.latest_ca_cert.get_3779resources() - new_resources = self.latest_ca_cert.get_3779resources() + if sia_uri_changed or old_resources.oversized(new_resources): + for child_cert in self.child_certs.all(): + child_resources = child_cert.cert.get_3779resources() + if sia_uri_changed or child_resources.oversized(new_resources): + child_cert.reissue(ca_detail = self, resources = child_resources & new_resources, publisher = publisher) - if sia_uri_changed or old_resources.oversized(new_resources): - for child_cert in self.child_certs.all(): - child_resources = child_cert.cert.get_3779resources() - if sia_uri_changed or child_resources.oversized(new_resources): - child_cert.reissue(ca_detail = self, resources = child_resources & new_resources, publisher = publisher) + if sia_uri_changed or validity_changed or old_resources.oversized(new_resources): + for roa in self.roas.all(): + roa.update(publisher = publisher, fast = True) - if sia_uri_changed or validity_changed or old_resources.oversized(new_resources): - for roa in self.roas.all(): - roa.update(publisher = publisher, fast = True) + if sia_uri_changed or validity_changed: + for ghostbuster in self.ghostbusters.all(): + ghostbuster.update(publisher = publisher, fast = True) - if sia_uri_changed or validity_changed: - for ghostbuster in self.ghostbusters.all(): - ghostbuster.update(publisher = publisher, fast = True) + yield publisher.call_pubd() - yield publisher.call_pubd() + @classmethod + def create(cls, ca): + """ + Create a new ca_detail object for a specified CA. + """ + + cer_keypair = rpki.x509.RSA.generate() + mft_keypair = rpki.x509.RSA.generate() + return cls.objects.create( + ca = ca, + state = "pending", + private_key_id = cer_keypair, + public_key = cer_keypair.get_public(), + manifest_private_key_id = mft_keypair, + manifest_public_key = mft_keypair.get_public()) + + + def issue_ee(self, ca, resources, subject_key, sia, + cn = None, sn = None, notAfter = None, eku = None): + """ + Issue a new EE certificate. + """ + + if notAfter is None: + notAfter = self.latest_ca_cert.getNotAfter() + return self.latest_ca_cert.issue( + keypair = self.private_key_id, + subject_key = subject_key, + serial = ca.next_serial_number(), + sia = sia, + aia = self.ca_cert_uri, + crldp = self.crl_uri, + resources = resources, + notAfter = notAfter, + is_ca = False, + cn = cn, + sn = sn, + eku = eku) + + + def generate_manifest_cert(self): + """ + Generate a new manifest certificate for this ca_detail. + """ + + resources = rpki.resource_set.resource_bag.from_inheritance() + self.latest_manifest_cert = self.issue_ee( + ca = self.ca, + resources = resources, + subject_key = self.manifest_public_key, + sia = (None, None, self.manifest_uri, self.ca.parent.repository.rrdp_notification_uri)) + + + def issue(self, ca, child, subject_key, sia, resources, publisher, child_cert = None): + """ + Issue a new certificate to a child. Optional child_cert argument + specifies an existing child_cert object to update in place; if not + specified, we create a new one. Returns the child_cert object + containing the newly issued cert. + """ + + self.check_failed_publication(publisher) + cert = self.latest_ca_cert.issue( + keypair = self.private_key_id, + subject_key = subject_key, + serial = ca.next_serial_number(), + aia = self.ca_cert_uri, + crldp = self.crl_uri, + sia = sia, + resources = resources, + notAfter = resources.valid_until) + if child_cert is None: + old_cert = None + child_cert = ChildCert(child = child, ca_detail = self, cert = cert) + logger.debug("Created new child_cert %r", child_cert) + else: + old_cert = child_cert.cert + child_cert.cert = cert + child_cert.ca_detail = self + logger.debug("Reusing existing child_cert %r", child_cert) + child_cert.gski = cert.gSKI() + child_cert.published = rpki.sundial.now() + child_cert.save() + publisher.queue( + uri = child_cert.uri, + old_obj = old_cert, + new_obj = child_cert.cert, + repository = ca.parent.repository, + handler = child_cert.published_callback) + self.generate_manifest(publisher = publisher) + return child_cert + + + def generate_crl(self, publisher, nextUpdate = None): + """ + Generate a new CRL for this ca_detail. At the moment this is + unconditional, that is, it is up to the caller to decide whether a + new CRL is needed. + """ + + self.check_failed_publication(publisher) + crl_interval = rpki.sundial.timedelta(seconds = self.ca.parent.tenant.crl_interval) + now = rpki.sundial.now() + if nextUpdate is None: + nextUpdate = now + crl_interval + certlist = [] + for revoked_cert in self.revoked_certs.all(): + if now > revoked_cert.expires + crl_interval: + revoked_cert.delete() + else: + certlist.append((revoked_cert.serial, revoked_cert.revoked)) + certlist.sort() + old_crl = self.latest_crl + self.latest_crl = rpki.x509.CRL.generate( + keypair = self.private_key_id, + issuer = self.latest_ca_cert, + serial = self.ca.next_crl_number(), + thisUpdate = now, + nextUpdate = nextUpdate, + revokedCertificates = certlist) + self.crl_published = now + self.save() + publisher.queue( + uri = self.crl_uri, + old_obj = old_crl, + new_obj = self.latest_crl, + repository = self.ca.parent.repository, + handler = self.crl_published_callback) + + + def crl_published_callback(self, pdu): + """ + Check result of CRL publication. + """ + + rpki.publication.raise_if_error(pdu) + self.crl_published = None + self.save() + + + def generate_manifest(self, publisher, nextUpdate = None): + """ + Generate a new manifest for this ca_detail. + """ + + self.check_failed_publication(publisher) + + crl_interval = rpki.sundial.timedelta(seconds = self.ca.parent.tenant.crl_interval) + now = rpki.sundial.now() + uri = self.manifest_uri + if nextUpdate is None: + nextUpdate = now + crl_interval + if (self.latest_manifest_cert is None or + (self.latest_manifest_cert.getNotAfter() < nextUpdate and + self.latest_manifest_cert.getNotAfter() < self.latest_ca_cert.getNotAfter())): + logger.debug("Generating EE certificate for %s", uri) + self.generate_manifest_cert() + logger.debug("Latest CA cert notAfter %s, new %s EE notAfter %s", + self.latest_ca_cert.getNotAfter(), uri, self.latest_manifest_cert.getNotAfter()) + logger.debug("Constructing manifest object list for %s", uri) + objs = [(self.crl_uri_tail, self.latest_crl)] + objs.extend((c.uri_tail, c.cert) for c in self.child_certs.all()) + objs.extend((r.uri_tail, r.roa) for r in self.roas.filter(roa__isnull = False)) + objs.extend((g.uri_tail, g.ghostbuster) for g in self.ghostbusters.all()) + objs.extend((e.uri_tail, e.cert) for e in self.ee_certificates.all()) + logger.debug("Building manifest object %s", uri) + old_manifest = self.latest_manifest + self.latest_manifest = rpki.x509.SignedManifest.build( + serial = self.ca.next_manifest_number(), + thisUpdate = now, + nextUpdate = nextUpdate, + names_and_objs = objs, + keypair = self.manifest_private_key_id, + certs = self.latest_manifest_cert) + logger.debug("Manifest generation took %s", rpki.sundial.now() - now) + self.manifest_published = now + self.save() + publisher.queue( + uri = uri, + old_obj = old_manifest, + new_obj = self.latest_manifest, + repository = self.ca.parent.repository, + handler = self.manifest_published_callback) + + + def manifest_published_callback(self, pdu): + """ + Check result of manifest publication. + """ + + rpki.publication.raise_if_error(pdu) + self.manifest_published = None + self.save() + + + @tornado.gen.coroutine + def reissue(self, rpkid): + """ + Reissue all current certificates issued by this ca_detail. + """ - @classmethod - def create(cls, ca): - """ - Create a new ca_detail object for a specified CA. - """ + publisher = rpki.rpkid.publication_queue(rpkid) + self.check_failed_publication(publisher) + for roa in self.roas.all(): + roa.regenerate(publisher, fast = True) + for ghostbuster in self.ghostbusters.all(): + ghostbuster.regenerate(publisher, fast = True) + for ee_certificate in self.ee_certificates.all(): + ee_certificate.reissue(publisher, force = True) + for child_cert in self.child_certs.all(): + child_cert.reissue(self, publisher, force = True) + self.generate_manifest_cert() + self.save() + self.generate_crl(publisher = publisher) + self.generate_manifest(publisher = publisher) + self.save() + yield publisher.call_pubd() - cer_keypair = rpki.x509.RSA.generate() - mft_keypair = rpki.x509.RSA.generate() - return cls.objects.create( - ca = ca, - state = "pending", - private_key_id = cer_keypair, - public_key = cer_keypair.get_public(), - manifest_private_key_id = mft_keypair, - manifest_public_key = mft_keypair.get_public()) + def check_failed_publication(self, publisher, check_all = True): + """ + Check for failed publication of objects issued by this ca_detail. + + All publishable objects have timestamp fields recording time of + last attempted publication, and callback methods which clear these + timestamps once publication has succeeded. Our task here is to + look for objects issued by this ca_detail which have timestamps + set (indicating that they have not been published) and for which + the timestamps are not very recent (for some definition of very + recent -- intent is to allow a bit of slack in case pubd is just + being slow). In such cases, we want to retry publication. + + As an optimization, we can probably skip checking other products + if manifest and CRL have been published, thus saving ourselves + several complex SQL queries. Not sure yet whether this + optimization is worthwhile. + + For the moment we check everything without optimization, because + it simplifies testing. + + For the moment our definition of staleness is hardwired; this + should become configurable. + """ + + logger.debug("Checking for failed publication for %r", self) + + stale = rpki.sundial.now() - rpki.sundial.timedelta(seconds = 60) + repository = self.ca.parent.repository + if self.latest_crl is not None and self.crl_published is not None and self.crl_published < stale: + logger.debug("Retrying publication for %s", self.crl_uri) + publisher.queue(uri = self.crl_uri, + new_obj = self.latest_crl, + repository = repository, + handler = self.crl_published_callback) + if self.latest_manifest is not None and self.manifest_published is not None and self.manifest_published < stale: + logger.debug("Retrying publication for %s", self.manifest_uri) + publisher.queue(uri = self.manifest_uri, + new_obj = self.latest_manifest, + repository = repository, + handler = self.manifest_published_callback) + if not check_all: + return + for child_cert in self.child_certs.filter(published__isnull = False, published__lt = stale): + logger.debug("Retrying publication for %s", child_cert) + publisher.queue( + uri = child_cert.uri, + new_obj = child_cert.cert, + repository = repository, + handler = child_cert.published_callback) + for roa in self.roas.filter(published__isnull = False, published__lt = stale): + logger.debug("Retrying publication for %s", roa) + publisher.queue( + uri = roa.uri, + new_obj = roa.roa, + repository = repository, + handler = roa.published_callback) + for ghostbuster in self.ghostbusters.filter(published__isnull = False, published__lt = stale): + logger.debug("Retrying publication for %s", ghostbuster) + publisher.queue( + uri = ghostbuster.uri, + new_obj = ghostbuster.ghostbuster, + repository = repository, + handler = ghostbuster.published_callback) + for ee_cert in self.ee_certificates.filter(published__isnull = False, published__lt = stale): + logger.debug("Retrying publication for %s", ee_cert) + publisher.queue( + uri = ee_cert.uri, + new_obj = ee_cert.cert, + repository = repository, + handler = ee_cert.published_callback) - def issue_ee(self, ca, resources, subject_key, sia, - cn = None, sn = None, notAfter = None, eku = None): - """ - Issue a new EE certificate. - """ - if notAfter is None: - notAfter = self.latest_ca_cert.getNotAfter() - return self.latest_ca_cert.issue( - keypair = self.private_key_id, - subject_key = subject_key, - serial = ca.next_serial_number(), - sia = sia, - aia = self.ca_cert_uri, - crldp = self.crl_uri, - resources = resources, - notAfter = notAfter, - is_ca = False, - cn = cn, - sn = sn, - eku = eku) - - - def generate_manifest_cert(self): - """ - Generate a new manifest certificate for this ca_detail. - """ +@xml_hooks +class Child(models.Model): + child_handle = models.SlugField(max_length = 255) + bpki_cert = CertificateField(null = True) + bpki_glue = CertificateField(null = True) + last_cms_timestamp = SundialField(null = True) + tenant = models.ForeignKey(Tenant, related_name = "children") + bsc = models.ForeignKey(BSC, related_name = "children") + objects = XMLManager() - resources = rpki.resource_set.resource_bag.from_inheritance() - self.latest_manifest_cert = self.issue_ee( - ca = self.ca, - resources = resources, - subject_key = self.manifest_public_key, - sia = (None, None, self.manifest_uri, self.ca.parent.repository.rrdp_notification_uri)) + class Meta: # pylint: disable=C1001,W0232 + unique_together = ("tenant", "child_handle") + xml_template = XMLTemplate( + name = "child", + handles = (BSC,), + elements = ("bpki_cert", "bpki_glue")) - def issue(self, ca, child, subject_key, sia, resources, publisher, child_cert = None): - """ - Issue a new certificate to a child. Optional child_cert argument - specifies an existing child_cert object to update in place; if not - specified, we create a new one. Returns the child_cert object - containing the newly issued cert. - """ - self.check_failed_publication(publisher) - cert = self.latest_ca_cert.issue( - keypair = self.private_key_id, - subject_key = subject_key, - serial = ca.next_serial_number(), - aia = self.ca_cert_uri, - crldp = self.crl_uri, - sia = sia, - resources = resources, - notAfter = resources.valid_until) - if child_cert is None: - old_cert = None - child_cert = ChildCert(child = child, ca_detail = self, cert = cert) - logger.debug("Created new child_cert %r", child_cert) - else: - old_cert = child_cert.cert - child_cert.cert = cert - child_cert.ca_detail = self - logger.debug("Reusing existing child_cert %r", child_cert) - child_cert.gski = cert.gSKI() - child_cert.published = rpki.sundial.now() - child_cert.save() - publisher.queue( - uri = child_cert.uri, - old_obj = old_cert, - new_obj = child_cert.cert, - repository = ca.parent.repository, - handler = child_cert.published_callback) - self.generate_manifest(publisher = publisher) - return child_cert - - - def generate_crl(self, publisher, nextUpdate = None): - """ - Generate a new CRL for this ca_detail. At the moment this is - unconditional, that is, it is up to the caller to decide whether a - new CRL is needed. - """ + @tornado.gen.coroutine + def xml_pre_delete_hook(self, rpkid): + publisher = rpki.rpkid.publication_queue(rpkid) + for child_cert in self.child_certs.all(): + child_cert.revoke(publisher = publisher, generate_crl_and_manifest = True) + yield publisher.call_pubd() - self.check_failed_publication(publisher) - crl_interval = rpki.sundial.timedelta(seconds = self.ca.parent.tenant.crl_interval) - now = rpki.sundial.now() - if nextUpdate is None: - nextUpdate = now + crl_interval - certlist = [] - for revoked_cert in self.revoked_certs.all(): - if now > revoked_cert.expires + crl_interval: - revoked_cert.delete() - else: - certlist.append((revoked_cert.serial, revoked_cert.revoked)) - certlist.sort() - old_crl = self.latest_crl - self.latest_crl = rpki.x509.CRL.generate( - keypair = self.private_key_id, - issuer = self.latest_ca_cert, - serial = self.ca.next_crl_number(), - thisUpdate = now, - nextUpdate = nextUpdate, - revokedCertificates = certlist) - self.crl_published = now - self.save() - publisher.queue( - uri = self.crl_uri, - old_obj = old_crl, - new_obj = self.latest_crl, - repository = self.ca.parent.repository, - handler = self.crl_published_callback) - - - def crl_published_callback(self, pdu): - """ - Check result of CRL publication. - """ - rpki.publication.raise_if_error(pdu) - self.crl_published = None - self.save() + @tornado.gen.coroutine + def xml_post_save_hook(self, rpkid, q_pdu): + if q_pdu.get("clear_replay_protection"): + self.clear_replay_protection() + if q_pdu.get("reissue"): + yield self.serve_reissue(rpkid) - def generate_manifest(self, publisher, nextUpdate = None): - """ - Generate a new manifest for this ca_detail. - """ + def serve_reissue(self, rpkid): + publisher = rpki.rpkid.publication_queue(rpkid) + for child_cert in self.child_certs.all(): + child_cert.reissue(child_cert.ca_detail, publisher, force = True) + yield publisher.call_pubd() - self.check_failed_publication(publisher) - - crl_interval = rpki.sundial.timedelta(seconds = self.ca.parent.tenant.crl_interval) - now = rpki.sundial.now() - uri = self.manifest_uri - if nextUpdate is None: - nextUpdate = now + crl_interval - if (self.latest_manifest_cert is None or - (self.latest_manifest_cert.getNotAfter() < nextUpdate and - self.latest_manifest_cert.getNotAfter() < self.latest_ca_cert.getNotAfter())): - logger.debug("Generating EE certificate for %s", uri) - self.generate_manifest_cert() - logger.debug("Latest CA cert notAfter %s, new %s EE notAfter %s", - self.latest_ca_cert.getNotAfter(), uri, self.latest_manifest_cert.getNotAfter()) - logger.debug("Constructing manifest object list for %s", uri) - objs = [(self.crl_uri_tail, self.latest_crl)] - objs.extend((c.uri_tail, c.cert) for c in self.child_certs.all()) - objs.extend((r.uri_tail, r.roa) for r in self.roas.filter(roa__isnull = False)) - objs.extend((g.uri_tail, g.ghostbuster) for g in self.ghostbusters.all()) - objs.extend((e.uri_tail, e.cert) for e in self.ee_certificates.all()) - logger.debug("Building manifest object %s", uri) - old_manifest = self.latest_manifest - self.latest_manifest = rpki.x509.SignedManifest.build( - serial = self.ca.next_manifest_number(), - thisUpdate = now, - nextUpdate = nextUpdate, - names_and_objs = objs, - keypair = self.manifest_private_key_id, - certs = self.latest_manifest_cert) - logger.debug("Manifest generation took %s", rpki.sundial.now() - now) - self.manifest_published = now - self.save() - publisher.queue(uri = uri, - old_obj = old_manifest, - new_obj = self.latest_manifest, - repository = self.ca.parent.repository, - handler = self.manifest_published_callback) - - - def manifest_published_callback(self, pdu): - """ - Check result of manifest publication. - """ - rpki.publication.raise_if_error(pdu) - self.manifest_published = None - self.save() + def clear_replay_protection(self): + self.last_cms_timestamp = None + self.save() - @tornado.gen.coroutine - def reissue(self, rpkid): - """ - Reissue all current certificates issued by this ca_detail. - """ + @tornado.gen.coroutine + def up_down_handle_list(self, rpkid, q_msg, r_msg): - publisher = rpki.rpkid.publication_queue(rpkid) - self.check_failed_publication(publisher) - for roa in self.roas.all(): - roa.regenerate(publisher, fast = True) - for ghostbuster in self.ghostbusters.all(): - ghostbuster.regenerate(publisher, fast = True) - for ee_certificate in self.ee_certificates.all(): - ee_certificate.reissue(publisher, force = True) - for child_cert in self.child_certs.all(): - child_cert.reissue(self, publisher, force = True) - self.generate_manifest_cert() - self.save() - self.generate_crl(publisher = publisher) - self.generate_manifest(publisher = publisher) - self.save() - yield publisher.call_pubd() - - - def check_failed_publication(self, publisher, check_all = True): - """ - Check for failed publication of objects issued by this ca_detail. - - All publishable objects have timestamp fields recording time of - last attempted publication, and callback methods which clear these - timestamps once publication has succeeded. Our task here is to - look for objects issued by this ca_detail which have timestamps - set (indicating that they have not been published) and for which - the timestamps are not very recent (for some definition of very - recent -- intent is to allow a bit of slack in case pubd is just - being slow). In such cases, we want to retry publication. - - As an optimization, we can probably skip checking other products - if manifest and CRL have been published, thus saving ourselves - several complex SQL queries. Not sure yet whether this - optimization is worthwhile. - - For the moment we check everything without optimization, because - it simplifies testing. - - For the moment our definition of staleness is hardwired; this - should become configurable. - """ + irdb_resources = yield rpkid.irdb_query_child_resources(self.tenant.tenant_handle, self.child_handle) - logger.debug("Checking for failed publication for %r", self) - - stale = rpki.sundial.now() - rpki.sundial.timedelta(seconds = 60) - repository = self.ca.parent.repository - if self.latest_crl is not None and self.crl_published is not None and self.crl_published < stale: - logger.debug("Retrying publication for %s", self.crl_uri) - publisher.queue(uri = self.crl_uri, - new_obj = self.latest_crl, - repository = repository, - handler = self.crl_published_callback) - if self.latest_manifest is not None and self.manifest_published is not None and self.manifest_published < stale: - logger.debug("Retrying publication for %s", self.manifest_uri) - publisher.queue(uri = self.manifest_uri, - new_obj = self.latest_manifest, - repository = repository, - handler = self.manifest_published_callback) - if not check_all: - return - for child_cert in self.child_certs.filter(published__isnull = False, published__lt = stale): - logger.debug("Retrying publication for %s", child_cert) - publisher.queue( - uri = child_cert.uri, - new_obj = child_cert.cert, - repository = repository, - handler = child_cert.published_callback) - for roa in self.roas.filter(published__isnull = False, published__lt = stale): - logger.debug("Retrying publication for %s", roa) - publisher.queue( - uri = roa.uri, - new_obj = roa.roa, - repository = repository, - handler = roa.published_callback) - for ghostbuster in self.ghostbusters.filter(published__isnull = False, published__lt = stale): - logger.debug("Retrying publication for %s", ghostbuster) - publisher.queue( - uri = ghostbuster.uri, - new_obj = ghostbuster.ghostbuster, - repository = repository, - handler = ghostbuster.published_callback) - for ee_cert in self.ee_certificates.filter(published__isnull = False, published__lt = stale): - logger.debug("Retrying publication for %s", ee_cert) - publisher.queue( - uri = ee_cert.uri, - new_obj = ee_cert.cert, - repository = repository, - handler = ee_cert.published_callback) + if irdb_resources.valid_until < rpki.sundial.now(): + logger.debug("Child %s's resources expired %s", self.child_handle, irdb_resources.valid_until) + else: -@xml_hooks -class Child(models.Model): - child_handle = models.SlugField(max_length = 255) - bpki_cert = CertificateField(null = True) - bpki_glue = CertificateField(null = True) - last_cms_timestamp = SundialField(null = True) - tenant = models.ForeignKey(Tenant, related_name = "children") - bsc = models.ForeignKey(BSC, related_name = "children") - objects = XMLManager() + for ca_detail in CADetail.objects.filter(ca__parent__tenant = self.tenant, state = "active"): + resources = ca_detail.latest_ca_cert.get_3779resources() & irdb_resources - class Meta: # pylint: disable=C1001,W0232 - unique_together = ("tenant", "child_handle") + if resources.empty(): + logger.debug("No overlap between received resources and what child %s should get ([%s], [%s])", + self.child_handle, ca_detail.latest_ca_cert.get_3779resources(), irdb_resources) + continue - xml_template = XMLTemplate( - name = "child", - handles = (BSC,), - elements = ("bpki_cert", "bpki_glue")) + rc = SubElement(r_msg, rpki.up_down.tag_class, + class_name = ca_detail.ca.parent_resource_class, + cert_url = ca_detail.ca_cert_uri, + resource_set_as = str(resources.asn), + resource_set_ipv4 = str(resources.v4), + resource_set_ipv6 = str(resources.v6), + resource_set_notafter = str(resources.valid_until)) + for child_cert in self.child_certs.filter(ca_detail = ca_detail): + c = SubElement(rc, rpki.up_down.tag_certificate, cert_url = child_cert.uri) + c.text = child_cert.cert.get_Base64() + SubElement(rc, rpki.up_down.tag_issuer).text = ca_detail.latest_ca_cert.get_Base64() - @tornado.gen.coroutine - def xml_pre_delete_hook(self, rpkid): - publisher = rpki.rpkid.publication_queue(rpkid) - for child_cert in self.child_certs.all(): - child_cert.revoke(publisher = publisher, generate_crl_and_manifest = True) - yield publisher.call_pubd() + @tornado.gen.coroutine + def up_down_handle_issue(self, rpkid, q_msg, r_msg): - @tornado.gen.coroutine - def xml_post_save_hook(self, rpkid, q_pdu): - if q_pdu.get("clear_replay_protection"): - self.clear_replay_protection() - if q_pdu.get("reissue"): - yield self.serve_reissue(rpkid) + req = q_msg[0] + assert req.tag == rpki.up_down.tag_request + # Subsetting not yet implemented, this is the one place where we have to handle it, by reporting that we're lame. - def serve_reissue(self, rpkid): - publisher = rpki.rpkid.publication_queue(rpkid) - for child_cert in self.child_certs.all(): - child_cert.reissue(child_cert.ca_detail, publisher, force = True) - yield publisher.call_pubd() + if any(req.get(a) for a in ("req_resource_set_as", "req_resource_set_ipv4", "req_resource_set_ipv6")): + raise rpki.exceptions.NotImplementedYet("req_* attributes not implemented yet, sorry") + class_name = req.get("class_name") + pkcs10 = rpki.x509.PKCS10(Base64 = req.text) + pkcs10.check_valid_request_ca() + ca_detail = CADetail.objects.get(ca__parent__tenant = self.tenant, state = "active", + ca__parent_resource_class = class_name) - def clear_replay_protection(self): - self.last_cms_timestamp = None - self.save() + irdb_resources = yield rpkid.irdb_query_child_resources(self.tenant.tenant_handle, self.child_handle) + if irdb_resources.valid_until < rpki.sundial.now(): + raise rpki.exceptions.IRDBExpired("IRDB entry for child %s expired %s" % ( + self.child_handle, irdb_resources.valid_until)) - @tornado.gen.coroutine - def up_down_handle_list(self, rpkid, q_msg, r_msg): + resources = irdb_resources & ca_detail.latest_ca_cert.get_3779resources() + resources.valid_until = irdb_resources.valid_until + req_key = pkcs10.getPublicKey() + req_sia = pkcs10.get_SIA() - irdb_resources = yield rpkid.irdb_query_child_resources(self.tenant.tenant_handle, self.child_handle) + # Generate new cert or regenerate old one if necessary - if irdb_resources.valid_until < rpki.sundial.now(): - logger.debug("Child %s's resources expired %s", self.child_handle, irdb_resources.valid_until) + publisher = rpki.rpkid.publication_queue(rpkid) - else: + try: + child_cert = self.child_certs.get(ca_detail = ca_detail, gski = req_key.gSKI()) - for ca_detail in CADetail.objects.filter(ca__parent__tenant = self.tenant, state = "active"): - resources = ca_detail.latest_ca_cert.get_3779resources() & irdb_resources + except ChildCert.DoesNotExist: + child_cert = ca_detail.issue( + ca = ca_detail.ca, + child = self, + subject_key = req_key, + sia = req_sia, + resources = resources, + publisher = publisher) - if resources.empty(): - logger.debug("No overlap between received resources and what child %s should get ([%s], [%s])", - self.child_handle, ca_detail.latest_ca_cert.get_3779resources(), irdb_resources) - continue + else: + child_cert = child_cert.reissue( + ca_detail = ca_detail, + sia = req_sia, + resources = resources, + publisher = publisher) + + yield publisher.call_pubd() rc = SubElement(r_msg, rpki.up_down.tag_class, - class_name = ca_detail.ca.parent_resource_class, + class_name = class_name, cert_url = ca_detail.ca_cert_uri, resource_set_as = str(resources.asn), resource_set_ipv4 = str(resources.v4), resource_set_ipv6 = str(resources.v6), resource_set_notafter = str(resources.valid_until)) - - for child_cert in self.child_certs.filter(ca_detail = ca_detail): - c = SubElement(rc, rpki.up_down.tag_certificate, cert_url = child_cert.uri) - c.text = child_cert.cert.get_Base64() + c = SubElement(rc, rpki.up_down.tag_certificate, cert_url = child_cert.uri) + c.text = child_cert.cert.get_Base64() SubElement(rc, rpki.up_down.tag_issuer).text = ca_detail.latest_ca_cert.get_Base64() - @tornado.gen.coroutine - def up_down_handle_issue(self, rpkid, q_msg, r_msg): - - req = q_msg[0] - assert req.tag == rpki.up_down.tag_request - - # Subsetting not yet implemented, this is the one place where we have to handle it, by reporting that we're lame. - - if any(req.get(a) for a in ("req_resource_set_as", "req_resource_set_ipv4", "req_resource_set_ipv6")): - raise rpki.exceptions.NotImplementedYet("req_* attributes not implemented yet, sorry") - - class_name = req.get("class_name") - pkcs10 = rpki.x509.PKCS10(Base64 = req.text) - pkcs10.check_valid_request_ca() - ca_detail = CADetail.objects.get(ca__parent__tenant = self.tenant, state = "active", - ca__parent_resource_class = class_name) - - irdb_resources = yield rpkid.irdb_query_child_resources(self.tenant.tenant_handle, self.child_handle) - - if irdb_resources.valid_until < rpki.sundial.now(): - raise rpki.exceptions.IRDBExpired("IRDB entry for child %s expired %s" % ( - self.child_handle, irdb_resources.valid_until)) - - resources = irdb_resources & ca_detail.latest_ca_cert.get_3779resources() - resources.valid_until = irdb_resources.valid_until - req_key = pkcs10.getPublicKey() - req_sia = pkcs10.get_SIA() - - # Generate new cert or regenerate old one if necessary - - publisher = rpki.rpkid.publication_queue(rpkid) - - try: - child_cert = self.child_certs.get(ca_detail = ca_detail, gski = req_key.gSKI()) - - except ChildCert.DoesNotExist: - child_cert = ca_detail.issue( - ca = ca_detail.ca, - child = self, - subject_key = req_key, - sia = req_sia, - resources = resources, - publisher = publisher) - - else: - child_cert = child_cert.reissue( - ca_detail = ca_detail, - sia = req_sia, - resources = resources, - publisher = publisher) - - yield publisher.call_pubd() - - rc = SubElement(r_msg, rpki.up_down.tag_class, - class_name = class_name, - cert_url = ca_detail.ca_cert_uri, - resource_set_as = str(resources.asn), - resource_set_ipv4 = str(resources.v4), - resource_set_ipv6 = str(resources.v6), - resource_set_notafter = str(resources.valid_until)) - c = SubElement(rc, rpki.up_down.tag_certificate, cert_url = child_cert.uri) - c.text = child_cert.cert.get_Base64() - SubElement(rc, rpki.up_down.tag_issuer).text = ca_detail.latest_ca_cert.get_Base64() - + @tornado.gen.coroutine + def up_down_handle_revoke(self, rpkid, q_msg, r_msg): + key = q_msg[0] + assert key.tag == rpki.up_down.tag_key + class_name = key.get("class_name") + publisher = rpki.rpkid.publication_queue(rpkid) + for child_cert in ChildCert.objects.filter(ca_detail__ca__parent__tenant = self.tenant, + ca_detail__ca__parent_resource_class = class_name, + gski = key.get("ski")): + child_cert.revoke(publisher = publisher) + yield publisher.call_pubd() + SubElement(r_msg, key.tag, class_name = class_name, ski = key.get("ski")) - @tornado.gen.coroutine - def up_down_handle_revoke(self, rpkid, q_msg, r_msg): - key = q_msg[0] - assert key.tag == rpki.up_down.tag_key - class_name = key.get("class_name") - publisher = rpki.rpkid.publication_queue(rpkid) - for child_cert in ChildCert.objects.filter(ca_detail__ca__parent__tenant = self.tenant, - ca_detail__ca__parent_resource_class = class_name, - gski = key.get("ski")): - child_cert.revoke(publisher = publisher) - yield publisher.call_pubd() - SubElement(r_msg, key.tag, class_name = class_name, ski = key.get("ski")) + @tornado.gen.coroutine + def serve_up_down(self, rpkid, q_der): + """ + Outer layer of server handling for one up-down PDU from this child. + """ - @tornado.gen.coroutine - def serve_up_down(self, rpkid, q_der): - """ - Outer layer of server handling for one up-down PDU from this child. - """ + if self.bsc is None: + raise rpki.exceptions.BSCNotFound("Could not find BSC") - if self.bsc is None: - raise rpki.exceptions.BSCNotFound("Could not find BSC") + q_cms = rpki.up_down.cms_msg(DER = q_der) + q_msg = q_cms.unwrap((rpkid.bpki_ta, self.tenant.bpki_cert, self.tenant.bpki_glue, self.bpki_cert, self.bpki_glue)) + q_cms.check_replay_sql(self, "child", self.child_handle) + q_type = q_msg.get("type") - q_cms = rpki.up_down.cms_msg(DER = q_der) - q_msg = q_cms.unwrap((rpkid.bpki_ta, self.tenant.bpki_cert, self.tenant.bpki_glue, self.bpki_cert, self.bpki_glue)) - q_cms.check_replay_sql(self, "child", self.child_handle) - q_type = q_msg.get("type") + logger.info("Serving %s query from child %s [sender %s, recipient %s]", + q_type, self.child_handle, q_msg.get("sender"), q_msg.get("recipient")) - logger.info("Serving %s query from child %s [sender %s, recipient %s]", - q_type, self.child_handle, q_msg.get("sender"), q_msg.get("recipient")) + if rpki.up_down.enforce_strict_up_down_xml_sender and q_msg.get("sender") != self.child_handle: + raise rpki.exceptions.BadSender("Unexpected XML sender %s" % q_msg.get("sender")) - if rpki.up_down.enforce_strict_up_down_xml_sender and q_msg.get("sender") != self.child_handle: - raise rpki.exceptions.BadSender("Unexpected XML sender %s" % q_msg.get("sender")) + r_msg = Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap, version = rpki.up_down.version, + sender = q_msg.get("recipient"), recipient = q_msg.get("sender"), type = q_type + "_response") - r_msg = Element(rpki.up_down.tag_message, nsmap = rpki.up_down.nsmap, version = rpki.up_down.version, - sender = q_msg.get("recipient"), recipient = q_msg.get("sender"), type = q_type + "_response") + try: + yield getattr(self, "up_down_handle_" + q_type)(rpkid, q_msg, r_msg) - try: - yield getattr(self, "up_down_handle_" + q_type)(rpkid, q_msg, r_msg) + except Exception, e: + logger.exception("Unhandled exception serving child %r", self) + rpki.up_down.generate_error_response_from_exception(r_msg, e, q_type) - except Exception, e: - logger.exception("Unhandled exception serving child %r", self) - rpki.up_down.generate_error_response_from_exception(r_msg, e, q_type) - - r_der = rpki.up_down.cms_msg().wrap(r_msg, self.bsc.private_key_id, self.bsc.signing_cert, self.bsc.signing_cert_crl) - raise tornado.gen.Return(r_der) + r_der = rpki.up_down.cms_msg().wrap(r_msg, self.bsc.private_key_id, self.bsc.signing_cert, self.bsc.signing_cert_crl) + raise tornado.gen.Return(r_der) class ChildCert(models.Model): - cert = CertificateField() - published = SundialField(null = True) - gski = models.CharField(max_length = 27) # Assumes SHA-1 -- SHA-256 would be 43, SHA-512 would be 86, etc. - child = models.ForeignKey(Child, related_name = "child_certs") - ca_detail = models.ForeignKey(CADetail, related_name = "child_certs") - - - @property - def uri_tail(self): - """ - Return the tail (filename) portion of the URI for this child_cert. - """ - - return self.gski + ".cer" - - - @property - def uri(self): - """ - Return the publication URI for this child_cert. - """ - - return self.ca_detail.ca.sia_uri + self.uri_tail - - - def revoke(self, publisher, generate_crl_and_manifest = True): - """ - Revoke a child cert. - """ - - ca_detail = self.ca_detail - logger.debug("Revoking %r %r", self, self.uri) - RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) - publisher.queue(uri = self.uri, old_obj = self.cert, repository = ca_detail.ca.parent.repository) - self.delete() - if generate_crl_and_manifest: - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - - def reissue(self, ca_detail, publisher, resources = None, sia = None, force = False): - """ - Reissue an existing child cert, reusing the public key. If the - child cert we would generate is identical to the one we already - have, we just return the one we already have. If we have to - revoke the old child cert when generating the new one, we have to - generate a new child_cert_obj, so calling code that needs the - updated child_cert_obj must use the return value from this method. - """ - - ca = ca_detail.ca - child = self.child - old_resources = self.cert.get_3779resources() - old_sia = self.cert.get_SIA() - old_aia = self.cert.get_AIA()[0] - old_ca_detail = self.ca_detail - needed = False - if resources is None: - resources = old_resources - if sia is None: - sia = old_sia - assert resources.valid_until is not None and old_resources.valid_until is not None - if resources.asn != old_resources.asn or resources.v4 != old_resources.v4 or resources.v6 != old_resources.v6: - logger.debug("Resources changed for %r: old %s new %s", self, old_resources, resources) - needed = True - if resources.valid_until != old_resources.valid_until: - logger.debug("Validity changed for %r: old %s new %s", - self, old_resources.valid_until, resources.valid_until) - needed = True - if sia != old_sia: - logger.debug("SIA changed for %r: old %r new %r", self, old_sia, sia) - needed = True - if ca_detail != old_ca_detail: - logger.debug("Issuer changed for %r: old %r new %r", self, old_ca_detail, ca_detail) - needed = True - if ca_detail.ca_cert_uri != old_aia: - logger.debug("AIA changed for %r: old %r new %r", self, old_aia, ca_detail.ca_cert_uri) - needed = True - must_revoke = old_resources.oversized(resources) or old_resources.valid_until > resources.valid_until - if must_revoke: - logger.debug("Must revoke any existing cert(s) for %r", self) - needed = True - if not needed and force: - logger.debug("No change needed for %r, forcing reissuance anyway", self) - needed = True - if not needed: - logger.debug("No change to %r", self) - return self - if must_revoke: - for x in child.child_certs.filter(ca_detail = ca_detail, gski = self.gski): - logger.debug("Revoking child_cert %r", x) - x.revoke(publisher = publisher) - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - child_cert = ca_detail.issue( - ca = ca, - child = child, - subject_key = self.cert.getPublicKey(), - sia = sia, - resources = resources, - child_cert = None if must_revoke else self, - publisher = publisher) - logger.debug("New child_cert %r uri %s", child_cert, child_cert.uri) - return child_cert - - - def published_callback(self, pdu): - """ - Publication callback: check result and mark published. - """ - - rpki.publication.raise_if_error(pdu) - self.published = None - self.save() + cert = CertificateField() + published = SundialField(null = True) + gski = models.CharField(max_length = 27) # Assumes SHA-1 -- SHA-256 would be 43, SHA-512 would be 86, etc. + child = models.ForeignKey(Child, related_name = "child_certs") + ca_detail = models.ForeignKey(CADetail, related_name = "child_certs") + + + @property + def uri_tail(self): + """ + Return the tail (filename) portion of the URI for this child_cert. + """ + + return self.gski + ".cer" + + + @property + def uri(self): + """ + Return the publication URI for this child_cert. + """ + + return self.ca_detail.ca.sia_uri + self.uri_tail + + + def revoke(self, publisher, generate_crl_and_manifest = True): + """ + Revoke a child cert. + """ + + ca_detail = self.ca_detail + logger.debug("Revoking %r %r", self, self.uri) + RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) + publisher.queue(uri = self.uri, old_obj = self.cert, repository = ca_detail.ca.parent.repository) + self.delete() + if generate_crl_and_manifest: + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + + def reissue(self, ca_detail, publisher, resources = None, sia = None, force = False): + """ + Reissue an existing child cert, reusing the public key. If the + child cert we would generate is identical to the one we already + have, we just return the one we already have. If we have to + revoke the old child cert when generating the new one, we have to + generate a new child_cert_obj, so calling code that needs the + updated child_cert_obj must use the return value from this method. + """ + + ca = ca_detail.ca + child = self.child + old_resources = self.cert.get_3779resources() + old_sia = self.cert.get_SIA() + old_aia = self.cert.get_AIA()[0] + old_ca_detail = self.ca_detail + needed = False + if resources is None: + resources = old_resources + if sia is None: + sia = old_sia + assert resources.valid_until is not None and old_resources.valid_until is not None + if resources.asn != old_resources.asn or resources.v4 != old_resources.v4 or resources.v6 != old_resources.v6: + logger.debug("Resources changed for %r: old %s new %s", self, old_resources, resources) + needed = True + if resources.valid_until != old_resources.valid_until: + logger.debug("Validity changed for %r: old %s new %s", + self, old_resources.valid_until, resources.valid_until) + needed = True + if sia != old_sia: + logger.debug("SIA changed for %r: old %r new %r", self, old_sia, sia) + needed = True + if ca_detail != old_ca_detail: + logger.debug("Issuer changed for %r: old %r new %r", self, old_ca_detail, ca_detail) + needed = True + if ca_detail.ca_cert_uri != old_aia: + logger.debug("AIA changed for %r: old %r new %r", self, old_aia, ca_detail.ca_cert_uri) + needed = True + must_revoke = old_resources.oversized(resources) or old_resources.valid_until > resources.valid_until + if must_revoke: + logger.debug("Must revoke any existing cert(s) for %r", self) + needed = True + if not needed and force: + logger.debug("No change needed for %r, forcing reissuance anyway", self) + needed = True + if not needed: + logger.debug("No change to %r", self) + return self + if must_revoke: + for x in child.child_certs.filter(ca_detail = ca_detail, gski = self.gski): + logger.debug("Revoking child_cert %r", x) + x.revoke(publisher = publisher) + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + child_cert = ca_detail.issue( + ca = ca, + child = child, + subject_key = self.cert.getPublicKey(), + sia = sia, + resources = resources, + child_cert = None if must_revoke else self, + publisher = publisher) + logger.debug("New child_cert %r uri %s", child_cert, child_cert.uri) + return child_cert + + + def published_callback(self, pdu): + """ + Publication callback: check result and mark published. + """ + + rpki.publication.raise_if_error(pdu) + self.published = None + self.save() class EECertificate(models.Model): - gski = models.CharField(max_length = 27) # Assumes SHA-1 -- SHA-256 would be 43, SHA-512 would be 86, etc. - cert = CertificateField() - published = SundialField(null = True) - tenant = models.ForeignKey(Tenant, related_name = "ee_certificates") - ca_detail = models.ForeignKey(CADetail, related_name = "ee_certificates") - - - @property - def uri(self): - """ - Return the publication URI for this ee_cert_obj. - """ - - return self.ca_detail.ca.sia_uri + self.uri_tail - - - @property - def uri_tail(self): - """ - Return the tail (filename portion) of the publication URI for this - ee_cert_obj. - """ - - return self.gski + ".cer" - - - @classmethod - def create(cls, ca_detail, subject_name, subject_key, resources, publisher, eku = None): - """ - Generate a new EE certificate. - """ - - # The low-level X.509 code really ought to supply the singleton - # tuple wrapper when handed a string, but that yak will need to - # wait until another day for its shave. - - cn, sn = subject_name.extract_cn_and_sn() - sia = (None, None, - (ca_detail.ca.sia_uri + subject_key.gSKI() + ".cer",), - (ca_detail.ca.parent.repository.rrdp_notification_uri,)) - cert = ca_detail.issue_ee( - ca = ca_detail.ca, - subject_key = subject_key, - sia = sia, - resources = resources, - notAfter = resources.valid_until, - cn = cn, - sn = sn, - eku = eku) - self = cls(tenant = ca_detail.ca.parent.tenant, ca_detail = ca_detail, cert = cert, gski = subject_key.gSKI()) - publisher.queue( - uri = self.uri, - new_obj = self.cert, - repository = ca_detail.ca.parent.repository, - handler = self.published_callback) - self.save() - ca_detail.generate_manifest(publisher = publisher) - logger.debug("New ee_cert %r", self) - return self - - - def revoke(self, publisher, generate_crl_and_manifest = True): - """ - Revoke and withdraw an EE certificate. - """ - - ca_detail = self.ca_detail - logger.debug("Revoking %r %r", self, self.uri) - RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) - publisher.queue(uri = self.uri, old_obj = self.cert, repository = ca_detail.ca.parent.repository) - self.delete() - if generate_crl_and_manifest: - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - - def reissue(self, publisher, ca_detail = None, resources = None, force = False): - """ - Reissue an existing EE cert, reusing the public key. If the EE - cert we would generate is identical to the one we already have, we - just return; if we need to reissue, we reuse this ee_cert_obj and - just update its contents, as the publication URI will not have - changed. - """ - - needed = False - old_cert = self.cert - old_ca_detail = self.ca_detail - if ca_detail is None: - ca_detail = old_ca_detail - assert ca_detail.ca is old_ca_detail.ca - old_resources = old_cert.get_3779resources() - if resources is None: - resources = old_resources - assert resources.valid_until is not None and old_resources.valid_until is not None - assert ca_detail.covers(resources) - if ca_detail != self.ca_detail: - logger.debug("ca_detail changed for %r: old %r new %r", self, self.ca_detail, ca_detail) - needed = True - if ca_detail.ca_cert_uri != old_cert.get_AIA()[0]: - logger.debug("AIA changed for %r: old %s new %s", self, old_cert.get_AIA()[0], ca_detail.ca_cert_uri) - needed = True - if resources.valid_until != old_resources.valid_until: - logger.debug("Validity changed for %r: old %s new %s", self, old_resources.valid_until, resources.valid_until) - needed = True - if resources.asn != old_resources.asn or resources.v4 != old_resources.v4 or resources.v6 != old_resources.v6: - logger.debug("Resources changed for %r: old %s new %s", self, old_resources, resources) - needed = True - must_revoke = old_resources.oversized(resources) or old_resources.valid_until > resources.valid_until - if must_revoke: - logger.debug("Must revoke existing cert(s) for %r", self) - needed = True - if not needed and force: - logger.debug("No change needed for %r, forcing reissuance anyway", self) - needed = True - if not needed: - logger.debug("No change to %r", self) - return - cn, sn = self.cert.getSubject().extract_cn_and_sn() - self.cert = ca_detail.issue_ee( - ca = ca_detail.ca, - subject_key = self.cert.getPublicKey(), - eku = self.cert.get_EKU(), - sia = (None, None, self.uri, ca_detail.ca.parent.repository.rrdp_notification_uri), - resources = resources, - notAfter = resources.valid_until, - cn = cn, - sn = sn) - self.save() - publisher.queue( - uri = self.uri, - old_obj = old_cert, - new_obj = self.cert, - repository = ca_detail.ca.parent.repository, - handler = self.published_callback) - if must_revoke: - RevokedCert.revoke(cert = old_cert.cert, ca_detail = old_ca_detail) - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - - def published_callback(self, pdu): - """ - Publication callback: check result and mark published. - """ - - rpki.publication.raise_if_error(pdu) - self.published = None - self.save() + gski = models.CharField(max_length = 27) # Assumes SHA-1 -- SHA-256 would be 43, SHA-512 would be 86, etc. + cert = CertificateField() + published = SundialField(null = True) + tenant = models.ForeignKey(Tenant, related_name = "ee_certificates") + ca_detail = models.ForeignKey(CADetail, related_name = "ee_certificates") + + + @property + def uri(self): + """ + Return the publication URI for this ee_cert_obj. + """ + + return self.ca_detail.ca.sia_uri + self.uri_tail + + + @property + def uri_tail(self): + """ + Return the tail (filename portion) of the publication URI for this + ee_cert_obj. + """ + + return self.gski + ".cer" + + + @classmethod + def create(cls, ca_detail, subject_name, subject_key, resources, publisher, eku = None): + """ + Generate a new EE certificate. + """ + + # The low-level X.509 code really ought to supply the singleton + # tuple wrapper when handed a string, but that yak will need to + # wait until another day for its shave. + + cn, sn = subject_name.extract_cn_and_sn() + sia = (None, None, + (ca_detail.ca.sia_uri + subject_key.gSKI() + ".cer",), + (ca_detail.ca.parent.repository.rrdp_notification_uri,)) + cert = ca_detail.issue_ee( + ca = ca_detail.ca, + subject_key = subject_key, + sia = sia, + resources = resources, + notAfter = resources.valid_until, + cn = cn, + sn = sn, + eku = eku) + self = cls(tenant = ca_detail.ca.parent.tenant, ca_detail = ca_detail, cert = cert, gski = subject_key.gSKI()) + publisher.queue( + uri = self.uri, + new_obj = self.cert, + repository = ca_detail.ca.parent.repository, + handler = self.published_callback) + self.save() + ca_detail.generate_manifest(publisher = publisher) + logger.debug("New ee_cert %r", self) + return self + + + def revoke(self, publisher, generate_crl_and_manifest = True): + """ + Revoke and withdraw an EE certificate. + """ + + ca_detail = self.ca_detail + logger.debug("Revoking %r %r", self, self.uri) + RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) + publisher.queue(uri = self.uri, old_obj = self.cert, repository = ca_detail.ca.parent.repository) + self.delete() + if generate_crl_and_manifest: + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + + def reissue(self, publisher, ca_detail = None, resources = None, force = False): + """ + Reissue an existing EE cert, reusing the public key. If the EE + cert we would generate is identical to the one we already have, we + just return; if we need to reissue, we reuse this ee_cert_obj and + just update its contents, as the publication URI will not have + changed. + """ + + needed = False + old_cert = self.cert + old_ca_detail = self.ca_detail + if ca_detail is None: + ca_detail = old_ca_detail + assert ca_detail.ca is old_ca_detail.ca + old_resources = old_cert.get_3779resources() + if resources is None: + resources = old_resources + assert resources.valid_until is not None and old_resources.valid_until is not None + assert ca_detail.covers(resources) + if ca_detail != self.ca_detail: + logger.debug("ca_detail changed for %r: old %r new %r", self, self.ca_detail, ca_detail) + needed = True + if ca_detail.ca_cert_uri != old_cert.get_AIA()[0]: + logger.debug("AIA changed for %r: old %s new %s", self, old_cert.get_AIA()[0], ca_detail.ca_cert_uri) + needed = True + if resources.valid_until != old_resources.valid_until: + logger.debug("Validity changed for %r: old %s new %s", self, old_resources.valid_until, resources.valid_until) + needed = True + if resources.asn != old_resources.asn or resources.v4 != old_resources.v4 or resources.v6 != old_resources.v6: + logger.debug("Resources changed for %r: old %s new %s", self, old_resources, resources) + needed = True + must_revoke = old_resources.oversized(resources) or old_resources.valid_until > resources.valid_until + if must_revoke: + logger.debug("Must revoke existing cert(s) for %r", self) + needed = True + if not needed and force: + logger.debug("No change needed for %r, forcing reissuance anyway", self) + needed = True + if not needed: + logger.debug("No change to %r", self) + return + cn, sn = self.cert.getSubject().extract_cn_and_sn() + self.cert = ca_detail.issue_ee( + ca = ca_detail.ca, + subject_key = self.cert.getPublicKey(), + eku = self.cert.get_EKU(), + sia = (None, None, self.uri, ca_detail.ca.parent.repository.rrdp_notification_uri), + resources = resources, + notAfter = resources.valid_until, + cn = cn, + sn = sn) + self.save() + publisher.queue( + uri = self.uri, + old_obj = old_cert, + new_obj = self.cert, + repository = ca_detail.ca.parent.repository, + handler = self.published_callback) + if must_revoke: + RevokedCert.revoke(cert = old_cert.cert, ca_detail = old_ca_detail) + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + + def published_callback(self, pdu): + """ + Publication callback: check result and mark published. + """ + + rpki.publication.raise_if_error(pdu) + self.published = None + self.save() class Ghostbuster(models.Model): - vcard = models.TextField() - cert = CertificateField() - ghostbuster = GhostbusterField() - published = SundialField(null = True) - tenant = models.ForeignKey(Tenant, related_name = "ghostbusters") - ca_detail = models.ForeignKey(CADetail, related_name = "ghostbusters") - - - def update(self, publisher, fast = False): - """ - Bring this ghostbuster_obj up to date if necesssary. - """ - - if self.ghostbuster is None: - logger.debug("Ghostbuster record doesn't exist, generating") - return self.generate(publisher = publisher, fast = fast) - - now = rpki.sundial.now() - regen_time = self.cert.getNotAfter() - rpki.sundial.timedelta(seconds = self.tenant.regen_margin) - - if now > regen_time and self.cert.getNotAfter() < self.ca_detail.latest_ca_cert.getNotAfter(): - logger.debug("%r past threshold %s, regenerating", self, regen_time) - return self.regenerate(publisher = publisher, fast = fast) - - if now > regen_time: - logger.warning("%r is past threshold %s but so is issuer %r, can't regenerate", self, regen_time, self.ca_detail) - - if self.cert.get_AIA()[0] != self.ca_detail.ca_cert_uri: - logger.debug("%r AIA changed, regenerating", self) - return self.regenerate(publisher = publisher, fast = fast) - - - def generate(self, publisher, fast = False): - """ - Generate a Ghostbuster record - - Once we have the right covering certificate, we generate the - ghostbuster payload, generate a new EE certificate, use the EE - certificate to sign the ghostbuster payload, publish the result, - then throw away the private key for the EE cert. This is modeled - after the way we handle ROAs. - - If fast is set, we leave generating the new manifest for our - caller to handle, presumably at the end of a bulk operation. - """ - - resources = rpki.resource_set.resource_bag.from_inheritance() - keypair = rpki.x509.RSA.generate() - self.cert = self.ca_detail.issue_ee( - ca = self.ca_detail.ca, - resources = resources, - subject_key = keypair.get_public(), - sia = (None, None, self.uri_from_key(keypair), self.ca_detail.ca.parent.repository.rrdp_notification_uri)) - self.ghostbuster = rpki.x509.Ghostbuster.build(self.vcard, keypair, (self.cert,)) - self.published = rpki.sundial.now() - self.save() - logger.debug("Generating Ghostbuster record %r", self.uri) - publisher.queue( - uri = self.uri, - new_obj = self.ghostbuster, - repository = self.ca_detail.ca.parent.repository, - handler = self.published_callback) - if not fast: - self.ca_detail.generate_manifest(publisher = publisher) - - - def published_callback(self, pdu): - """ - Check publication result. - """ - - rpki.publication.raise_if_error(pdu) - self.published = None - self.save() - - - def revoke(self, publisher, regenerate = False, allow_failure = False, fast = False): - """ - Withdraw Ghostbuster associated with this ghostbuster_obj. - - In order to preserve make-before-break properties without - duplicating code, this method also handles generating a - replacement ghostbuster when requested. - - If allow_failure is set, failing to withdraw the ghostbuster will not be - considered an error. - - If fast is set, SQL actions will be deferred, on the assumption - that our caller will handle regenerating CRL and manifest and - flushing the SQL cache. - """ - - ca_detail = self.ca_detail - logger.debug("%s %r, ca_detail %r state is %s", - "Regenerating" if regenerate else "Not regenerating", - self, ca_detail, ca_detail.state) - if regenerate: - self.generate(publisher = publisher, fast = fast) - logger.debug("Withdrawing %r %s and revoking its EE cert", self, self.uri) - RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) - publisher.queue(uri = self.uri, - old_obj = self.ghostbuster, - repository = ca_detail.ca.parent.repository, - handler = False if allow_failure else None) - if not regenerate: - self.delete() - if not fast: - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - - def regenerate(self, publisher, fast = False): - """ - Reissue Ghostbuster associated with this ghostbuster_obj. - """ + vcard = models.TextField() + cert = CertificateField() + ghostbuster = GhostbusterField() + published = SundialField(null = True) + tenant = models.ForeignKey(Tenant, related_name = "ghostbusters") + ca_detail = models.ForeignKey(CADetail, related_name = "ghostbusters") + + + def update(self, publisher, fast = False): + """ + Bring this ghostbuster_obj up to date if necesssary. + """ + + if self.ghostbuster is None: + logger.debug("Ghostbuster record doesn't exist, generating") + return self.generate(publisher = publisher, fast = fast) + + now = rpki.sundial.now() + regen_time = self.cert.getNotAfter() - rpki.sundial.timedelta(seconds = self.tenant.regen_margin) + + if now > regen_time and self.cert.getNotAfter() < self.ca_detail.latest_ca_cert.getNotAfter(): + logger.debug("%r past threshold %s, regenerating", self, regen_time) + return self.regenerate(publisher = publisher, fast = fast) + + if now > regen_time: + logger.warning("%r is past threshold %s but so is issuer %r, can't regenerate", self, regen_time, self.ca_detail) + + if self.cert.get_AIA()[0] != self.ca_detail.ca_cert_uri: + logger.debug("%r AIA changed, regenerating", self) + return self.regenerate(publisher = publisher, fast = fast) + + + def generate(self, publisher, fast = False): + """ + Generate a Ghostbuster record + + Once we have the right covering certificate, we generate the + ghostbuster payload, generate a new EE certificate, use the EE + certificate to sign the ghostbuster payload, publish the result, + then throw away the private key for the EE cert. This is modeled + after the way we handle ROAs. + + If fast is set, we leave generating the new manifest for our + caller to handle, presumably at the end of a bulk operation. + """ + + resources = rpki.resource_set.resource_bag.from_inheritance() + keypair = rpki.x509.RSA.generate() + self.cert = self.ca_detail.issue_ee( + ca = self.ca_detail.ca, + resources = resources, + subject_key = keypair.get_public(), + sia = (None, None, self.uri_from_key(keypair), self.ca_detail.ca.parent.repository.rrdp_notification_uri)) + self.ghostbuster = rpki.x509.Ghostbuster.build(self.vcard, keypair, (self.cert,)) + self.published = rpki.sundial.now() + self.save() + logger.debug("Generating Ghostbuster record %r", self.uri) + publisher.queue( + uri = self.uri, + new_obj = self.ghostbuster, + repository = self.ca_detail.ca.parent.repository, + handler = self.published_callback) + if not fast: + self.ca_detail.generate_manifest(publisher = publisher) + + + def published_callback(self, pdu): + """ + Check publication result. + """ + + rpki.publication.raise_if_error(pdu) + self.published = None + self.save() + + + def revoke(self, publisher, regenerate = False, allow_failure = False, fast = False): + """ + Withdraw Ghostbuster associated with this ghostbuster_obj. + + In order to preserve make-before-break properties without + duplicating code, this method also handles generating a + replacement ghostbuster when requested. + + If allow_failure is set, failing to withdraw the ghostbuster will not be + considered an error. + + If fast is set, SQL actions will be deferred, on the assumption + that our caller will handle regenerating CRL and manifest and + flushing the SQL cache. + """ + + ca_detail = self.ca_detail + logger.debug("%s %r, ca_detail %r state is %s", + "Regenerating" if regenerate else "Not regenerating", + self, ca_detail, ca_detail.state) + if regenerate: + self.generate(publisher = publisher, fast = fast) + logger.debug("Withdrawing %r %s and revoking its EE cert", self, self.uri) + RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) + publisher.queue(uri = self.uri, + old_obj = self.ghostbuster, + repository = ca_detail.ca.parent.repository, + handler = False if allow_failure else None) + if not regenerate: + self.delete() + if not fast: + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + + def regenerate(self, publisher, fast = False): + """ + Reissue Ghostbuster associated with this ghostbuster_obj. + """ + + if self.ghostbuster is None: + self.generate(publisher = publisher, fast = fast) + else: + self.revoke(publisher = publisher, regenerate = True, fast = fast) - if self.ghostbuster is None: - self.generate(publisher = publisher, fast = fast) - else: - self.revoke(publisher = publisher, regenerate = True, fast = fast) + def uri_from_key(self, key): + """ + Return publication URI for a public key. + """ - def uri_from_key(self, key): - """ - Return publication URI for a public key. - """ + return self.ca_detail.ca.sia_uri + key.gSKI() + ".gbr" - return self.ca_detail.ca.sia_uri + key.gSKI() + ".gbr" + @property + def uri(self): + """ + Return the publication URI for this ghostbuster_obj's ghostbuster. + """ - @property - def uri(self): - """ - Return the publication URI for this ghostbuster_obj's ghostbuster. - """ + return self.ca_detail.ca.sia_uri + self.uri_tail - return self.ca_detail.ca.sia_uri + self.uri_tail + @property + def uri_tail(self): + """ + Return the tail (filename portion) of the publication URI for this + ghostbuster_obj's ghostbuster. + """ - @property - def uri_tail(self): - """ - Return the tail (filename portion) of the publication URI for this - ghostbuster_obj's ghostbuster. - """ - - return self.cert.gSKI() + ".gbr" + return self.cert.gSKI() + ".gbr" class RevokedCert(models.Model): - serial = models.BigIntegerField() - revoked = SundialField() - expires = SundialField() - ca_detail = models.ForeignKey(CADetail, related_name = "revoked_certs") + serial = models.BigIntegerField() + revoked = SundialField() + expires = SundialField() + ca_detail = models.ForeignKey(CADetail, related_name = "revoked_certs") - @classmethod - def revoke(cls, cert, ca_detail): - """ - Revoke a certificate. - """ + @classmethod + def revoke(cls, cert, ca_detail): + """ + Revoke a certificate. + """ - return cls.objects.create( - serial = cert.getSerial(), - expires = cert.getNotAfter(), - revoked = rpki.sundial.now(), - ca_detail = ca_detail) + return cls.objects.create( + serial = cert.getSerial(), + expires = cert.getNotAfter(), + revoked = rpki.sundial.now(), + ca_detail = ca_detail) class ROA(models.Model): - asn = models.BigIntegerField() - ipv4 = models.TextField(null = True) - ipv6 = models.TextField(null = True) - cert = CertificateField() - roa = ROAField() - published = SundialField(null = True) - tenant = models.ForeignKey(Tenant, related_name = "roas") - ca_detail = models.ForeignKey(CADetail, related_name = "roas") - - - def update(self, publisher, fast = False): - """ - Bring ROA up to date if necesssary. - """ - - if self.roa is None: - logger.debug("%r doesn't exist, generating", self) - return self.generate(publisher = publisher, fast = fast) - - if self.ca_detail is None: - logger.debug("%r has no associated ca_detail, generating", self) - return self.generate(publisher = publisher, fast = fast) - - if self.ca_detail.state != "active": - logger.debug("ca_detail associated with %r not active (state %s), regenerating", self, self.ca_detail.state) - return self.regenerate(publisher = publisher, fast = fast) - - now = rpki.sundial.now() - regen_time = self.cert.getNotAfter() - rpki.sundial.timedelta(seconds = self.tenant.regen_margin) - - if now > regen_time and self.cert.getNotAfter() < self.ca_detail.latest_ca_cert.getNotAfter(): - logger.debug("%r past threshold %s, regenerating", self, regen_time) - return self.regenerate(publisher = publisher, fast = fast) - - if now > regen_time: - logger.warning("%r is past threshold %s but so is issuer %r, can't regenerate", self, regen_time, self.ca_detail) - - ca_resources = self.ca_detail.latest_ca_cert.get_3779resources() - ee_resources = self.cert.get_3779resources() - - if ee_resources.oversized(ca_resources): - logger.debug("%r oversized with respect to CA, regenerating", self) - return self.regenerate(publisher = publisher, fast = fast) - - v4 = rpki.resource_set.resource_set_ipv4(self.ipv4) - v6 = rpki.resource_set.resource_set_ipv6(self.ipv6) - - if ee_resources.v4 != v4 or ee_resources.v6 != v6: - logger.debug("%r resources do not match EE, regenerating", self) - return self.regenerate(publisher = publisher, fast = fast) - - if self.cert.get_AIA()[0] != self.ca_detail.ca_cert_uri: - logger.debug("%r AIA changed, regenerating", self) - return self.regenerate(publisher = publisher, fast = fast) - - - def generate(self, publisher, fast = False): - """ - Generate a ROA. - - At present we have no way of performing a direct lookup from a - desired set of resources to a covering certificate, so we have to - search. This could be quite slow if we have a lot of active - ca_detail objects. Punt on the issue for now, revisit if - profiling shows this as a hotspot. - - Once we have the right covering certificate, we generate the ROA - payload, generate a new EE certificate, use the EE certificate to - sign the ROA payload, publish the result, then throw away the - private key for the EE cert, all per the ROA specification. This - implies that generating a lot of ROAs will tend to thrash - /dev/random, but there is not much we can do about that. - - If fast is set, we leave generating the new manifest for our - caller to handle, presumably at the end of a bulk operation. - """ - - if self.ipv4 is None and self.ipv6 is None: - raise rpki.exceptions.EmptyROAPrefixList - - v4 = rpki.resource_set.resource_set_ipv4(self.ipv4) - v6 = rpki.resource_set.resource_set_ipv6(self.ipv6) - - # http://stackoverflow.com/questions/26270042/how-do-you-catch-this-exception - # "Django is amazing when its not terrifying." - try: - ca_detail = self.ca_detail - except CADetail.DoesNotExist: - ca_detail = None - - if ca_detail is not None and ca_detail.state == "active" and not ca_detail.has_expired(): - logger.debug("Keeping old ca_detail %r for ROA %r", ca_detail, self) - else: - logger.debug("Searching for new ca_detail for ROA %r", self) - for ca_detail in CADetail.objects.filter(ca__parent__tenant = self.tenant, state = "active"): - resources = ca_detail.latest_ca_cert.get_3779resources() - if not ca_detail.has_expired() and v4.issubset(resources.v4) and v6.issubset(resources.v6): - logger.debug("Using new ca_detail %r for ROA %r", ca_detail, self) - self.ca_detail = ca_detail - break - else: - raise rpki.exceptions.NoCoveringCertForROA("Could not find a certificate covering %r" % self) - - resources = rpki.resource_set.resource_bag(v4 = v4, v6 = v6) - keypair = rpki.x509.RSA.generate() - - self.cert = self.ca_detail.issue_ee( - ca = self.ca_detail.ca, - resources = resources, - subject_key = keypair.get_public(), - sia = (None, None, self.uri_from_key(keypair), self.ca_detail.ca.parent.repository.rrdp_notification_uri)) - self.roa = rpki.x509.ROA.build(self.asn, - rpki.resource_set.roa_prefix_set_ipv4(self.ipv4), - rpki.resource_set.roa_prefix_set_ipv6(self.ipv6), - keypair, - (self.cert,)) - self.published = rpki.sundial.now() - self.save() - - logger.debug("Generating %r URI %s", self, self.uri) - publisher.queue(uri = self.uri, new_obj = self.roa, - repository = self.ca_detail.ca.parent.repository, - handler = self.published_callback) - if not fast: - self.ca_detail.generate_manifest(publisher = publisher) - - - def published_callback(self, pdu): - """ - Check publication result. - """ - - rpki.publication.raise_if_error(pdu) - self.published = None - self.save() - - - def revoke(self, publisher, regenerate = False, allow_failure = False, fast = False): - """ - Withdraw ROA associated with this roa_obj. - - In order to preserve make-before-break properties without - duplicating code, this method also handles generating a - replacement ROA when requested. - - If allow_failure is set, failing to withdraw the ROA will not be - considered an error. - - If fast is set, SQL actions will be deferred, on the assumption - that our caller will handle regenerating CRL and manifest and - flushing the SQL cache. - """ - - ca_detail = self.ca_detail - logger.debug("%s %r, ca_detail %r state is %s", - "Regenerating" if regenerate else "Not regenerating", - self, ca_detail, ca_detail.state) - if regenerate: - self.generate(publisher = publisher, fast = fast) - logger.debug("Withdrawing %r %s and revoking its EE cert", self, self.uri) - RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) - publisher.queue(uri = self.uri, old_obj = self.roa, - repository = ca_detail.ca.parent.repository, - handler = False if allow_failure else None) - if not regenerate: - self.delete() - if not fast: - ca_detail.generate_crl(publisher = publisher) - ca_detail.generate_manifest(publisher = publisher) - - - def regenerate(self, publisher, fast = False): - """ - Reissue ROA associated with this roa_obj. - """ + asn = models.BigIntegerField() + ipv4 = models.TextField(null = True) + ipv6 = models.TextField(null = True) + cert = CertificateField() + roa = ROAField() + published = SundialField(null = True) + tenant = models.ForeignKey(Tenant, related_name = "roas") + ca_detail = models.ForeignKey(CADetail, related_name = "roas") + + + def update(self, publisher, fast = False): + """ + Bring ROA up to date if necesssary. + """ + + if self.roa is None: + logger.debug("%r doesn't exist, generating", self) + return self.generate(publisher = publisher, fast = fast) + + if self.ca_detail is None: + logger.debug("%r has no associated ca_detail, generating", self) + return self.generate(publisher = publisher, fast = fast) + + if self.ca_detail.state != "active": + logger.debug("ca_detail associated with %r not active (state %s), regenerating", self, self.ca_detail.state) + return self.regenerate(publisher = publisher, fast = fast) + + now = rpki.sundial.now() + regen_time = self.cert.getNotAfter() - rpki.sundial.timedelta(seconds = self.tenant.regen_margin) + + if now > regen_time and self.cert.getNotAfter() < self.ca_detail.latest_ca_cert.getNotAfter(): + logger.debug("%r past threshold %s, regenerating", self, regen_time) + return self.regenerate(publisher = publisher, fast = fast) + + if now > regen_time: + logger.warning("%r is past threshold %s but so is issuer %r, can't regenerate", self, regen_time, self.ca_detail) + + ca_resources = self.ca_detail.latest_ca_cert.get_3779resources() + ee_resources = self.cert.get_3779resources() + + if ee_resources.oversized(ca_resources): + logger.debug("%r oversized with respect to CA, regenerating", self) + return self.regenerate(publisher = publisher, fast = fast) + + v4 = rpki.resource_set.resource_set_ipv4(self.ipv4) + v6 = rpki.resource_set.resource_set_ipv6(self.ipv6) + + if ee_resources.v4 != v4 or ee_resources.v6 != v6: + logger.debug("%r resources do not match EE, regenerating", self) + return self.regenerate(publisher = publisher, fast = fast) + + if self.cert.get_AIA()[0] != self.ca_detail.ca_cert_uri: + logger.debug("%r AIA changed, regenerating", self) + return self.regenerate(publisher = publisher, fast = fast) + + + def generate(self, publisher, fast = False): + """ + Generate a ROA. + + At present we have no way of performing a direct lookup from a + desired set of resources to a covering certificate, so we have to + search. This could be quite slow if we have a lot of active + ca_detail objects. Punt on the issue for now, revisit if + profiling shows this as a hotspot. + + Once we have the right covering certificate, we generate the ROA + payload, generate a new EE certificate, use the EE certificate to + sign the ROA payload, publish the result, then throw away the + private key for the EE cert, all per the ROA specification. This + implies that generating a lot of ROAs will tend to thrash + /dev/random, but there is not much we can do about that. + + If fast is set, we leave generating the new manifest for our + caller to handle, presumably at the end of a bulk operation. + """ + + if self.ipv4 is None and self.ipv6 is None: + raise rpki.exceptions.EmptyROAPrefixList + + v4 = rpki.resource_set.resource_set_ipv4(self.ipv4) + v6 = rpki.resource_set.resource_set_ipv6(self.ipv6) + + # http://stackoverflow.com/questions/26270042/how-do-you-catch-this-exception + # "Django is amazing when its not terrifying." + try: + ca_detail = self.ca_detail + except CADetail.DoesNotExist: + ca_detail = None + + if ca_detail is not None and ca_detail.state == "active" and not ca_detail.has_expired(): + logger.debug("Keeping old ca_detail %r for ROA %r", ca_detail, self) + else: + logger.debug("Searching for new ca_detail for ROA %r", self) + for ca_detail in CADetail.objects.filter(ca__parent__tenant = self.tenant, state = "active"): + resources = ca_detail.latest_ca_cert.get_3779resources() + if not ca_detail.has_expired() and v4.issubset(resources.v4) and v6.issubset(resources.v6): + logger.debug("Using new ca_detail %r for ROA %r", ca_detail, self) + self.ca_detail = ca_detail + break + else: + raise rpki.exceptions.NoCoveringCertForROA("Could not find a certificate covering %r" % self) + + resources = rpki.resource_set.resource_bag(v4 = v4, v6 = v6) + keypair = rpki.x509.RSA.generate() + + self.cert = self.ca_detail.issue_ee( + ca = self.ca_detail.ca, + resources = resources, + subject_key = keypair.get_public(), + sia = (None, None, self.uri_from_key(keypair), self.ca_detail.ca.parent.repository.rrdp_notification_uri)) + self.roa = rpki.x509.ROA.build(self.asn, + rpki.resource_set.roa_prefix_set_ipv4(self.ipv4), + rpki.resource_set.roa_prefix_set_ipv6(self.ipv6), + keypair, + (self.cert,)) + self.published = rpki.sundial.now() + self.save() + + logger.debug("Generating %r URI %s", self, self.uri) + publisher.queue(uri = self.uri, new_obj = self.roa, + repository = self.ca_detail.ca.parent.repository, + handler = self.published_callback) + if not fast: + self.ca_detail.generate_manifest(publisher = publisher) + + + def published_callback(self, pdu): + """ + Check publication result. + """ + + rpki.publication.raise_if_error(pdu) + self.published = None + self.save() + + + def revoke(self, publisher, regenerate = False, allow_failure = False, fast = False): + """ + Withdraw ROA associated with this roa_obj. + + In order to preserve make-before-break properties without + duplicating code, this method also handles generating a + replacement ROA when requested. + + If allow_failure is set, failing to withdraw the ROA will not be + considered an error. + + If fast is set, SQL actions will be deferred, on the assumption + that our caller will handle regenerating CRL and manifest and + flushing the SQL cache. + """ + + ca_detail = self.ca_detail + logger.debug("%s %r, ca_detail %r state is %s", + "Regenerating" if regenerate else "Not regenerating", + self, ca_detail, ca_detail.state) + if regenerate: + self.generate(publisher = publisher, fast = fast) + logger.debug("Withdrawing %r %s and revoking its EE cert", self, self.uri) + RevokedCert.revoke(cert = self.cert, ca_detail = ca_detail) + publisher.queue(uri = self.uri, old_obj = self.roa, + repository = ca_detail.ca.parent.repository, + handler = False if allow_failure else None) + if not regenerate: + self.delete() + if not fast: + ca_detail.generate_crl(publisher = publisher) + ca_detail.generate_manifest(publisher = publisher) + + + def regenerate(self, publisher, fast = False): + """ + Reissue ROA associated with this roa_obj. + """ + + if self.ca_detail is None: + self.generate(publisher = publisher, fast = fast) + else: + self.revoke(publisher = publisher, regenerate = True, fast = fast) - if self.ca_detail is None: - self.generate(publisher = publisher, fast = fast) - else: - self.revoke(publisher = publisher, regenerate = True, fast = fast) + def uri_from_key(self, key): + """ + Return publication URI for a public key. + """ - def uri_from_key(self, key): - """ - Return publication URI for a public key. - """ + return self.ca_detail.ca.sia_uri + key.gSKI() + ".roa" - return self.ca_detail.ca.sia_uri + key.gSKI() + ".roa" + @property + def uri(self): + """ + Return the publication URI for this roa_obj's ROA. + """ - @property - def uri(self): - """ - Return the publication URI for this roa_obj's ROA. - """ - - return self.ca_detail.ca.sia_uri + self.uri_tail + return self.ca_detail.ca.sia_uri + self.uri_tail - @property - def uri_tail(self): - """ - Return the tail (filename portion) of the publication URI for this - roa_obj's ROA. - """ + @property + def uri_tail(self): + """ + Return the tail (filename portion) of the publication URI for this + roa_obj's ROA. + """ - return self.cert.gSKI() + ".roa" + return self.cert.gSKI() + ".roa" diff --git a/rpki/rtr/bgpdump.py b/rpki/rtr/bgpdump.py index 5ffabc4d..3336fb9f 100755 --- a/rpki/rtr/bgpdump.py +++ b/rpki/rtr/bgpdump.py @@ -39,292 +39,292 @@ from rpki.rtr.channels import Timestamp class IgnoreThisRecord(Exception): - pass + pass class PrefixPDU(rpki.rtr.generator.PrefixPDU): - @staticmethod - def from_bgpdump(line, rib_dump): - try: - assert isinstance(rib_dump, bool) - fields = line.split("|") - - # Parse prefix, including figuring out IP protocol version - cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU - self = cls() - self.timestamp = Timestamp(fields[1]) - p, l = fields[5].split("/") - self.prefix = rpki.POW.IPAddress(p) - self.prefixlen = self.max_prefixlen = int(l) - - # Withdrawals don't have AS paths, so be careful - assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W") - if fields[2] == "W": - self.asn = 0 - self.announce = 0 - else: - self.announce = 1 - if not fields[6] or "{" in fields[6] or "(" in fields[6]: - raise IgnoreThisRecord - a = fields[6].split()[-1] - if "." in a: - a = [int(s) for s in a.split(".")] - if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535: - logging.warn("Bad dotted ASNum %r, ignoring record", fields[6]) + @staticmethod + def from_bgpdump(line, rib_dump): + try: + assert isinstance(rib_dump, bool) + fields = line.split("|") + + # Parse prefix, including figuring out IP protocol version + cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU + self = cls() + self.timestamp = Timestamp(fields[1]) + p, l = fields[5].split("/") + self.prefix = rpki.POW.IPAddress(p) + self.prefixlen = self.max_prefixlen = int(l) + + # Withdrawals don't have AS paths, so be careful + assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W") + if fields[2] == "W": + self.asn = 0 + self.announce = 0 + else: + self.announce = 1 + if not fields[6] or "{" in fields[6] or "(" in fields[6]: + raise IgnoreThisRecord + a = fields[6].split()[-1] + if "." in a: + a = [int(s) for s in a.split(".")] + if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535: + logging.warn("Bad dotted ASNum %r, ignoring record", fields[6]) + raise IgnoreThisRecord + a = (a[0] << 16) | a[1] + else: + a = int(a) + self.asn = a + + self.check() + return self + + except IgnoreThisRecord: + raise + + except Exception, e: + logging.warn("Ignoring line %r: %s", line, e) raise IgnoreThisRecord - a = (a[0] << 16) | a[1] - else: - a = int(a) - self.asn = a - self.check() - return self - except IgnoreThisRecord: - raise +class AXFRSet(rpki.rtr.generator.AXFRSet): - except Exception, e: - logging.warn("Ignoring line %r: %s", line, e) - raise IgnoreThisRecord + @staticmethod + def read_bgpdump(filename): + assert filename.endswith(".bz2") + logging.debug("Reading %s", filename) + bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE) + bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE) + return bgpdump.stdout + + @classmethod + def parse_bgpdump_rib_dump(cls, filename): + assert os.path.basename(filename).startswith("ribs.") + self = cls() + self.serial = None + for line in cls.read_bgpdump(filename): + try: + pfx = PrefixPDU.from_bgpdump(line, rib_dump = True) + except IgnoreThisRecord: + continue + self.append(pfx) + self.serial = pfx.timestamp + if self.serial is None: + sys.exit("Failed to parse anything useful from %s" % filename) + self.sort() + for i in xrange(len(self) - 2, -1, -1): + if self[i] == self[i + 1]: + del self[i + 1] + return self + + def parse_bgpdump_update(self, filename): + assert os.path.basename(filename).startswith("updates.") + for line in self.read_bgpdump(filename): + try: + pfx = PrefixPDU.from_bgpdump(line, rib_dump = False) + except IgnoreThisRecord: + continue + announce = pfx.announce + pfx.announce = 1 + i = bisect.bisect_left(self, pfx) + if announce: + if i >= len(self) or pfx != self[i]: + self.insert(i, pfx) + else: + while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen: + del self[i] + self.serial = pfx.timestamp -class AXFRSet(rpki.rtr.generator.AXFRSet): +def bgpdump_convert_main(args): + """ + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + Simulate route origin data from a set of BGP dump files. + argv is an ordered list of filenames. Each file must be a BGP RIB + dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by + this program's --cronjob command. The first file must be a RIB dump + or AXFR dump, it cannot be an UPDATE dump. Output will be a set of + AXFR and IXFR files with timestamps derived from the BGP dumps, + which can be used as input to this program's --server command for + test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL. + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + """ + + first = True + db = None + axfrs = [] + version = max(rpki.rtr.pdus.PDU.version_map.iterkeys()) + + for filename in args.files: + + if ".ax.v" in filename: + logging.debug("Reading %s", filename) + db = AXFRSet.load(filename) + + elif os.path.basename(filename).startswith("ribs."): + db = AXFRSet.parse_bgpdump_rib_dump(filename) + db.save_axfr() + + elif not first: + assert db is not None + db.parse_bgpdump_update(filename) + db.save_axfr() - @staticmethod - def read_bgpdump(filename): - assert filename.endswith(".bz2") - logging.debug("Reading %s", filename) - bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE) - bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE) - return bgpdump.stdout - - @classmethod - def parse_bgpdump_rib_dump(cls, filename): - assert os.path.basename(filename).startswith("ribs.") - self = cls() - self.serial = None - for line in cls.read_bgpdump(filename): - try: - pfx = PrefixPDU.from_bgpdump(line, rib_dump = True) - except IgnoreThisRecord: - continue - self.append(pfx) - self.serial = pfx.timestamp - if self.serial is None: - sys.exit("Failed to parse anything useful from %s" % filename) - self.sort() - for i in xrange(len(self) - 2, -1, -1): - if self[i] == self[i + 1]: - del self[i + 1] - return self - - def parse_bgpdump_update(self, filename): - assert os.path.basename(filename).startswith("updates.") - for line in self.read_bgpdump(filename): - try: - pfx = PrefixPDU.from_bgpdump(line, rib_dump = False) - except IgnoreThisRecord: - continue - announce = pfx.announce - pfx.announce = 1 - i = bisect.bisect_left(self, pfx) - if announce: - if i >= len(self) or pfx != self[i]: - self.insert(i, pfx) - else: - while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen: - del self[i] - self.serial = pfx.timestamp + else: + sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename) + logging.debug("DB serial now %d (%s)", db.serial, db.serial) + if first and rpki.rtr.server.read_current(version) == (None, None): + db.mark_current() + first = False -def bgpdump_convert_main(args): - """ - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - Simulate route origin data from a set of BGP dump files. - argv is an ordered list of filenames. Each file must be a BGP RIB - dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by - this program's --cronjob command. The first file must be a RIB dump - or AXFR dump, it cannot be an UPDATE dump. Output will be a set of - AXFR and IXFR files with timestamps derived from the BGP dumps, - which can be used as input to this program's --server command for - test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL. - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - """ - - first = True - db = None - axfrs = [] - version = max(rpki.rtr.pdus.PDU.version_map.iterkeys()) - - for filename in args.files: - - if ".ax.v" in filename: - logging.debug("Reading %s", filename) - db = AXFRSet.load(filename) - - elif os.path.basename(filename).startswith("ribs."): - db = AXFRSet.parse_bgpdump_rib_dump(filename) - db.save_axfr() - - elif not first: - assert db is not None - db.parse_bgpdump_update(filename) - db.save_axfr() - - else: - sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename) - - logging.debug("DB serial now %d (%s)", db.serial, db.serial) - if first and rpki.rtr.server.read_current(version) == (None, None): - db.mark_current() - first = False - - for axfr in axfrs: - logging.debug("Loading %s", axfr) - ax = AXFRSet.load(axfr) - logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial) - db.save_ixfr(ax) - del ax - - axfrs.append(db.filename()) + for axfr in axfrs: + logging.debug("Loading %s", axfr) + ax = AXFRSet.load(axfr) + logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial) + db.save_ixfr(ax) + del ax + + axfrs.append(db.filename()) def bgpdump_select_main(args): - """ - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - Simulate route origin data from a set of BGP dump files. - Set current serial number to correspond to an .ax file created by - converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL. - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - """ + """ + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + Simulate route origin data from a set of BGP dump files. + Set current serial number to correspond to an .ax file created by + converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL. + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + """ - head, sep, tail = os.path.basename(args.ax_file).partition(".") - if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit(): - sys.exit("Argument must be name of a .ax file") + head, sep, tail = os.path.basename(args.ax_file).partition(".") + if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit(): + sys.exit("Argument must be name of a .ax file") - serial = Timestamp(head) - version = int(tail[4:]) + serial = Timestamp(head) + version = int(tail[4:]) - if version not in rpki.rtr.pdus.PDU.version_map: - sys.exit("Unknown protocol version %d" % version) + if version not in rpki.rtr.pdus.PDU.version_map: + sys.exit("Unknown protocol version %d" % version) - nonce = rpki.rtr.server.read_current(version)[1] - if nonce is None: - nonce = rpki.rtr.generator.new_nonce() + nonce = rpki.rtr.server.read_current(version)[1] + if nonce is None: + nonce = rpki.rtr.generator.new_nonce() - rpki.rtr.server.write_current(serial, nonce, version) - rpki.rtr.generator.kick_all(serial) + rpki.rtr.server.write_current(serial, nonce, version) + rpki.rtr.generator.kick_all(serial) class BGPDumpReplayClock(object): - """ - Internal clock for replaying BGP dump files. + """ + Internal clock for replaying BGP dump files. - * DANGER WILL ROBINSON! * - * DEBUGGING AND TEST USE ONLY! * + * DANGER WILL ROBINSON! * + * DEBUGGING AND TEST USE ONLY! * - This class replaces the normal on-disk serial number mechanism with - an in-memory version based on pre-computed data. + This class replaces the normal on-disk serial number mechanism with + an in-memory version based on pre-computed data. - bgpdump_server_main() uses this hack to replay historical data for - testing purposes. DO NOT USE THIS IN PRODUCTION. + bgpdump_server_main() uses this hack to replay historical data for + testing purposes. DO NOT USE THIS IN PRODUCTION. - You have been warned. - """ + You have been warned. + """ - def __init__(self): - self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")] - self.timestamps.sort() - self.offset = self.timestamps[0] - int(time.time()) - self.nonce = rpki.rtr.generator.new_nonce() + def __init__(self): + self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")] + self.timestamps.sort() + self.offset = self.timestamps[0] - int(time.time()) + self.nonce = rpki.rtr.generator.new_nonce() - def __nonzero__(self): - return len(self.timestamps) > 0 + def __nonzero__(self): + return len(self.timestamps) > 0 - def now(self): - return Timestamp.now(self.offset) + def now(self): + return Timestamp.now(self.offset) - def read_current(self, version): - now = self.now() - while len(self.timestamps) > 1 and now >= self.timestamps[1]: - del self.timestamps[0] - return self.timestamps[0], self.nonce + def read_current(self, version): + now = self.now() + while len(self.timestamps) > 1 and now >= self.timestamps[1]: + del self.timestamps[0] + return self.timestamps[0], self.nonce - def siesta(self): - now = self.now() - if len(self.timestamps) <= 1: - return None - elif now < self.timestamps[1]: - return self.timestamps[1] - now - else: - return 1 + def siesta(self): + now = self.now() + if len(self.timestamps) <= 1: + return None + elif now < self.timestamps[1]: + return self.timestamps[1] - now + else: + return 1 def bgpdump_server_main(args): - """ - Simulate route origin data from a set of BGP dump files. + """ + Simulate route origin data from a set of BGP dump files. + + * DANGER WILL ROBINSON! * + * DEBUGGING AND TEST USE ONLY! * + + This is a clone of server_main() which replaces the external serial + number updates triggered via the kickme channel by cronjob_main with + an internal clocking mechanism to replay historical test data. - * DANGER WILL ROBINSON! * - * DEBUGGING AND TEST USE ONLY! * + DO NOT USE THIS IN PRODUCTION. - This is a clone of server_main() which replaces the external serial - number updates triggered via the kickme channel by cronjob_main with - an internal clocking mechanism to replay historical test data. + You have been warned. + """ - DO NOT USE THIS IN PRODUCTION. + logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag())) - You have been warned. - """ + logger.debug("[Starting]") - logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag())) + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) - logger.debug("[Starting]") + # Yes, this really does replace a global function defined in another + # module with a bound method to our clock object. Fun stuff, huh? + # + clock = BGPDumpReplayClock() + rpki.rtr.server.read_current = clock.read_current - if args.rpki_rtr_dir: try: - os.chdir(args.rpki_rtr_dir) - except OSError, e: - sys.exit(e) - - # Yes, this really does replace a global function defined in another - # module with a bound method to our clock object. Fun stuff, huh? - # - clock = BGPDumpReplayClock() - rpki.rtr.server.read_current = clock.read_current - - try: - server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) - old_serial = server.get_serial() - logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial) - while clock: - new_serial = server.get_serial() - if old_serial != new_serial: - logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial) - server.notify() - old_serial = new_serial - asyncore.loop(timeout = clock.siesta(), count = 1) - except KeyboardInterrupt: - sys.exit(0) + server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) + old_serial = server.get_serial() + logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial) + while clock: + new_serial = server.get_serial() + if old_serial != new_serial: + logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial) + server.notify() + old_serial = new_serial + asyncore.loop(timeout = clock.siesta(), count = 1) + except KeyboardInterrupt: + sys.exit(0) def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__, - help = "Convert bgpdump to fake ROAs") - subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog") - subparser.add_argument("files", nargs = "+", help = "input files") - - subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__, - help = "Set current serial number for fake ROA data") - subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog") - subparser.add_argument("ax_file", help = "name of the .ax to select") - - subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__, - help = "Replay fake ROAs generated from historical data") - subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__, + help = "Convert bgpdump to fake ROAs") + subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog") + subparser.add_argument("files", nargs = "+", help = "input files") + + subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__, + help = "Set current serial number for fake ROA data") + subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog") + subparser.add_argument("ax_file", help = "name of the .ax to select") + + subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__, + help = "Replay fake ROAs generated from historical data") + subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") diff --git a/rpki/rtr/channels.py b/rpki/rtr/channels.py index d14c024d..e2f443e8 100644 --- a/rpki/rtr/channels.py +++ b/rpki/rtr/channels.py @@ -32,215 +32,215 @@ import rpki.rtr.pdus class Timestamp(int): - """ - Wrapper around time module. - """ - - def __new__(cls, t): - # __new__() is a static method, not a class method, hence the odd calling sequence. - return super(Timestamp, cls).__new__(cls, t) - - @classmethod - def now(cls, delta = 0): - return cls(time.time() + delta) - - def __str__(self): - return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self)) - - -class ReadBuffer(object): - """ - Wrapper around synchronous/asynchronous read state. - - This also handles tracking the current protocol version, - because it has to go somewhere and there's no better place. - """ - - def __init__(self): - self.buffer = "" - self.version = None - - def update(self, need, callback): """ - Update count of needed bytes and callback, then dispatch to callback. + Wrapper around time module. """ - self.need = need - self.callback = callback - return self.retry() + def __new__(cls, t): + # __new__() is a static method, not a class method, hence the odd calling sequence. + return super(Timestamp, cls).__new__(cls, t) - def retry(self): - """ - Try dispatching to the callback again. - """ + @classmethod + def now(cls, delta = 0): + return cls(time.time() + delta) - return self.callback(self) + def __str__(self): + return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self)) - def available(self): - """ - How much data do we have available in this buffer? - """ - return len(self.buffer) - - def needed(self): - """ - How much more data does this buffer need to become ready? +class ReadBuffer(object): """ + Wrapper around synchronous/asynchronous read state. - return self.need - self.available() - - def ready(self): - """ - Is this buffer ready to read yet? + This also handles tracking the current protocol version, + because it has to go somewhere and there's no better place. """ - return self.available() >= self.need + def __init__(self): + self.buffer = "" + self.version = None - def get(self, n): - """ - Hand some data to the caller. - """ + def update(self, need, callback): + """ + Update count of needed bytes and callback, then dispatch to callback. + """ - b = self.buffer[:n] - self.buffer = self.buffer[n:] - return b + self.need = need + self.callback = callback + return self.retry() - def put(self, b): - """ - Accumulate some data. - """ + def retry(self): + """ + Try dispatching to the callback again. + """ - self.buffer += b + return self.callback(self) - def check_version(self, version): - """ - Track version number of PDUs read from this buffer. - Once set, the version must not change. - """ + def available(self): + """ + How much data do we have available in this buffer? + """ - if self.version is not None and version != self.version: - raise rpki.rtr.pdus.CorruptData( - "Received PDU version %d, expected %d" % (version, self.version)) - if self.version is None and version not in rpki.rtr.pdus.PDU.version_map: - raise rpki.rtr.pdus.UnsupportedProtocolVersion( - "Received PDU version %s, known versions %s" % ( - version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map))) - self.version = version + return len(self.buffer) + def needed(self): + """ + How much more data does this buffer need to become ready? + """ -class PDUChannel(asynchat.async_chat, object): - """ - asynchat subclass that understands our PDUs. This just handles - network I/O. Specific engines (client, server) should be subclasses - of this with methods that do something useful with the resulting - PDUs. - """ - - def __init__(self, root_pdu_class, sock = None): - asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super() - self.reader = ReadBuffer() - assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU) - self.root_pdu_class = root_pdu_class - - @property - def version(self): - return self.reader.version - - @version.setter - def version(self, version): - self.reader.check_version(version) - - def start_new_pdu(self): - """ - Start read of a new PDU. - """ - - try: - p = self.root_pdu_class.read_pdu(self.reader) - while p is not None: - self.deliver_pdu(p) - p = self.root_pdu_class.read_pdu(self.reader) - except rpki.rtr.pdus.PDUException, e: - self.push_pdu(e.make_error_report(version = self.version)) - self.close_when_done() - else: - assert not self.reader.ready() - self.set_terminator(self.reader.needed()) - - def collect_incoming_data(self, data): - """ - Collect data into the read buffer. - """ - - self.reader.put(data) - - def found_terminator(self): - """ - Got requested data, see if we now have a PDU. If so, pass it - along, then restart cycle for a new PDU. - """ - - p = self.reader.retry() - if p is None: - self.set_terminator(self.reader.needed()) - else: - self.deliver_pdu(p) - self.start_new_pdu() - - def push_pdu(self, pdu): - """ - Write PDU to stream. - """ + return self.need - self.available() - try: - self.push(pdu.to_pdu()) - except OSError, e: - if e.errno != errno.EAGAIN: - raise + def ready(self): + """ + Is this buffer ready to read yet? + """ - def log(self, msg): - """ - Intercept asyncore's logging. - """ + return self.available() >= self.need - logging.info(msg) + def get(self, n): + """ + Hand some data to the caller. + """ - def log_info(self, msg, tag = "info"): - """ - Intercept asynchat's logging. - """ + b = self.buffer[:n] + self.buffer = self.buffer[n:] + return b - logging.info("asynchat: %s: %s", tag, msg) + def put(self, b): + """ + Accumulate some data. + """ - def handle_error(self): - """ - Handle errors caught by asyncore main loop. - """ + self.buffer += b - logging.exception("[Unhandled exception]") - logging.critical("[Exiting after unhandled exception]") - sys.exit(1) + def check_version(self, version): + """ + Track version number of PDUs read from this buffer. + Once set, the version must not change. + """ - def init_file_dispatcher(self, fd): - """ - Kludge to plug asyncore.file_dispatcher into asynchat. Call from - subclass's __init__() method, after calling - PDUChannel.__init__(), and don't read this on a full stomach. - """ + if self.version is not None and version != self.version: + raise rpki.rtr.pdus.CorruptData( + "Received PDU version %d, expected %d" % (version, self.version)) + if self.version is None and version not in rpki.rtr.pdus.PDU.version_map: + raise rpki.rtr.pdus.UnsupportedProtocolVersion( + "Received PDU version %s, known versions %s" % ( + version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map))) + self.version = version - self.connected = True - self._fileno = fd - self.socket = asyncore.file_wrapper(fd) - self.add_channel() - flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(fd, fcntl.F_SETFL, flags) - def handle_close(self): - """ - Exit when channel closed. +class PDUChannel(asynchat.async_chat, object): """ - - asynchat.async_chat.handle_close(self) - sys.exit(0) + asynchat subclass that understands our PDUs. This just handles + network I/O. Specific engines (client, server) should be subclasses + of this with methods that do something useful with the resulting + PDUs. + """ + + def __init__(self, root_pdu_class, sock = None): + asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super() + self.reader = ReadBuffer() + assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU) + self.root_pdu_class = root_pdu_class + + @property + def version(self): + return self.reader.version + + @version.setter + def version(self, version): + self.reader.check_version(version) + + def start_new_pdu(self): + """ + Start read of a new PDU. + """ + + try: + p = self.root_pdu_class.read_pdu(self.reader) + while p is not None: + self.deliver_pdu(p) + p = self.root_pdu_class.read_pdu(self.reader) + except rpki.rtr.pdus.PDUException, e: + self.push_pdu(e.make_error_report(version = self.version)) + self.close_when_done() + else: + assert not self.reader.ready() + self.set_terminator(self.reader.needed()) + + def collect_incoming_data(self, data): + """ + Collect data into the read buffer. + """ + + self.reader.put(data) + + def found_terminator(self): + """ + Got requested data, see if we now have a PDU. If so, pass it + along, then restart cycle for a new PDU. + """ + + p = self.reader.retry() + if p is None: + self.set_terminator(self.reader.needed()) + else: + self.deliver_pdu(p) + self.start_new_pdu() + + def push_pdu(self, pdu): + """ + Write PDU to stream. + """ + + try: + self.push(pdu.to_pdu()) + except OSError, e: + if e.errno != errno.EAGAIN: + raise + + def log(self, msg): + """ + Intercept asyncore's logging. + """ + + logging.info(msg) + + def log_info(self, msg, tag = "info"): + """ + Intercept asynchat's logging. + """ + + logging.info("asynchat: %s: %s", tag, msg) + + def handle_error(self): + """ + Handle errors caught by asyncore main loop. + """ + + logging.exception("[Unhandled exception]") + logging.critical("[Exiting after unhandled exception]") + sys.exit(1) + + def init_file_dispatcher(self, fd): + """ + Kludge to plug asyncore.file_dispatcher into asynchat. Call from + subclass's __init__() method, after calling + PDUChannel.__init__(), and don't read this on a full stomach. + """ + + self.connected = True + self._fileno = fd + self.socket = asyncore.file_wrapper(fd) + self.add_channel() + flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + def handle_close(self): + """ + Exit when channel closed. + """ + + asynchat.async_chat.handle_close(self) + sys.exit(0) diff --git a/rpki/rtr/client.py b/rpki/rtr/client.py index a35ab81d..9c7a00d6 100644 --- a/rpki/rtr/client.py +++ b/rpki/rtr/client.py @@ -37,13 +37,13 @@ from rpki.rtr.channels import Timestamp class PDU(rpki.rtr.pdus.PDU): - def consume(self, client): - """ - Handle results in test client. Default behavior is just to print - out the PDU; data PDU subclasses may override this. - """ + def consume(self, client): + """ + Handle results in test client. Default behavior is just to print + out the PDU; data PDU subclasses may override this. + """ - logging.debug(self) + logging.debug(self) clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU) @@ -52,407 +52,407 @@ clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU) @clone_pdu class SerialNotifyPDU(rpki.rtr.pdus.SerialNotifyPDU): - def consume(self, client): - """ - Respond to a SerialNotifyPDU with either a SerialQueryPDU or a - ResetQueryPDU, depending on what we already know. - """ + def consume(self, client): + """ + Respond to a SerialNotifyPDU with either a SerialQueryPDU or a + ResetQueryPDU, depending on what we already know. + """ - logging.debug(self) - if client.serial is None or client.nonce != self.nonce: - client.push_pdu(ResetQueryPDU(version = client.version)) - elif self.serial != client.serial: - client.push_pdu(SerialQueryPDU(version = client.version, - serial = client.serial, - nonce = client.nonce)) - else: - logging.debug("[Notify did not change serial number, ignoring]") + logging.debug(self) + if client.serial is None or client.nonce != self.nonce: + client.push_pdu(ResetQueryPDU(version = client.version)) + elif self.serial != client.serial: + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.serial, + nonce = client.nonce)) + else: + logging.debug("[Notify did not change serial number, ignoring]") @clone_pdu class CacheResponsePDU(rpki.rtr.pdus.CacheResponsePDU): - def consume(self, client): - """ - Handle CacheResponsePDU. - """ + def consume(self, client): + """ + Handle CacheResponsePDU. + """ - logging.debug(self) - if self.nonce != client.nonce: - logging.debug("[Nonce changed, resetting]") - client.cache_reset() + logging.debug(self) + if self.nonce != client.nonce: + logging.debug("[Nonce changed, resetting]") + client.cache_reset() @clone_pdu class EndOfDataPDUv0(rpki.rtr.pdus.EndOfDataPDUv0): - def consume(self, client): - """ - Handle EndOfDataPDU response. - """ + def consume(self, client): + """ + Handle EndOfDataPDU response. + """ - logging.debug(self) - client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + logging.debug(self) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) @clone_pdu class EndOfDataPDUv1(rpki.rtr.pdus.EndOfDataPDUv1): - def consume(self, client): - """ - Handle EndOfDataPDU response. - """ + def consume(self, client): + """ + Handle EndOfDataPDU response. + """ - logging.debug(self) - client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + logging.debug(self) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) @clone_pdu class CacheResetPDU(rpki.rtr.pdus.CacheResetPDU): - def consume(self, client): - """ - Handle CacheResetPDU response, by issuing a ResetQueryPDU. - """ + def consume(self, client): + """ + Handle CacheResetPDU response, by issuing a ResetQueryPDU. + """ - logging.debug(self) - client.cache_reset() - client.push_pdu(ResetQueryPDU(version = client.version)) + logging.debug(self) + client.cache_reset() + client.push_pdu(ResetQueryPDU(version = client.version)) class PrefixPDU(rpki.rtr.pdus.PrefixPDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - def consume(self, client): """ - Handle one incoming prefix PDU + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. """ - logging.debug(self) - client.consume_prefix(self) + def consume(self, client): + """ + Handle one incoming prefix PDU + """ + + logging.debug(self) + client.consume_prefix(self) @clone_pdu class IPv4PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv4PrefixPDU): - pass + pass @clone_pdu class IPv6PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv6PrefixPDU): - pass + pass @clone_pdu class ErrorReportPDU(PDU, rpki.rtr.pdus.ErrorReportPDU): - pass + pass @clone_pdu class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU): - """ - Router Key PDU. - """ - - def consume(self, client): """ - Handle one incoming Router Key PDU + Router Key PDU. """ - logging.debug(self) - client.consume_routerkey(self) + def consume(self, client): + """ + Handle one incoming Router Key PDU + """ + logging.debug(self) + client.consume_routerkey(self) -class ClientChannel(rpki.rtr.channels.PDUChannel): - """ - Client protocol engine, handles upcalls from PDUChannel. - """ - - serial = None - nonce = None - sql = None - host = None - port = None - cache_id = None - refresh = rpki.rtr.pdus.default_refresh - retry = rpki.rtr.pdus.default_retry - expire = rpki.rtr.pdus.default_expire - updated = Timestamp(0) - - def __init__(self, sock, proc, killsig, args, host = None, port = None): - self.killsig = killsig - self.proc = proc - self.args = args - self.host = args.host if host is None else host - self.port = args.port if port is None else port - super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU) - if args.force_version is not None: - self.version = args.force_version - self.start_new_pdu() - if args.sql_database: - self.setup_sql() - - @classmethod - def ssh(cls, args): - """ - Set up ssh connection and start listening for first PDU. - """ - if args.port is None: - argv = ("ssh", "-s", args.host, "rpki-rtr") - else: - argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr") - logging.debug("[Running ssh: %s]", " ".join(argv)) - s = socket.socketpair() - return cls(sock = s[1], - proc = subprocess.Popen(argv, executable = "/usr/bin/ssh", - stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGKILL, args = args) - - @classmethod - def tcp(cls, args): - """ - Set up TCP connection and start listening for first PDU. +class ClientChannel(rpki.rtr.channels.PDUChannel): """ - - logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port) - try: - addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.error, e: - logging.debug("[socket.getaddrinfo() failed: %s]", e) - else: - for ai in addrinfo: - af, socktype, proto, cn, sa = ai # pylint: disable=W0612 - logging.debug("[Trying addr %s port %s]", sa[0], sa[1]) + Client protocol engine, handles upcalls from PDUChannel. + """ + + serial = None + nonce = None + sql = None + host = None + port = None + cache_id = None + refresh = rpki.rtr.pdus.default_refresh + retry = rpki.rtr.pdus.default_retry + expire = rpki.rtr.pdus.default_expire + updated = Timestamp(0) + + def __init__(self, sock, proc, killsig, args, host = None, port = None): + self.killsig = killsig + self.proc = proc + self.args = args + self.host = args.host if host is None else host + self.port = args.port if port is None else port + super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU) + if args.force_version is not None: + self.version = args.force_version + self.start_new_pdu() + if args.sql_database: + self.setup_sql() + + @classmethod + def ssh(cls, args): + """ + Set up ssh connection and start listening for first PDU. + """ + + if args.port is None: + argv = ("ssh", "-s", args.host, "rpki-rtr") + else: + argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr") + logging.debug("[Running ssh: %s]", " ".join(argv)) + s = socket.socketpair() + return cls(sock = s[1], + proc = subprocess.Popen(argv, executable = "/usr/bin/ssh", + stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, args = args) + + @classmethod + def tcp(cls, args): + """ + Set up TCP connection and start listening for first PDU. + """ + + logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port) try: - s = socket.socket(af, socktype, proto) + addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM) except socket.error, e: - logging.debug("[socket.socket() failed: %s]", e) - continue + logging.debug("[socket.getaddrinfo() failed: %s]", e) + else: + for ai in addrinfo: + af, socktype, proto, cn, sa = ai # pylint: disable=W0612 + logging.debug("[Trying addr %s port %s]", sa[0], sa[1]) + try: + s = socket.socket(af, socktype, proto) + except socket.error, e: + logging.debug("[socket.socket() failed: %s]", e) + continue + try: + s.connect(sa) + except socket.error, e: + logging.exception("[socket.connect() failed: %s]", e) + s.close() + continue + return cls(sock = s, proc = None, killsig = None, args = args) + sys.exit(1) + + @classmethod + def loopback(cls, args): + """ + Set up loopback connection and start listening for first PDU. + """ + + s = socket.socketpair() + logging.debug("[Using direct subprocess kludge for testing]") + argv = (sys.executable, sys.argv[0], "server") + return cls(sock = s[1], + proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGINT, args = args, + host = args.host or "none", port = args.port or "none") + + @classmethod + def tls(cls, args): + """ + Set up TLS connection and start listening for first PDU. + + NB: This uses OpenSSL's "s_client" command, which does not + check server certificates properly, so this is not suitable for + production use. Fixing this would be a trivial change, it just + requires using a client program which does check certificates + properly (eg, gnutls-cli, or stunnel's client mode if that works + for such purposes this week). + """ + + argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port)) + logging.debug("[Running: %s]", " ".join(argv)) + s = socket.socketpair() + return cls(sock = s[1], + proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, args = args) + + def setup_sql(self): + """ + Set up an SQLite database to contain the table we receive. If + necessary, we will create the database. + """ + + import sqlite3 + missing = not os.path.exists(self.args.sql_database) + self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES) + self.sql.text_factory = str + cur = self.sql.cursor() + cur.execute("PRAGMA foreign_keys = on") + if missing: + cur.execute(''' + CREATE TABLE cache ( + cache_id INTEGER PRIMARY KEY NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + version INTEGER, + nonce INTEGER, + serial INTEGER, + updated INTEGER, + refresh INTEGER, + retry INTEGER, + expire INTEGER, + UNIQUE (host, port))''') + cur.execute(''' + CREATE TABLE prefix ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + prefix TEXT NOT NULL, + prefixlen INTEGER NOT NULL, + max_prefixlen INTEGER NOT NULL, + UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''') + cur.execute(''' + CREATE TABLE routerkey ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + ski TEXT NOT NULL, + key TEXT NOT NULL, + UNIQUE (cache_id, asn, ski), + UNIQUE (cache_id, asn, key))''') + elif self.args.reset_session: + cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) + cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated " + "FROM cache WHERE host = ? AND port = ?", + (self.host, self.port)) try: - s.connect(sa) - except socket.error, e: - logging.exception("[socket.connect() failed: %s]", e) - s.close() - continue - return cls(sock = s, proc = None, killsig = None, args = args) - sys.exit(1) - - @classmethod - def loopback(cls, args): - """ - Set up loopback connection and start listening for first PDU. - """ - - s = socket.socketpair() - logging.debug("[Using direct subprocess kludge for testing]") - argv = (sys.executable, sys.argv[0], "server") - return cls(sock = s[1], - proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGINT, args = args, - host = args.host or "none", port = args.port or "none") - - @classmethod - def tls(cls, args): - """ - Set up TLS connection and start listening for first PDU. - - NB: This uses OpenSSL's "s_client" command, which does not - check server certificates properly, so this is not suitable for - production use. Fixing this would be a trivial change, it just - requires using a client program which does check certificates - properly (eg, gnutls-cli, or stunnel's client mode if that works - for such purposes this week). - """ - - argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port)) - logging.debug("[Running: %s]", " ".join(argv)) - s = socket.socketpair() - return cls(sock = s[1], - proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGKILL, args = args) - - def setup_sql(self): - """ - Set up an SQLite database to contain the table we receive. If - necessary, we will create the database. - """ - - import sqlite3 - missing = not os.path.exists(self.args.sql_database) - self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES) - self.sql.text_factory = str - cur = self.sql.cursor() - cur.execute("PRAGMA foreign_keys = on") - if missing: - cur.execute(''' - CREATE TABLE cache ( - cache_id INTEGER PRIMARY KEY NOT NULL, - host TEXT NOT NULL, - port TEXT NOT NULL, - version INTEGER, - nonce INTEGER, - serial INTEGER, - updated INTEGER, - refresh INTEGER, - retry INTEGER, - expire INTEGER, - UNIQUE (host, port))''') - cur.execute(''' - CREATE TABLE prefix ( - cache_id INTEGER NOT NULL - REFERENCES cache(cache_id) - ON DELETE CASCADE - ON UPDATE CASCADE, - asn INTEGER NOT NULL, - prefix TEXT NOT NULL, - prefixlen INTEGER NOT NULL, - max_prefixlen INTEGER NOT NULL, - UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''') - cur.execute(''' - CREATE TABLE routerkey ( - cache_id INTEGER NOT NULL - REFERENCES cache(cache_id) - ON DELETE CASCADE - ON UPDATE CASCADE, - asn INTEGER NOT NULL, - ski TEXT NOT NULL, - key TEXT NOT NULL, - UNIQUE (cache_id, asn, ski), - UNIQUE (cache_id, asn, key))''') - elif self.args.reset_session: - cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) - cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated " - "FROM cache WHERE host = ? AND port = ?", - (self.host, self.port)) - try: - self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone() - if version is not None and self.version is not None and version != self.version: - cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) - raise TypeError # Simulate lookup failure case - if version is not None: - self.version = version - if refresh is not None: + self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone() + if version is not None and self.version is not None and version != self.version: + cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) + raise TypeError # Simulate lookup failure case + if version is not None: + self.version = version + if refresh is not None: + self.refresh = refresh + if retry is not None: + self.retry = retry + if expire is not None: + self.expire = expire + if updated is not None: + self.updated = Timestamp(updated) + except TypeError: + cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) + self.cache_id = cur.lastrowid + self.sql.commit() + logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]", + self.cache_id, self.version, self.nonce, + self.serial, self.refresh, self.retry, self.expire, self.updated) + + def cache_reset(self): + """ + Handle CacheResetPDU actions. + """ + + self.serial = None + if self.sql: + cur = self.sql.cursor() + cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) + cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,)) + cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id)) + self.sql.commit() + + def end_of_data(self, version, serial, nonce, refresh, retry, expire): + """ + Handle EndOfDataPDU actions. + """ + + assert version == self.version + self.serial = serial + self.nonce = nonce self.refresh = refresh - if retry is not None: - self.retry = retry - if expire is not None: - self.expire = expire - if updated is not None: - self.updated = Timestamp(updated) - except TypeError: - cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) - self.cache_id = cur.lastrowid - self.sql.commit() - logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]", - self.cache_id, self.version, self.nonce, - self.serial, self.refresh, self.retry, self.expire, self.updated) - - def cache_reset(self): - """ - Handle CacheResetPDU actions. - """ - - self.serial = None - if self.sql: - cur = self.sql.cursor() - cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) - cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,)) - cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id)) - self.sql.commit() - - def end_of_data(self, version, serial, nonce, refresh, retry, expire): - """ - Handle EndOfDataPDU actions. - """ - - assert version == self.version - self.serial = serial - self.nonce = nonce - self.refresh = refresh - self.retry = retry - self.expire = expire - self.updated = Timestamp.now() - if self.sql: - self.sql.execute("UPDATE cache SET" - " version = ?, serial = ?, nonce = ?," - " refresh = ?, retry = ?, expire = ?," - " updated = ? " - "WHERE cache_id = ?", - (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id)) - self.sql.commit() - - def consume_prefix(self, prefix): - """ - Handle one prefix PDU. - """ - - if self.sql: - values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) - if prefix.announce: - self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) " - "VALUES (?, ?, ?, ?, ?)", - values) - else: - self.sql.execute("DELETE FROM prefix " - "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?", - values) - - def consume_routerkey(self, routerkey): - """ - Handle one Router Key PDU. - """ - - if self.sql: - values = (self.cache_id, routerkey.asn, - base64.urlsafe_b64encode(routerkey.ski).rstrip("="), - base64.b64encode(routerkey.key)) - if routerkey.announce: - self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) " - "VALUES (?, ?, ?, ?)", - values) - else: - self.sql.execute("DELETE FROM routerkey " - "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)", - values) - - def deliver_pdu(self, pdu): - """ - Handle received PDU. - """ - - pdu.consume(self) - - def push_pdu(self, pdu): - """ - Log outbound PDU then write it to stream. - """ - - logging.debug(pdu) - super(ClientChannel, self).push_pdu(pdu) - - def cleanup(self): - """ - Force clean up this client's child process. If everything goes - well, child will have exited already before this method is called, - but we may need to whack it with a stick if something breaks. - """ - - if self.proc is not None and self.proc.returncode is None: - try: - os.kill(self.proc.pid, self.killsig) - except OSError: - pass - - def handle_close(self): - """ - Intercept close event so we can log it, then shut down. - """ - - logging.debug("Server closed channel") - super(ClientChannel, self).handle_close() + self.retry = retry + self.expire = expire + self.updated = Timestamp.now() + if self.sql: + self.sql.execute("UPDATE cache SET" + " version = ?, serial = ?, nonce = ?," + " refresh = ?, retry = ?, expire = ?," + " updated = ? " + "WHERE cache_id = ?", + (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id)) + self.sql.commit() + + def consume_prefix(self, prefix): + """ + Handle one prefix PDU. + """ + + if self.sql: + values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) + if prefix.announce: + self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) " + "VALUES (?, ?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM prefix " + "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?", + values) + + def consume_routerkey(self, routerkey): + """ + Handle one Router Key PDU. + """ + + if self.sql: + values = (self.cache_id, routerkey.asn, + base64.urlsafe_b64encode(routerkey.ski).rstrip("="), + base64.b64encode(routerkey.key)) + if routerkey.announce: + self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) " + "VALUES (?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM routerkey " + "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)", + values) + + def deliver_pdu(self, pdu): + """ + Handle received PDU. + """ + + pdu.consume(self) + + def push_pdu(self, pdu): + """ + Log outbound PDU then write it to stream. + """ + + logging.debug(pdu) + super(ClientChannel, self).push_pdu(pdu) + + def cleanup(self): + """ + Force clean up this client's child process. If everything goes + well, child will have exited already before this method is called, + but we may need to whack it with a stick if something breaks. + """ + + if self.proc is not None and self.proc.returncode is None: + try: + os.kill(self.proc.pid, self.killsig) + except OSError: + pass + + def handle_close(self): + """ + Intercept close event so we can log it, then shut down. + """ + + logging.debug("Server closed channel") + super(ClientChannel, self).handle_close() # Hack to let us subclass this from scripts without needing to rewrite client_main(). @@ -460,73 +460,73 @@ class ClientChannel(rpki.rtr.channels.PDUChannel): ClientChannelClass = ClientChannel def client_main(args): - """ - Test client, intended primarily for debugging. - """ + """ + Test client, intended primarily for debugging. + """ - logging.debug("[Startup]") + logging.debug("[Startup]") - assert issubclass(ClientChannelClass, ClientChannel) - constructor = getattr(ClientChannelClass, args.protocol) + assert issubclass(ClientChannelClass, ClientChannel) + constructor = getattr(ClientChannelClass, args.protocol) - client = None - try: - client = constructor(args) + client = None + try: + client = constructor(args) - polled = client.updated - wakeup = None + polled = client.updated + wakeup = None - while True: + while True: - now = Timestamp.now() + now = Timestamp.now() - if client.serial is not None and now > client.updated + client.expire: - logging.info("[Expiring client data: serial %s, last updated %s, expire %s]", - client.serial, client.updated, client.expire) - client.cache_reset() + if client.serial is not None and now > client.updated + client.expire: + logging.info("[Expiring client data: serial %s, last updated %s, expire %s]", + client.serial, client.updated, client.expire) + client.cache_reset() - if client.serial is None or client.nonce is None: - polled = now - client.push_pdu(ResetQueryPDU(version = client.version)) + if client.serial is None or client.nonce is None: + polled = now + client.push_pdu(ResetQueryPDU(version = client.version)) - elif now >= client.updated + client.refresh: - polled = now - client.push_pdu(SerialQueryPDU(version = client.version, - serial = client.serial, - nonce = client.nonce)) + elif now >= client.updated + client.refresh: + polled = now + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.serial, + nonce = client.nonce)) - remaining = 1 + remaining = 1 - while remaining > 0: - now = Timestamp.now() - timer = client.retry if (now >= client.updated + client.refresh) else client.refresh - wokeup = wakeup - wakeup = max(now, Timestamp(max(polled, client.updated) + timer)) - remaining = wakeup - now - if wakeup != wokeup: - logging.info("[Last client poll %s, next %s]", polled, wakeup) - asyncore.loop(timeout = remaining, count = 1) + while remaining > 0: + now = Timestamp.now() + timer = client.retry if (now >= client.updated + client.refresh) else client.refresh + wokeup = wakeup + wakeup = max(now, Timestamp(max(polled, client.updated) + timer)) + remaining = wakeup - now + if wakeup != wokeup: + logging.info("[Last client poll %s, next %s]", polled, wakeup) + asyncore.loop(timeout = remaining, count = 1) - except KeyboardInterrupt: - sys.exit(0) + except KeyboardInterrupt: + sys.exit(0) - finally: - if client is not None: - client.cleanup() + finally: + if client is not None: + client.cleanup() def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - subparser = subparsers.add_parser("client", description = client_main.__doc__, - help = "Test client for RPKI-RTR protocol") - subparser.set_defaults(func = client_main, default_log_to = "stderr") - subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state") - subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version") - subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database") - subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol") - subparser.add_argument("host", nargs = "?", help = "server host") - subparser.add_argument("port", nargs = "?", help = "server port") - return subparser + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("client", description = client_main.__doc__, + help = "Test client for RPKI-RTR protocol") + subparser.set_defaults(func = client_main, default_log_to = "stderr") + subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state") + subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version") + subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database") + subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol") + subparser.add_argument("host", nargs = "?", help = "server host") + subparser.add_argument("port", nargs = "?", help = "server port") + return subparser diff --git a/rpki/rtr/generator.py b/rpki/rtr/generator.py index 26e25b6e..e00e44b7 100644 --- a/rpki/rtr/generator.py +++ b/rpki/rtr/generator.py @@ -37,539 +37,539 @@ import rpki.rtr.server from rpki.rtr.channels import Timestamp class PrefixPDU(rpki.rtr.pdus.PrefixPDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - @staticmethod - def from_text(version, asn, addr): - """ - Construct a prefix from its text form. - """ - - cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU - self = cls(version = version) - self.asn = long(asn) - p, l = addr.split("/") - self.prefix = rpki.POW.IPAddress(p) - if "-" in l: - self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-")) - else: - self.prefixlen = self.max_prefixlen = int(l) - self.announce = 1 - self.check() - return self - - @staticmethod - def from_roa(version, asn, prefix_tuple): - """ - Construct a prefix from a ROA. """ - - address, length, maxlength = prefix_tuple - cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU - self = cls(version = version) - self.asn = asn - self.prefix = address - self.prefixlen = length - self.max_prefixlen = length if maxlength is None else maxlength - self.announce = 1 - self.check() - return self + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + @staticmethod + def from_text(version, asn, addr): + """ + Construct a prefix from its text form. + """ + + cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU + self = cls(version = version) + self.asn = long(asn) + p, l = addr.split("/") + self.prefix = rpki.POW.IPAddress(p) + if "-" in l: + self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-")) + else: + self.prefixlen = self.max_prefixlen = int(l) + self.announce = 1 + self.check() + return self + + @staticmethod + def from_roa(version, asn, prefix_tuple): + """ + Construct a prefix from a ROA. + """ + + address, length, maxlength = prefix_tuple + cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU + self = cls(version = version) + self.asn = asn + self.prefix = address + self.prefixlen = length + self.max_prefixlen = length if maxlength is None else maxlength + self.announce = 1 + self.check() + return self class IPv4PrefixPDU(PrefixPDU): - """ - IPv4 flavor of a prefix. - """ + """ + IPv4 flavor of a prefix. + """ - pdu_type = 4 - address_byte_count = 4 + pdu_type = 4 + address_byte_count = 4 class IPv6PrefixPDU(PrefixPDU): - """ - IPv6 flavor of a prefix. - """ - - pdu_type = 6 - address_byte_count = 16 - -class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU): - """ - Router Key PDU. - """ - - @classmethod - def from_text(cls, version, asn, gski, key): """ - Construct a router key from its text form. + IPv6 flavor of a prefix. """ - self = cls(version = version) - self.asn = long(asn) - self.ski = base64.urlsafe_b64decode(gski + "=") - self.key = base64.b64decode(key) - self.announce = 1 - self.check() - return self + pdu_type = 6 + address_byte_count = 16 - @classmethod - def from_certificate(cls, version, asn, ski, key): +class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU): """ - Construct a router key from a certificate. + Router Key PDU. """ - self = cls(version = version) - self.asn = asn - self.ski = ski - self.key = key - self.announce = 1 - self.check() - return self + @classmethod + def from_text(cls, version, asn, gski, key): + """ + Construct a router key from its text form. + """ + self = cls(version = version) + self.asn = long(asn) + self.ski = base64.urlsafe_b64decode(gski + "=") + self.key = base64.b64decode(key) + self.announce = 1 + self.check() + return self -class ROA(rpki.POW.ROA): # pylint: disable=W0232 - """ - Minor additions to rpki.POW.ROA. - """ - - @classmethod - def derReadFile(cls, fn): # pylint: disable=E1002 - self = super(ROA, cls).derReadFile(fn) - self.extractWithoutVerifying() - return self - - @property - def prefixes(self): - v4, v6 = self.getPrefixes() - if v4 is not None: - for p in v4: - yield p - if v6 is not None: - for p in v6: - yield p - -class X509(rpki.POW.X509): # pylint: disable=W0232 - """ - Minor additions to rpki.POW.X509. - """ + @classmethod + def from_certificate(cls, version, asn, ski, key): + """ + Construct a router key from a certificate. + """ - @property - def asns(self): - resources = self.getRFC3779() - if resources is not None and resources[0] is not None: - for min_asn, max_asn in resources[0]: - for asn in xrange(min_asn, max_asn + 1): - yield asn + self = cls(version = version) + self.asn = asn + self.ski = ski + self.key = key + self.announce = 1 + self.check() + return self -class PDUSet(list): - """ - Object representing a set of PDUs, that is, one versioned and - (theoretically) consistant set of prefixes and router keys extracted - from rcynic's output. - """ - - def __init__(self, version): - assert version in rpki.rtr.pdus.PDU.version_map - super(PDUSet, self).__init__() - self.version = version - - @classmethod - def _load_file(cls, filename, version): +class ROA(rpki.POW.ROA): # pylint: disable=W0232 """ - Low-level method to read PDUSet from a file. + Minor additions to rpki.POW.ROA. """ - self = cls(version = version) - f = open(filename, "rb") - r = rpki.rtr.channels.ReadBuffer() - while True: - p = rpki.rtr.pdus.PDU.read_pdu(r) - while p is None: - b = f.read(r.needed()) - if b == "": - assert r.available() == 0 - return self - r.put(b) - p = r.retry() - assert p.version == self.version - self.append(p) - - @staticmethod - def seq_ge(a, b): - return ((a - b) % (1 << 32)) < (1 << 31) + @classmethod + def derReadFile(cls, fn): # pylint: disable=E1002 + self = super(ROA, cls).derReadFile(fn) + self.extractWithoutVerifying() + return self + @property + def prefixes(self): + v4, v6 = self.getPrefixes() + if v4 is not None: + for p in v4: + yield p + if v6 is not None: + for p in v6: + yield p -class AXFRSet(PDUSet): - """ - Object representing a complete set of PDUs, that is, one versioned - and (theoretically) consistant set of prefixes and router - certificates extracted from rcynic's output, all with the announce - field set. - """ - - @classmethod - def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None): +class X509(rpki.POW.X509): # pylint: disable=W0232 + """ + Minor additions to rpki.POW.X509. """ - Parse ROAS and router certificates fetched (and validated!) by - rcynic to create a new AXFRSet. - In normal operation, we use os.walk() and the rpki.POW library to - parse these data directly, but we can, if so instructed, use - external programs instead, for testing, simulation, or to provide - a way to inject local data. + @property + def asns(self): + resources = self.getRFC3779() + if resources is not None and resources[0] is not None: + for min_asn, max_asn in resources[0]: + for asn in xrange(min_asn, max_asn + 1): + yield asn - At some point the ability to parse these data from external - programs may move to a separate constructor function, so that we - can make this one a bit simpler and faster. - """ - self = cls(version = version) - self.serial = rpki.rtr.channels.Timestamp.now() - - include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version] - - if scan_roas is None or (scan_routercerts is None and include_routercerts): - for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612 - for fn in files: - if scan_roas is None and fn.endswith(".roa"): - roa = ROA.derReadFile(os.path.join(root, fn)) - asn = roa.getASID() - self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple) - for prefix_tuple in roa.prefixes) - if include_routercerts and scan_routercerts is None and fn.endswith(".cer"): - x = X509.derReadFile(os.path.join(root, fn)) - eku = x.getEKU() - if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: - ski = x.getSKI() - key = x.getPublicKey().derWritePublic() - self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key) - for asn in x.asns) - - if scan_roas is not None: - try: - p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE) - for line in p.stdout: - line = line.split() - asn = line[1] - self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr) - for addr in line[2:]) - except OSError, e: - sys.exit("Could not run %s: %s" % (scan_roas, e)) - - if include_routercerts and scan_routercerts is not None: - try: - p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE) - for line in p.stdout: - line = line.split() - gski = line[0] - key = line[-1] - self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key) - for asn in line[1:-1]) - except OSError, e: - sys.exit("Could not run %s: %s" % (scan_routercerts, e)) - - self.sort() - for i in xrange(len(self) - 2, -1, -1): - if self[i] == self[i + 1]: - del self[i + 1] - return self - - @classmethod - def load(cls, filename): - """ - Load an AXFRSet from a file, parse filename to obtain version and serial. +class PDUSet(list): """ + Object representing a set of PDUs, that is, one versioned and + (theoretically) consistant set of prefixes and router keys extracted + from rcynic's output. + """ + + def __init__(self, version): + assert version in rpki.rtr.pdus.PDU.version_map + super(PDUSet, self).__init__() + self.version = version + + @classmethod + def _load_file(cls, filename, version): + """ + Low-level method to read PDUSet from a file. + """ + + self = cls(version = version) + f = open(filename, "rb") + r = rpki.rtr.channels.ReadBuffer() + while True: + p = rpki.rtr.pdus.PDU.read_pdu(r) + while p is None: + b = f.read(r.needed()) + if b == "": + assert r.available() == 0 + return self + r.put(b) + p = r.retry() + assert p.version == self.version + self.append(p) + + @staticmethod + def seq_ge(a, b): + return ((a - b) % (1 << 32)) < (1 << 31) - fn1, fn2, fn3 = os.path.basename(filename).split(".") - assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit() - version = int(fn3[1:]) - self = cls._load_file(filename, version) - self.serial = rpki.rtr.channels.Timestamp(fn1) - return self - def filename(self): - """ - Generate filename for this AXFRSet. +class AXFRSet(PDUSet): """ + Object representing a complete set of PDUs, that is, one versioned + and (theoretically) consistant set of prefixes and router + certificates extracted from rcynic's output, all with the announce + field set. + """ + + @classmethod + def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None): + """ + Parse ROAS and router certificates fetched (and validated!) by + rcynic to create a new AXFRSet. + + In normal operation, we use os.walk() and the rpki.POW library to + parse these data directly, but we can, if so instructed, use + external programs instead, for testing, simulation, or to provide + a way to inject local data. + + At some point the ability to parse these data from external + programs may move to a separate constructor function, so that we + can make this one a bit simpler and faster. + """ + + self = cls(version = version) + self.serial = rpki.rtr.channels.Timestamp.now() + + include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version] + + if scan_roas is None or (scan_routercerts is None and include_routercerts): + for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612 + for fn in files: + if scan_roas is None and fn.endswith(".roa"): + roa = ROA.derReadFile(os.path.join(root, fn)) + asn = roa.getASID() + self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple) + for prefix_tuple in roa.prefixes) + if include_routercerts and scan_routercerts is None and fn.endswith(".cer"): + x = X509.derReadFile(os.path.join(root, fn)) + eku = x.getEKU() + if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: + ski = x.getSKI() + key = x.getPublicKey().derWritePublic() + self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key) + for asn in x.asns) + + if scan_roas is not None: + try: + p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE) + for line in p.stdout: + line = line.split() + asn = line[1] + self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr) + for addr in line[2:]) + except OSError, e: + sys.exit("Could not run %s: %s" % (scan_roas, e)) + + if include_routercerts and scan_routercerts is not None: + try: + p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE) + for line in p.stdout: + line = line.split() + gski = line[0] + key = line[-1] + self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key) + for asn in line[1:-1]) + except OSError, e: + sys.exit("Could not run %s: %s" % (scan_routercerts, e)) + + self.sort() + for i in xrange(len(self) - 2, -1, -1): + if self[i] == self[i + 1]: + del self[i + 1] + return self + + @classmethod + def load(cls, filename): + """ + Load an AXFRSet from a file, parse filename to obtain version and serial. + """ + + fn1, fn2, fn3 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit() + version = int(fn3[1:]) + self = cls._load_file(filename, version) + self.serial = rpki.rtr.channels.Timestamp(fn1) + return self + + def filename(self): + """ + Generate filename for this AXFRSet. + """ + + return "%d.ax.v%d" % (self.serial, self.version) + + @classmethod + def load_current(cls, version): + """ + Load current AXFRSet. Return None if can't. + """ + + serial = rpki.rtr.server.read_current(version)[0] + if serial is None: + return None + try: + return cls.load("%d.ax.v%d" % (serial, version)) + except IOError: + return None + + def save_axfr(self): + """ + Write AXFRSet to file with magic filename. + """ + + f = open(self.filename(), "wb") + for p in self: + f.write(p.to_pdu()) + f.close() + + def destroy_old_data(self): + """ + Destroy old data files, presumably because our nonce changed and + the old serial numbers are no longer valid. + """ + + for i in glob.iglob("*.ix.*.v%d" % self.version): + os.unlink(i) + for i in glob.iglob("*.ax.v%d" % self.version): + if i != self.filename(): + os.unlink(i) + + @staticmethod + def new_nonce(force_zero_nonce): + """ + Create and return a new nonce value. + """ + + if force_zero_nonce: + return 0 + try: + return int(random.SystemRandom().getrandbits(16)) + except NotImplementedError: + return int(random.getrandbits(16)) + + def mark_current(self, force_zero_nonce = False): + """ + Save current serial number and nonce, creating new nonce if + necessary. Creating a new nonce triggers cleanup of old state, as + the new nonce invalidates all old serial numbers. + """ + + assert self.version in rpki.rtr.pdus.PDU.version_map + old_serial, nonce = rpki.rtr.server.read_current(self.version) + if old_serial is None or self.seq_ge(old_serial, self.serial): + logging.debug("Creating new nonce and deleting stale data") + nonce = self.new_nonce(force_zero_nonce) + self.destroy_old_data() + rpki.rtr.server.write_current(self.serial, nonce, self.version) + + def save_ixfr(self, other): + """ + Comparing this AXFRSet with an older one and write the resulting + IXFRSet to file with magic filename. Since we store PDUSets + in sorted order, computing the difference is a trivial linear + comparison. + """ + + f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb") + old = other + new = self + len_old = len(old) + len_new = len(new) + i_old = i_new = 0 + while i_old < len_old and i_new < len_new: + if old[i_old] < new[i_new]: + f.write(old[i_old].to_pdu(announce = 0)) + i_old += 1 + elif old[i_old] > new[i_new]: + f.write(new[i_new].to_pdu(announce = 1)) + i_new += 1 + else: + i_old += 1 + i_new += 1 + for i in xrange(i_old, len_old): + f.write(old[i].to_pdu(announce = 0)) + for i in xrange(i_new, len_new): + f.write(new[i].to_pdu(announce = 1)) + f.close() + + def show(self): + """ + Print this AXFRSet. + """ + + logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version) + for p in self: + logging.debug(p) - return "%d.ax.v%d" % (self.serial, self.version) - @classmethod - def load_current(cls, version): +class IXFRSet(PDUSet): """ - Load current AXFRSet. Return None if can't. + Object representing an incremental set of PDUs, that is, the + differences between one versioned and (theoretically) consistant set + of prefixes and router certificates extracted from rcynic's output + and another, with the announce fields set or cleared as necessary to + indicate the changes. """ - serial = rpki.rtr.server.read_current(version)[0] - if serial is None: - return None - try: - return cls.load("%d.ax.v%d" % (serial, version)) - except IOError: - return None + @classmethod + def load(cls, filename): + """ + Load an IXFRSet from a file, parse filename to obtain version and serials. + """ - def save_axfr(self): - """ - Write AXFRSet to file with magic filename. - """ + fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit() + version = int(fn4[1:]) + self = cls._load_file(filename, version) + self.from_serial = rpki.rtr.channels.Timestamp(fn3) + self.to_serial = rpki.rtr.channels.Timestamp(fn1) + return self - f = open(self.filename(), "wb") - for p in self: - f.write(p.to_pdu()) - f.close() + def filename(self): + """ + Generate filename for this IXFRSet. + """ - def destroy_old_data(self): - """ - Destroy old data files, presumably because our nonce changed and - the old serial numbers are no longer valid. - """ + return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version) - for i in glob.iglob("*.ix.*.v%d" % self.version): - os.unlink(i) - for i in glob.iglob("*.ax.v%d" % self.version): - if i != self.filename(): - os.unlink(i) + def show(self): + """ + Print this IXFRSet. + """ - @staticmethod - def new_nonce(force_zero_nonce): - """ - Create and return a new nonce value. - """ + logging.debug("# IXFR %d (%s) -> %d (%s) v%d", + self.from_serial, self.from_serial, + self.to_serial, self.to_serial, + self.version) + for p in self: + logging.debug(p) - if force_zero_nonce: - return 0 - try: - return int(random.SystemRandom().getrandbits(16)) - except NotImplementedError: - return int(random.getrandbits(16)) - def mark_current(self, force_zero_nonce = False): +def kick_all(serial): """ - Save current serial number and nonce, creating new nonce if - necessary. Creating a new nonce triggers cleanup of old state, as - the new nonce invalidates all old serial numbers. + Kick any existing server processes to wake them up. """ - assert self.version in rpki.rtr.pdus.PDU.version_map - old_serial, nonce = rpki.rtr.server.read_current(self.version) - if old_serial is None or self.seq_ge(old_serial, self.serial): - logging.debug("Creating new nonce and deleting stale data") - nonce = self.new_nonce(force_zero_nonce) - self.destroy_old_data() - rpki.rtr.server.write_current(self.serial, nonce, self.version) + try: + os.stat(rpki.rtr.server.kickme_dir) + except OSError: + logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir) + os.makedirs(rpki.rtr.server.kickme_dir) + + msg = "Good morning, serial %d is ready" % serial + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base): + try: + logging.debug("# Kicking %s", name) + sock.sendto(msg, name) + except socket.error: + try: + logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name) + os.unlink(name) + except Exception, e: + logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e) + except Exception, e: + logging.warning("# Failed to kick %s and don't understand why: %s", name, e) + sock.close() - def save_ixfr(self, other): - """ - Comparing this AXFRSet with an older one and write the resulting - IXFRSet to file with magic filename. Since we store PDUSets - in sorted order, computing the difference is a trivial linear - comparison. - """ - f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb") - old = other - new = self - len_old = len(old) - len_new = len(new) - i_old = i_new = 0 - while i_old < len_old and i_new < len_new: - if old[i_old] < new[i_new]: - f.write(old[i_old].to_pdu(announce = 0)) - i_old += 1 - elif old[i_old] > new[i_new]: - f.write(new[i_new].to_pdu(announce = 1)) - i_new += 1 - else: - i_old += 1 - i_new += 1 - for i in xrange(i_old, len_old): - f.write(old[i].to_pdu(announce = 0)) - for i in xrange(i_new, len_new): - f.write(new[i].to_pdu(announce = 1)) - f.close() - - def show(self): - """ - Print this AXFRSet. +def cronjob_main(args): """ - - logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version) - for p in self: - logging.debug(p) + Run this right after running rcynic to wade through the ROAs and + router certificates that rcynic collects and translate that data + into the form used in the rpki-router protocol. Output is an + updated database containing both full dumps (AXFR) and incremental + dumps against a specific prior version (IXFR). After updating the + database, kicks any active servers, so that they can notify their + clients that a new version is available. + """ + + if args.rpki_rtr_dir: + try: + if not os.path.isdir(args.rpki_rtr_dir): + os.makedirs(args.rpki_rtr_dir) + os.chdir(args.rpki_rtr_dir) + except OSError, e: + logging.critical(str(e)) + sys.exit(1) + + for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True): + + logging.debug("# Generating updates for protocol version %d", version) + + old_ixfrs = glob.glob("*.ix.*.v%d" % version) + + current = rpki.rtr.server.read_current(version)[0] + cutoff = Timestamp.now(-(24 * 60 * 60)) + for f in glob.iglob("*.ax.v%d" % version): + t = Timestamp(int(f.split(".")[0])) + if t < cutoff and t != current: + logging.debug("# Deleting old file %s, timestamp %s", f, t) + os.unlink(f) + + pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts) + if pdus == rpki.rtr.generator.AXFRSet.load_current(version): + logging.debug("# No change, new serial not needed") + continue + pdus.save_axfr() + for axfr in glob.iglob("*.ax.v%d" % version): + if axfr != pdus.filename(): + pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr)) + pdus.mark_current(args.force_zero_nonce) + + logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial) + + rpki.rtr.generator.kick_all(pdus.serial) + + old_ixfrs.sort() + for ixfr in old_ixfrs: + try: + logging.debug("# Deleting old file %s", ixfr) + os.unlink(ixfr) + except OSError: + pass -class IXFRSet(PDUSet): - """ - Object representing an incremental set of PDUs, that is, the - differences between one versioned and (theoretically) consistant set - of prefixes and router certificates extracted from rcynic's output - and another, with the announce fields set or cleared as necessary to - indicate the changes. - """ - - @classmethod - def load(cls, filename): +def show_main(args): """ - Load an IXFRSet from a file, parse filename to obtain version and serials. + Display current rpki-rtr server database in textual form. """ - fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".") - assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit() - version = int(fn4[1:]) - self = cls._load_file(filename, version) - self.from_serial = rpki.rtr.channels.Timestamp(fn3) - self.to_serial = rpki.rtr.channels.Timestamp(fn1) - return self + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) - def filename(self): - """ - Generate filename for this IXFRSet. - """ + g = glob.glob("*.ax.v*") + g.sort() + for f in g: + rpki.rtr.generator.AXFRSet.load(f).show() - return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version) + g = glob.glob("*.ix.*.v*") + g.sort() + for f in g: + rpki.rtr.generator.IXFRSet.load(f).show() - def show(self): +def argparse_setup(subparsers): """ - Print this IXFRSet. + Set up argparse stuff for commands in this module. """ - logging.debug("# IXFR %d (%s) -> %d (%s) v%d", - self.from_serial, self.from_serial, - self.to_serial, self.to_serial, - self.version) - for p in self: - logging.debug(p) - + subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__, + help = "Generate RPKI-RTR database from rcynic output") + subparser.set_defaults(func = cronjob_main, default_log_to = "syslog") + subparser.add_argument("--scan-roas", help = "specify an external scan_roas program") + subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program") + subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero") + subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") -def kick_all(serial): - """ - Kick any existing server processes to wake them up. - """ - - try: - os.stat(rpki.rtr.server.kickme_dir) - except OSError: - logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir) - os.makedirs(rpki.rtr.server.kickme_dir) - - msg = "Good morning, serial %d is ready" % serial - sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) - for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base): - try: - logging.debug("# Kicking %s", name) - sock.sendto(msg, name) - except socket.error: - try: - logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name) - os.unlink(name) - except Exception, e: - logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e) - except Exception, e: - logging.warning("# Failed to kick %s and don't understand why: %s", name, e) - sock.close() - - -def cronjob_main(args): - """ - Run this right after running rcynic to wade through the ROAs and - router certificates that rcynic collects and translate that data - into the form used in the rpki-router protocol. Output is an - updated database containing both full dumps (AXFR) and incremental - dumps against a specific prior version (IXFR). After updating the - database, kicks any active servers, so that they can notify their - clients that a new version is available. - """ - - if args.rpki_rtr_dir: - try: - if not os.path.isdir(args.rpki_rtr_dir): - os.makedirs(args.rpki_rtr_dir) - os.chdir(args.rpki_rtr_dir) - except OSError, e: - logging.critical(str(e)) - sys.exit(1) - - for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True): - - logging.debug("# Generating updates for protocol version %d", version) - - old_ixfrs = glob.glob("*.ix.*.v%d" % version) - - current = rpki.rtr.server.read_current(version)[0] - cutoff = Timestamp.now(-(24 * 60 * 60)) - for f in glob.iglob("*.ax.v%d" % version): - t = Timestamp(int(f.split(".")[0])) - if t < cutoff and t != current: - logging.debug("# Deleting old file %s, timestamp %s", f, t) - os.unlink(f) - - pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts) - if pdus == rpki.rtr.generator.AXFRSet.load_current(version): - logging.debug("# No change, new serial not needed") - continue - pdus.save_axfr() - for axfr in glob.iglob("*.ax.v%d" % version): - if axfr != pdus.filename(): - pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr)) - pdus.mark_current(args.force_zero_nonce) - - logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial) - - rpki.rtr.generator.kick_all(pdus.serial) - - old_ixfrs.sort() - for ixfr in old_ixfrs: - try: - logging.debug("# Deleting old file %s", ixfr) - os.unlink(ixfr) - except OSError: - pass - - -def show_main(args): - """ - Display current rpki-rtr server database in textual form. - """ - - if args.rpki_rtr_dir: - try: - os.chdir(args.rpki_rtr_dir) - except OSError, e: - sys.exit(e) - - g = glob.glob("*.ax.v*") - g.sort() - for f in g: - rpki.rtr.generator.AXFRSet.load(f).show() - - g = glob.glob("*.ix.*.v*") - g.sort() - for f in g: - rpki.rtr.generator.IXFRSet.load(f).show() - -def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__, - help = "Generate RPKI-RTR database from rcynic output") - subparser.set_defaults(func = cronjob_main, default_log_to = "syslog") - subparser.add_argument("--scan-roas", help = "specify an external scan_roas program") - subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program") - subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero") - subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") - - subparser = subparsers.add_parser("show", description = show_main.__doc__, - help = "Display content of RPKI-RTR database") - subparser.set_defaults(func = show_main, default_log_to = "stderr") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + subparser = subparsers.add_parser("show", description = show_main.__doc__, + help = "Display content of RPKI-RTR database") + subparser.set_defaults(func = show_main, default_log_to = "stderr") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") diff --git a/rpki/rtr/main.py b/rpki/rtr/main.py index 12de30cc..34f5598d 100644 --- a/rpki/rtr/main.py +++ b/rpki/rtr/main.py @@ -31,64 +31,64 @@ import argparse class Formatter(logging.Formatter): - converter = time.gmtime + converter = time.gmtime - def __init__(self, debug, fmt, datefmt): - self.debug = debug - super(Formatter, self).__init__(fmt, datefmt) + def __init__(self, debug, fmt, datefmt): + self.debug = debug + super(Formatter, self).__init__(fmt, datefmt) - def format(self, record): - if getattr(record, "connection", None) is None: - record.connection = "" - return super(Formatter, self).format(record) + def format(self, record): + if getattr(record, "connection", None) is None: + record.connection = "" + return super(Formatter, self).format(record) - def formatException(self, ei): - if self.debug: - return super(Formatter, self).formatException(ei) - else: - return str(ei[1]) + def formatException(self, ei): + if self.debug: + return super(Formatter, self).formatException(ei) + else: + return str(ei[1]) def main(): - os.environ["TZ"] = "UTC" - time.tzset() - - from rpki.rtr.server import argparse_setup as argparse_setup_server - from rpki.rtr.client import argparse_setup as argparse_setup_client - from rpki.rtr.generator import argparse_setup as argparse_setup_generator - - if "rpki.rtr.bgpdump" in sys.modules: - from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump - else: - def argparse_setup_bgpdump(ignored): - pass - - argparser = argparse.ArgumentParser(description = __doc__) - argparser.add_argument("--debug", action = "store_true", help = "debugging mode") - argparser.add_argument("--log-level", default = "debug", - choices = ("debug", "info", "warning", "error", "critical"), - type = lambda s: s.lower()) - argparser.add_argument("--log-to", - choices = ("syslog", "stderr")) - subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode") - argparse_setup_server(subparsers) - argparse_setup_client(subparsers) - argparse_setup_generator(subparsers) - argparse_setup_bgpdump(subparsers) - args = argparser.parse_args() - - fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s" - - if (args.log_to or args.default_log_to) == "stderr": - handler = logging.StreamHandler() - fmt = "%(asctime)s " + fmt - elif os.path.exists("/dev/log"): - handler = logging.handlers.SysLogHandler("/dev/log") - else: - handler = logging.handlers.SysLogHandler() - - handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ")) - logging.root.addHandler(handler) - logging.root.setLevel(int(getattr(logging, args.log_level.upper()))) - - return args.func(args) + os.environ["TZ"] = "UTC" + time.tzset() + + from rpki.rtr.server import argparse_setup as argparse_setup_server + from rpki.rtr.client import argparse_setup as argparse_setup_client + from rpki.rtr.generator import argparse_setup as argparse_setup_generator + + if "rpki.rtr.bgpdump" in sys.modules: + from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump + else: + def argparse_setup_bgpdump(ignored): + pass + + argparser = argparse.ArgumentParser(description = __doc__) + argparser.add_argument("--debug", action = "store_true", help = "debugging mode") + argparser.add_argument("--log-level", default = "debug", + choices = ("debug", "info", "warning", "error", "critical"), + type = lambda s: s.lower()) + argparser.add_argument("--log-to", + choices = ("syslog", "stderr")) + subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode") + argparse_setup_server(subparsers) + argparse_setup_client(subparsers) + argparse_setup_generator(subparsers) + argparse_setup_bgpdump(subparsers) + args = argparser.parse_args() + + fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s" + + if (args.log_to or args.default_log_to) == "stderr": + handler = logging.StreamHandler() + fmt = "%(asctime)s " + fmt + elif os.path.exists("/dev/log"): + handler = logging.handlers.SysLogHandler("/dev/log") + else: + handler = logging.handlers.SysLogHandler() + + handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ")) + logging.root.addHandler(handler) + logging.root.setLevel(int(getattr(logging, args.log_level.upper()))) + + return args.func(args) diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py index 0d2e5928..94f579a1 100644 --- a/rpki/rtr/pdus.py +++ b/rpki/rtr/pdus.py @@ -28,292 +28,292 @@ import rpki.POW # Exceptions class PDUException(Exception): - """ - Parent exception type for exceptions that signal particular protocol - errors. String value of exception instance will be the message to - put in the ErrorReportPDU, error_report_code value of exception - will be the numeric code to use. - """ - - def __init__(self, msg = None, pdu = None): - super(PDUException, self).__init__() - assert msg is None or isinstance(msg, (str, unicode)) - self.error_report_msg = msg - self.error_report_pdu = pdu - - def __str__(self): - return self.error_report_msg or self.__class__.__name__ - - def make_error_report(self, version): - return ErrorReportPDU(version = version, - errno = self.error_report_code, - errmsg = self.error_report_msg, - errpdu = self.error_report_pdu) + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the ErrorReportPDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + super(PDUException, self).__init__() + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) class UnsupportedProtocolVersion(PDUException): - error_report_code = 4 + error_report_code = 4 class UnsupportedPDUType(PDUException): - error_report_code = 5 + error_report_code = 5 class CorruptData(PDUException): - error_report_code = 0 + error_report_code = 0 # Decorators def wire_pdu(cls, versions = None): - """ - Class decorator to add a PDU class to the set of known PDUs - for all supported protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. + """ - for v in PDU.version_map.iterkeys() if versions is None else versions: - assert cls.pdu_type not in PDU.version_map[v] - PDU.version_map[v][cls.pdu_type] = cls - return cls + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls + return cls def wire_pdu_only(*versions): - """ - Class decorator to add a PDU class to the set of known PDUs - for specific protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ - assert versions and all(v in PDU.version_map for v in versions) - return lambda cls: wire_pdu(cls, versions) + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) def clone_pdu_root(root_pdu_class): - """ - Replace a PDU root class's version_map with a two-level deep copy of itself, - and return a class decorator which subclasses can use to replace their - parent classes with themselves in the resulting cloned version map. + """ + Replace a PDU root class's version_map with a two-level deep copy of itself, + and return a class decorator which subclasses can use to replace their + parent classes with themselves in the resulting cloned version map. - This function is not itself a decorator, it returns one. - """ + This function is not itself a decorator, it returns one. + """ - root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) + root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) - def decorator(cls): - for pdu_map in root_pdu_class.version_map.itervalues(): - for pdu_type, pdu_class in pdu_map.items(): - if pdu_class in cls.__bases__: - pdu_map[pdu_type] = cls - return cls + def decorator(cls): + for pdu_map in root_pdu_class.version_map.itervalues(): + for pdu_type, pdu_class in pdu_map.items(): + if pdu_class in cls.__bases__: + pdu_map[pdu_type] = cls + return cls - return decorator + return decorator # PDUs class PDU(object): - """ - Base PDU. Real PDUs are subclasses of this class. - """ + """ + Base PDU. Real PDUs are subclasses of this class. + """ - version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu - _pdu = None # Cached when first generated + _pdu = None # Cached when first generated - header_struct = struct.Struct("!BB2xL") + header_struct = struct.Struct("!BB2xL") - def __init__(self, version): - assert version in self.version_map - self.version = version + def __init__(self, version): + assert version in self.version_map + self.version = version - def __cmp__(self, other): - return cmp(self.to_pdu(), other.to_pdu()) + def __cmp__(self, other): + return cmp(self.to_pdu(), other.to_pdu()) - @property - def default_version(self): - return max(self.version_map.iterkeys()) + @property + def default_version(self): + return max(self.version_map.iterkeys()) - def check(self): - pass + def check(self): + pass - @classmethod - def read_pdu(cls, reader): - return reader.update(need = cls.header_struct.size, callback = cls.got_header) + @classmethod + def read_pdu(cls, reader): + return reader.update(need = cls.header_struct.size, callback = cls.got_header) - @classmethod - def got_header(cls, reader): - if not reader.ready(): - return None - assert reader.available() >= cls.header_struct.size - version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - reader.check_version(version) - if pdu_type not in cls.version_map[version]: - raise UnsupportedPDUType( - "Received unsupported PDU type %d" % pdu_type) - if length < 8: - raise CorruptData( - "Received PDU with length %d, which is too short to be valid" % length) - self = cls.version_map[version][pdu_type](version = version) - return reader.update(need = length, callback = self.got_pdu) + @classmethod + def got_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.header_struct.size + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) + self = cls.version_map[version][pdu_type](version = version) + return reader.update(need = length, callback = self.got_pdu) class PDUWithSerial(PDU): - """ - Base class for PDUs consisting of just a serial number and nonce. - """ - - header_struct = struct.Struct("!BBHLL") - - def __init__(self, version, serial = None, nonce = None): - super(PDUWithSerial, self).__init__(version) - if serial is not None: - assert isinstance(serial, int) - self.serial = serial - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a serial number and nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 12: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHLL") + + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self class PDUWithNonce(PDU): - """ - Base class for PDUs consisting of just a nonce. - """ - - header_struct = struct.Struct("!BBHL") - - def __init__(self, version, nonce = None): - super(PDUWithNonce, self).__init__(version) - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) - return self._pdu + header_struct = struct.Struct("!BBHL") - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) -class PDUEmpty(PDU): - """ - Base class for empty PDUs. - """ + def to_pdu(self): + """ + Generate the wire format PDU. + """ - header_struct = struct.Struct("!BBHL") + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu - def __str__(self): - return "[%s]" % self.__class__.__name__ + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self - def to_pdu(self): + +class PDUEmpty(PDU): """ - Generate the wire format PDU for this prefix. + Base class for empty PDUs. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, zero, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if zero != 0: - raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHL") + + def __str__(self): + return "[%s]" % self.__class__.__name__ + + def to_pdu(self): + """ + Generate the wire format PDU for this prefix. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class SerialNotifyPDU(PDUWithSerial): - """ - Serial Notify PDU. - """ + """ + Serial Notify PDU. + """ - pdu_type = 0 + pdu_type = 0 @wire_pdu class SerialQueryPDU(PDUWithSerial): - """ - Serial Query PDU. - """ + """ + Serial Query PDU. + """ - pdu_type = 1 + pdu_type = 1 - def __init__(self, version, serial = None, nonce = None): - super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) @wire_pdu class ResetQueryPDU(PDUEmpty): - """ - Reset Query PDU. - """ + """ + Reset Query PDU. + """ - pdu_type = 2 + pdu_type = 2 - def __init__(self, version): - super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) @wire_pdu class CacheResponsePDU(PDUWithNonce): - """ - Cache Response PDU. - """ + """ + Cache Response PDU. + """ - pdu_type = 3 + pdu_type = 3 def EndOfDataPDU(version, *args, **kwargs): - """ - Factory for the EndOfDataPDU classes, which take different forms in - different protocol versions. - """ + """ + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ - if version == 0: - return EndOfDataPDUv0(version, *args, **kwargs) - if version == 1: - return EndOfDataPDUv1(version, *args, **kwargs) - raise NotImplementedError + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError # Min, max, and default values, from the current RFC 6810 bis I-D. @@ -324,325 +324,325 @@ def EndOfDataPDU(version, *args, **kwargs): default_refresh = 3600 def valid_refresh(refresh): - if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: - raise ValueError - return refresh + if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: + raise ValueError + return refresh default_retry = 600 def valid_retry(retry): - if not isinstance(retry, int) or retry < 120 or retry > 7200: - raise ValueError - return retry + if not isinstance(retry, int) or retry < 120 or retry > 7200: + raise ValueError + return retry default_expire = 7200 def valid_expire(expire): - if not isinstance(expire, int) or expire < 600 or expire > 172800: - raise ValueError - return expire + if not isinstance(expire, int) or expire < 600 or expire > 172800: + raise ValueError + return expire @wire_pdu_only(0) class EndOfDataPDUv0(PDUWithSerial): - """ - End of Data PDU, protocol version 0. - """ + """ + End of Data PDU, protocol version 0. + """ - pdu_type = 7 + pdu_type = 7 - def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): - super(EndOfDataPDUv0, self).__init__(version, serial, nonce) - self.refresh = valid_refresh(default_refresh if refresh is None else refresh) - self.retry = valid_retry( default_retry if retry is None else retry) - self.expire = valid_expire( default_expire if expire is None else expire) + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv0, self).__init__(version, serial, nonce) + self.refresh = valid_refresh(default_refresh if refresh is None else refresh) + self.retry = valid_retry( default_retry if retry is None else retry) + self.expire = valid_expire( default_expire if expire is None else expire) @wire_pdu_only(1) class EndOfDataPDUv1(EndOfDataPDUv0): - """ - End of Data PDU, protocol version 1. - """ + """ + End of Data PDU, protocol version 1. + """ - header_struct = struct.Struct("!BBHLLLLL") + header_struct = struct.Struct("!BBHLLLLL") - def __str__(self): - return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( - self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) - def to_pdu(self): - """ - Generate the wire format PDU. - """ + def to_pdu(self): + """ + Generate the wire format PDU. + """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial, - self.refresh, self.retry, self.expire) - return self._pdu + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ - = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 24: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class CacheResetPDU(PDUEmpty): - """ - Cache reset PDU. - """ + """ + Cache reset PDU. + """ - pdu_type = 8 + pdu_type = 8 class PrefixPDU(PDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - header_struct = struct.Struct("!BB2xLBBBx") - asnum_struct = struct.Struct("!L") - - def __str__(self): - plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def show(self): - logging.debug("# Class: %s", self.__class__.__name__) - logging.debug("# ASN: %s", self.asn) - logging.debug("# Prefix: %s", self.prefix) - logging.debug("# Prefixlen: %s", self.prefixlen) - logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) - logging.debug("# Announce: %s", self.announce) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if self.prefix.bits != self.address_byte_count * 8: - raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) - if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: - raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) - if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: - raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - """ - Generate the wire format PDU for this prefix. - """ - - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, - announce if announce is not None else self.announce, - self.prefixlen, self.max_prefixlen) + - self.prefix.toBytes() + - self.asnum_struct.pack(self.asn)) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b1 = reader.get(self.header_struct.size) - b2 = reader.get(self.address_byte_count) - b3 = reader.get(self.asnum_struct.size) - version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) - assert version == self.version and pdu_type == self.pdu_type - if length != len(b1) + len(b2) + len(b3): - raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) - self.prefix = rpki.POW.IPAddress.fromBytes(b2) - self.asn = self.asnum_struct.unpack(b3)[0] - assert b1 + b2 + b3 == self.to_pdu() - return self + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + header_struct = struct.Struct("!BB2xLBBBx") + asnum_struct = struct.Struct("!L") + + def __str__(self): + plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def show(self): + logging.debug("# Class: %s", self.__class__.__name__) + logging.debug("# ASN: %s", self.asn) + logging.debug("# Prefix: %s", self.prefix) + logging.debug("# Prefixlen: %s", self.prefixlen) + logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) + logging.debug("# Announce: %s", self.announce) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefix.bits != self.address_byte_count * 8: + raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, + announce if announce is not None else self.announce, + self.prefixlen, self.max_prefixlen) + + self.prefix.toBytes() + + self.asnum_struct.pack(self.asn)) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.get(self.header_struct.size) + b2 = reader.get(self.address_byte_count) + b3 = reader.get(self.asnum_struct.size) + version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) + self.prefix = rpki.POW.IPAddress.fromBytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() + return self @wire_pdu class IPv4PrefixPDU(PrefixPDU): - """ - IPv4 flavor of a prefix. - """ + """ + IPv4 flavor of a prefix. + """ - pdu_type = 4 - address_byte_count = 4 + pdu_type = 4 + address_byte_count = 4 @wire_pdu class IPv6PrefixPDU(PrefixPDU): - """ - IPv6 flavor of a prefix. - """ + """ + IPv6 flavor of a prefix. + """ - pdu_type = 6 - address_byte_count = 16 + pdu_type = 6 + address_byte_count = 16 @wire_pdu_only(1) class RouterKeyPDU(PDU): - """ - Router Key PDU. - """ - - pdu_type = 9 - - header_struct = struct.Struct("!BBBxL20sL") - - def __str__(self): - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, - base64.urlsafe_b64encode(self.ski).rstrip("="), - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if len(self.ski) != 20: - raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) - pdulen = self.header_struct.size + len(self.key) - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + len(self.key) - pdu = (self.header_struct.pack(self.version, - self.pdu_type, - announce if announce is not None else self.announce, - pdulen, - self.ski, - self.asn) - + self.key) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - if remaining <= 0: - raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) - self.key = reader.get(remaining) - assert header + self.key == self.to_pdu() - return self + """ + Router Key PDU. + """ + + pdu_type = 9 + + header_struct = struct.Struct("!BBBxL20sL") + + def __str__(self): + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, + base64.urlsafe_b64encode(self.ski).rstrip("="), + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if len(self.ski) != 20: + raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) + pdulen = self.header_struct.size + len(self.key) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + len(self.key) + pdu = (self.header_struct.pack(self.version, + self.pdu_type, + announce if announce is not None else self.announce, + pdulen, + self.ski, + self.asn) + + self.key) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + if remaining <= 0: + raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) + self.key = reader.get(remaining) + assert header + self.key == self.to_pdu() + return self @wire_pdu class ErrorReportPDU(PDU): - """ - Error Report PDU. - """ - - pdu_type = 10 - - header_struct = struct.Struct("!BBHL") - string_struct = struct.Struct("!L") - - errors = { - 2 : "No Data Available" } - - fatal = { - 0 : "Corrupt Data", - 1 : "Internal Error", - 3 : "Invalid Request", - 4 : "Unsupported Protocol Version", - 5 : "Unsupported PDU Type", - 6 : "Withdrawal of Unknown Record", - 7 : "Duplicate Announcement Received" } - - assert set(errors) & set(fatal) == set() - - errors.update(fatal) - - codes = dict((v, k) for k, v in errors.items()) - - def __init__(self, version, errno = None, errpdu = None, errmsg = None): - super(ErrorReportPDU, self).__init__(version) - assert errno is None or errno in self.errors - self.errno = errno - self.errpdu = errpdu - self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] - - def __str__(self): - return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) - - def to_counted_string(self, s): - return self.string_struct.pack(len(s)) + s - - def read_counted_string(self, reader, remaining): - assert remaining >= self.string_struct.size - n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] - assert remaining >= self.string_struct.size + n - return n, reader.get(n), (remaining - self.string_struct.size - n) - - def to_pdu(self): - """ - Generate the wire format PDU for this error report. - """ - - if self._pdu is None: - assert isinstance(self.errno, int) - assert not isinstance(self.errpdu, ErrorReportPDU) - p = self.errpdu - if p is None: - p = "" - elif isinstance(p, PDU): - p = p.to_pdu() - assert isinstance(p, str) - pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) - self._pdu += self.to_counted_string(p) - self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.errno, length = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) - self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) - if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: - raise CorruptData("Got PDU length %d, expected %d" % ( - length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) - assert (header - + self.to_counted_string(self.errpdu) - + self.to_counted_string(self.errmsg.encode("utf8")) - == self.to_pdu()) - return self + """ + Error Report PDU. + """ + + pdu_type = 10 + + header_struct = struct.Struct("!BBHL") + string_struct = struct.Struct("!L") + + errors = { + 2 : "No Data Available" } + + fatal = { + 0 : "Corrupt Data", + 1 : "Internal Error", + 3 : "Invalid Request", + 4 : "Unsupported Protocol Version", + 5 : "Unsupported PDU Type", + 6 : "Withdrawal of Unknown Record", + 7 : "Duplicate Announcement Received" } + + assert set(errors) & set(fatal) == set() + + errors.update(fatal) + + codes = dict((v, k) for k, v in errors.items()) + + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) + assert errno is None or errno in self.errors + self.errno = errno + self.errpdu = errpdu + self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] + + def __str__(self): + return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) + + def to_counted_string(self, s): + return self.string_struct.pack(len(s)) + s + + def read_counted_string(self, reader, remaining): + assert remaining >= self.string_struct.size + n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] + assert remaining >= self.string_struct.size + n + return n, reader.get(n), (remaining - self.string_struct.size - n) + + def to_pdu(self): + """ + Generate the wire format PDU for this error report. + """ + + if self._pdu is None: + assert isinstance(self.errno, int) + assert not isinstance(self.errpdu, ErrorReportPDU) + p = self.errpdu + if p is None: + p = "" + elif isinstance(p, PDU): + p = p.to_pdu() + assert isinstance(p, str) + pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) + self._pdu += self.to_counted_string(p) + self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) + self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) + assert (header + + self.to_counted_string(self.errpdu) + + self.to_counted_string(self.errmsg.encode("utf8")) + == self.to_pdu()) + return self diff --git a/rpki/rtr/server.py b/rpki/rtr/server.py index 1c7a5e78..f57c3037 100644 --- a/rpki/rtr/server.py +++ b/rpki/rtr/server.py @@ -44,37 +44,37 @@ kickme_base = os.path.join(kickme_dir, "kickme") class PDU(rpki.rtr.pdus.PDU): - """ - Generic server PDU. - """ - - def send_file(self, server, filename): """ - Send a content of a file as a cache response. Caller should catch IOError. + Generic server PDU. """ - fn2 = os.path.splitext(filename)[1] - assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version - - f = open(filename, "rb") - server.push_pdu(CacheResponsePDU(version = server.version, - nonce = server.current_nonce)) - server.push_file(f) - server.push_pdu(EndOfDataPDU(version = server.version, - serial = server.current_serial, - nonce = server.current_nonce, - refresh = server.refresh, - retry = server.retry, - expire = server.expire)) - - def send_nodata(self, server): - """ - Send a nodata error. - """ + def send_file(self, server, filename): + """ + Send a content of a file as a cache response. Caller should catch IOError. + """ + + fn2 = os.path.splitext(filename)[1] + assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version - server.push_pdu(ErrorReportPDU(version = server.version, - errno = ErrorReportPDU.codes["No Data Available"], - errpdu = self)) + f = open(filename, "rb") + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_file(f) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce, + refresh = server.refresh, + retry = server.retry, + expire = server.expire)) + + def send_nodata(self, server): + """ + Send a nodata error. + """ + + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["No Data Available"], + errpdu = self)) clone_pdu = clone_pdu_root(PDU) @@ -82,512 +82,512 @@ clone_pdu = clone_pdu_root(PDU) @clone_pdu class SerialQueryPDU(PDU, rpki.rtr.pdus.SerialQueryPDU): - """ - Serial Query PDU. - """ - - def serve(self, server): - """ - Received a serial query, send incremental transfer in response. - If client is already up to date, just send an empty incremental - transfer. """ - - server.logger.debug(self) - if server.get_serial() is None: - self.send_nodata(server) - elif server.current_nonce != self.nonce: - server.logger.info("[Client requested wrong nonce, resetting client]") - server.push_pdu(CacheResetPDU(version = server.version)) - elif server.current_serial == self.serial: - server.logger.debug("[Client is already current, sending empty IXFR]") - server.push_pdu(CacheResponsePDU(version = server.version, - nonce = server.current_nonce)) - server.push_pdu(EndOfDataPDU(version = server.version, - serial = server.current_serial, - nonce = server.current_nonce, - refresh = server.refresh, - retry = server.retry, - expire = server.expire)) - elif disable_incrementals: - server.push_pdu(CacheResetPDU(version = server.version)) - else: - try: - self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version)) - except IOError: - server.push_pdu(CacheResetPDU(version = server.version)) + Serial Query PDU. + """ + + def serve(self, server): + """ + Received a serial query, send incremental transfer in response. + If client is already up to date, just send an empty incremental + transfer. + """ + + server.logger.debug(self) + if server.get_serial() is None: + self.send_nodata(server) + elif server.current_nonce != self.nonce: + server.logger.info("[Client requested wrong nonce, resetting client]") + server.push_pdu(CacheResetPDU(version = server.version)) + elif server.current_serial == self.serial: + server.logger.debug("[Client is already current, sending empty IXFR]") + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce, + refresh = server.refresh, + retry = server.retry, + expire = server.expire)) + elif disable_incrementals: + server.push_pdu(CacheResetPDU(version = server.version)) + else: + try: + self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version)) + except IOError: + server.push_pdu(CacheResetPDU(version = server.version)) @clone_pdu class ResetQueryPDU(PDU, rpki.rtr.pdus.ResetQueryPDU): - """ - Reset Query PDU. - """ - - def serve(self, server): """ - Received a reset query, send full current state in response. + Reset Query PDU. """ - server.logger.debug(self) - if server.get_serial() is None: - self.send_nodata(server) - else: - try: - fn = "%d.ax.v%d" % (server.current_serial, server.version) - self.send_file(server, fn) - except IOError: - server.push_pdu(ErrorReportPDU(version = server.version, - errno = ErrorReportPDU.codes["Internal Error"], - errpdu = self, - errmsg = "Couldn't open %s" % fn)) + def serve(self, server): + """ + Received a reset query, send full current state in response. + """ + + server.logger.debug(self) + if server.get_serial() is None: + self.send_nodata(server) + else: + try: + fn = "%d.ax.v%d" % (server.current_serial, server.version) + self.send_file(server, fn) + except IOError: + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["Internal Error"], + errpdu = self, + errmsg = "Couldn't open %s" % fn)) @clone_pdu class ErrorReportPDU(rpki.rtr.pdus.ErrorReportPDU): - """ - Error Report PDU. - """ - - def serve(self, server): """ - Received an ErrorReportPDU from client. Not much we can do beyond - logging it, then killing the connection if error was fatal. + Error Report PDU. """ - server.logger.error(self) - if self.errno in self.fatal: - server.logger.error("[Shutting down due to reported fatal protocol error]") - sys.exit(1) + def serve(self, server): + """ + Received an ErrorReportPDU from client. Not much we can do beyond + logging it, then killing the connection if error was fatal. + """ + + server.logger.error(self) + if self.errno in self.fatal: + server.logger.error("[Shutting down due to reported fatal protocol error]") + sys.exit(1) def read_current(version): - """ - Read current serial number and nonce. Return None for both if - serial and nonce not recorded. For backwards compatibility, treat - file containing just a serial number as having a nonce of zero. - """ - - if version is None: - return None, None - try: - with open("current.v%d" % version, "r") as f: - values = tuple(int(s) for s in f.read().split()) - return values[0], values[1] - except IndexError: - return values[0], 0 - except IOError: - return None, None + """ + Read current serial number and nonce. Return None for both if + serial and nonce not recorded. For backwards compatibility, treat + file containing just a serial number as having a nonce of zero. + """ + + if version is None: + return None, None + try: + with open("current.v%d" % version, "r") as f: + values = tuple(int(s) for s in f.read().split()) + return values[0], values[1] + except IndexError: + return values[0], 0 + except IOError: + return None, None def write_current(serial, nonce, version): - """ - Write serial number and nonce. - """ + """ + Write serial number and nonce. + """ - curfn = "current.v%d" % version - tmpfn = curfn + "%d.tmp" % os.getpid() - with open(tmpfn, "w") as f: - f.write("%d %d\n" % (serial, nonce)) - os.rename(tmpfn, curfn) + curfn = "current.v%d" % version + tmpfn = curfn + "%d.tmp" % os.getpid() + with open(tmpfn, "w") as f: + f.write("%d %d\n" % (serial, nonce)) + os.rename(tmpfn, curfn) class FileProducer(object): - """ - File-based producer object for asynchat. - """ + """ + File-based producer object for asynchat. + """ - def __init__(self, handle, buffersize): - self.handle = handle - self.buffersize = buffersize + def __init__(self, handle, buffersize): + self.handle = handle + self.buffersize = buffersize - def more(self): - return self.handle.read(self.buffersize) + def more(self): + return self.handle.read(self.buffersize) class ServerWriteChannel(rpki.rtr.channels.PDUChannel): - """ - Kludge to deal with ssh's habit of sometimes (compile time option) - invoking us with two unidirectional pipes instead of one - bidirectional socketpair. All the server logic is in the - ServerChannel class, this class just deals with sending the - server's output to a different file descriptor. - """ - - def __init__(self): """ - Set up stdout. + Kludge to deal with ssh's habit of sometimes (compile time option) + invoking us with two unidirectional pipes instead of one + bidirectional socketpair. All the server logic is in the + ServerChannel class, this class just deals with sending the + server's output to a different file descriptor. """ - super(ServerWriteChannel, self).__init__(root_pdu_class = PDU) - self.init_file_dispatcher(sys.stdout.fileno()) + def __init__(self): + """ + Set up stdout. + """ - def readable(self): - """ - This channel is never readable. - """ + super(ServerWriteChannel, self).__init__(root_pdu_class = PDU) + self.init_file_dispatcher(sys.stdout.fileno()) - return False + def readable(self): + """ + This channel is never readable. + """ - def push_file(self, f): - """ - Write content of a file to stream. - """ + return False - try: - self.push_with_producer(FileProducer(f, self.ac_out_buffer_size)) - except OSError, e: - if e.errno != errno.EAGAIN: - raise + def push_file(self, f): + """ + Write content of a file to stream. + """ + try: + self.push_with_producer(FileProducer(f, self.ac_out_buffer_size)) + except OSError, e: + if e.errno != errno.EAGAIN: + raise -class ServerChannel(rpki.rtr.channels.PDUChannel): - """ - Server protocol engine, handles upcalls from PDUChannel to - implement protocol logic. - """ - def __init__(self, logger, refresh, retry, expire): +class ServerChannel(rpki.rtr.channels.PDUChannel): """ - Set up stdin and stdout as connection and start listening for - first PDU. + Server protocol engine, handles upcalls from PDUChannel to + implement protocol logic. """ - super(ServerChannel, self).__init__(root_pdu_class = PDU) - self.init_file_dispatcher(sys.stdin.fileno()) - self.writer = ServerWriteChannel() - self.logger = logger - self.refresh = refresh - self.retry = retry - self.expire = expire - self.get_serial() - self.start_new_pdu() - - def writable(self): - """ - This channel is never writable. - """ + def __init__(self, logger, refresh, retry, expire): + """ + Set up stdin and stdout as connection and start listening for + first PDU. + """ - return False + super(ServerChannel, self).__init__(root_pdu_class = PDU) + self.init_file_dispatcher(sys.stdin.fileno()) + self.writer = ServerWriteChannel() + self.logger = logger + self.refresh = refresh + self.retry = retry + self.expire = expire + self.get_serial() + self.start_new_pdu() - def push(self, data): - """ - Redirect to writer channel. - """ + def writable(self): + """ + This channel is never writable. + """ - return self.writer.push(data) + return False - def push_with_producer(self, producer): - """ - Redirect to writer channel. - """ + def push(self, data): + """ + Redirect to writer channel. + """ - return self.writer.push_with_producer(producer) + return self.writer.push(data) - def push_pdu(self, pdu): - """ - Redirect to writer channel. - """ + def push_with_producer(self, producer): + """ + Redirect to writer channel. + """ - return self.writer.push_pdu(pdu) + return self.writer.push_with_producer(producer) - def push_file(self, f): - """ - Redirect to writer channel. - """ + def push_pdu(self, pdu): + """ + Redirect to writer channel. + """ - return self.writer.push_file(f) + return self.writer.push_pdu(pdu) - def deliver_pdu(self, pdu): - """ - Handle received PDU. - """ + def push_file(self, f): + """ + Redirect to writer channel. + """ - pdu.serve(self) + return self.writer.push_file(f) - def get_serial(self): - """ - Read, cache, and return current serial number, or None if we can't - find the serial number file. The latter condition should never - happen, but maybe we got started in server mode while the cronjob - mode instance is still building its database. - """ + def deliver_pdu(self, pdu): + """ + Handle received PDU. + """ - self.current_serial, self.current_nonce = read_current(self.version) - return self.current_serial + pdu.serve(self) - def check_serial(self): - """ - Check for a new serial number. - """ + def get_serial(self): + """ + Read, cache, and return current serial number, or None if we can't + find the serial number file. The latter condition should never + happen, but maybe we got started in server mode while the cronjob + mode instance is still building its database. + """ - old_serial = self.current_serial - return old_serial != self.get_serial() + self.current_serial, self.current_nonce = read_current(self.version) + return self.current_serial - def notify(self, data = None, force = False): - """ - Cronjob instance kicked us: check whether our serial number has - changed, and send a notify message if so. + def check_serial(self): + """ + Check for a new serial number. + """ - We have to check rather than just blindly notifying when kicked - because the cronjob instance has no good way of knowing which - protocol version we're running, thus has no good way of knowing - whether we care about a particular change set or not. - """ + old_serial = self.current_serial + return old_serial != self.get_serial() - if force or self.check_serial(): - self.push_pdu(SerialNotifyPDU(version = self.version, - serial = self.current_serial, - nonce = self.current_nonce)) - else: - self.logger.debug("Cronjob kicked me but I see no serial change, ignoring") + def notify(self, data = None, force = False): + """ + Cronjob instance kicked us: check whether our serial number has + changed, and send a notify message if so. + + We have to check rather than just blindly notifying when kicked + because the cronjob instance has no good way of knowing which + protocol version we're running, thus has no good way of knowing + whether we care about a particular change set or not. + """ + + if force or self.check_serial(): + self.push_pdu(SerialNotifyPDU(version = self.version, + serial = self.current_serial, + nonce = self.current_nonce)) + else: + self.logger.debug("Cronjob kicked me but I see no serial change, ignoring") class KickmeChannel(asyncore.dispatcher, object): - """ - asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to - kick servers when it's time to send notify PDUs to clients. - """ - - def __init__(self, server): - asyncore.dispatcher.__init__(self) # Old-style class - self.server = server - self.sockname = "%s.%d" % (kickme_base, os.getpid()) - self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) - try: - self.bind(self.sockname) - os.chmod(self.sockname, 0660) - except socket.error, e: - self.server.logger.exception("Couldn't bind() kickme socket: %r", e) - self.close() - except OSError, e: - self.server.logger.exception("Couldn't chmod() kickme socket: %r", e) - - def writable(self): """ - This socket is read-only, never writable. + asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to + kick servers when it's time to send notify PDUs to clients. """ - return False + def __init__(self, server): + asyncore.dispatcher.__init__(self) # Old-style class + self.server = server + self.sockname = "%s.%d" % (kickme_base, os.getpid()) + self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) + try: + self.bind(self.sockname) + os.chmod(self.sockname, 0660) + except socket.error, e: + self.server.logger.exception("Couldn't bind() kickme socket: %r", e) + self.close() + except OSError, e: + self.server.logger.exception("Couldn't chmod() kickme socket: %r", e) + + def writable(self): + """ + This socket is read-only, never writable. + """ + + return False + + def handle_connect(self): + """ + Ignore connect events (not very useful on datagram socket). + """ + + pass + + def handle_read(self): + """ + Handle receipt of a datagram. + """ + + data = self.recv(512) + self.server.notify(data) + + def cleanup(self): + """ + Clean up this dispatcher's socket. + """ + + self.close() + try: + os.unlink(self.sockname) + except: # pylint: disable=W0702 + pass - def handle_connect(self): - """ - Ignore connect events (not very useful on datagram socket). - """ + def log(self, msg): + """ + Intercept asyncore's logging. + """ - pass + self.server.logger.info(msg) - def handle_read(self): - """ - Handle receipt of a datagram. - """ + def log_info(self, msg, tag = "info"): + """ + Intercept asyncore's logging. + """ - data = self.recv(512) - self.server.notify(data) + self.server.logger.info("asyncore: %s: %s", tag, msg) - def cleanup(self): - """ - Clean up this dispatcher's socket. - """ + def handle_error(self): + """ + Handle errors caught by asyncore main loop. + """ - self.close() - try: - os.unlink(self.sockname) - except: # pylint: disable=W0702 - pass + self.server.logger.exception("[Unhandled exception]") + self.server.logger.critical("[Exiting after unhandled exception]") + sys.exit(1) - def log(self, msg): + +def _hostport_tag(): """ - Intercept asyncore's logging. + Construct hostname/address + port when we're running under a + protocol we understand well enough to do that. This is all + kludgery. Just grit your teeth, or perhaps just close your eyes. """ - self.server.logger.info(msg) + proto = None - def log_info(self, msg, tag = "info"): - """ - Intercept asyncore's logging. - """ + if proto is None: + try: + host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername() + proto = "tcp" + except: # pylint: disable=W0702 + pass - self.server.logger.info("asyncore: %s: %s", tag, msg) + if proto is None: + try: + host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2] + proto = "tcp" + except: # pylint: disable=W0702 + pass - def handle_error(self): + if proto is None: + try: + host, port = os.environ["SSH_CONNECTION"].split()[0:2] + proto = "ssh" + except: # pylint: disable=W0702 + pass + + if proto is None: + try: + host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT") + proto = "ssl" + except: # pylint: disable=W0702 + pass + + if proto is None: + return "" + elif not port: + return "/%s/%s" % (proto, host) + elif ":" in host: + return "/%s/%s.%s" % (proto, host, port) + else: + return "/%s/%s:%s" % (proto, host, port) + + +def server_main(args): """ - Handle errors caught by asyncore main loop. + Implement the server side of the rpkk-router protocol. Other than + one PF_UNIX socket inode, this doesn't write anything to disk, so it + can be run with minimal privileges. Most of the work has already + been done by the database generator, so all this server has to do is + pass the results along to a client. """ - self.server.logger.exception("[Unhandled exception]") - self.server.logger.critical("[Exiting after unhandled exception]") - sys.exit(1) + logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag())) + logger.debug("[Starting]") -def _hostport_tag(): - """ - Construct hostname/address + port when we're running under a - protocol we understand well enough to do that. This is all - kludgery. Just grit your teeth, or perhaps just close your eyes. - """ - - proto = None + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) - if proto is None: + kickme = None try: - host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername() - proto = "tcp" - except: # pylint: disable=W0702 - pass + server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) + kickme = rpki.rtr.server.KickmeChannel(server = server) + asyncore.loop(timeout = None) + signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition + except KeyboardInterrupt: + sys.exit(0) + finally: + signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition + if kickme is not None: + kickme.cleanup() - if proto is None: - try: - host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2] - proto = "tcp" - except: # pylint: disable=W0702 - pass - if proto is None: - try: - host, port = os.environ["SSH_CONNECTION"].split()[0:2] - proto = "ssh" - except: # pylint: disable=W0702 - pass +def listener_main(args): + """ + Totally insecure TCP listener for rpki-rtr protocol. We only + implement this because it's all that the routers currently support. + In theory, we will all be running TCP-AO in the future, at which + point this listener will go away or become a TCP-AO listener. + """ - if proto is None: - try: - host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT") - proto = "ssl" - except: # pylint: disable=W0702 - pass + # Perhaps we should daemonize? Deal with that later. - if proto is None: - return "" - elif not port: - return "/%s/%s" % (proto, host) - elif ":" in host: - return "/%s/%s.%s" % (proto, host, port) - else: - return "/%s/%s:%s" % (proto, host, port) + # server_main() handles args.rpki_rtr_dir. + listener = None + try: + listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except: # pylint: disable=W0702 + if listener is not None: + listener.close() + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except AttributeError: + pass + listener.bind(("", args.port)) + listener.listen(5) + logging.debug("[Listening on port %s]", args.port) + while True: + try: + s, ai = listener.accept() + except KeyboardInterrupt: + sys.exit(0) + logging.debug("[Received connection from %r]", ai) + pid = os.fork() + if pid == 0: + os.dup2(s.fileno(), 0) # pylint: disable=E1103 + os.dup2(s.fileno(), 1) # pylint: disable=E1103 + s.close() + #os.closerange(3, os.sysconf("SC_OPEN_MAX")) + server_main(args) + sys.exit() + else: + logging.debug("[Spawned server %d]", pid) + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612 + if pid: + logging.debug("[Server %s exited]", pid) + continue + except: # pylint: disable=W0702 + pass + break -def server_main(args): - """ - Implement the server side of the rpkk-router protocol. Other than - one PF_UNIX socket inode, this doesn't write anything to disk, so it - can be run with minimal privileges. Most of the work has already - been done by the database generator, so all this server has to do is - pass the results along to a client. - """ - logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag())) +def argparse_setup(subparsers): + """ + Set up argparse stuff for commands in this module. + """ - logger.debug("[Starting]") + # These could have been lambdas, but doing it this way results in + # more useful error messages on argparse failures. - if args.rpki_rtr_dir: - try: - os.chdir(args.rpki_rtr_dir) - except OSError, e: - sys.exit(e) - - kickme = None - try: - server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) - kickme = rpki.rtr.server.KickmeChannel(server = server) - asyncore.loop(timeout = None) - signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition - except KeyboardInterrupt: - sys.exit(0) - finally: - signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition - if kickme is not None: - kickme.cleanup() + def refresh(v): + return rpki.rtr.pdus.valid_refresh(int(v)) + def retry(v): + return rpki.rtr.pdus.valid_retry(int(v)) -def listener_main(args): - """ - Totally insecure TCP listener for rpki-rtr protocol. We only - implement this because it's all that the routers currently support. - In theory, we will all be running TCP-AO in the future, at which - point this listener will go away or become a TCP-AO listener. - """ - - # Perhaps we should daemonize? Deal with that later. - - # server_main() handles args.rpki_rtr_dir. - - listener = None - try: - listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except: # pylint: disable=W0702 - if listener is not None: - listener.close() - listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - try: - listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - except AttributeError: - pass - listener.bind(("", args.port)) - listener.listen(5) - logging.debug("[Listening on port %s]", args.port) - while True: - try: - s, ai = listener.accept() - except KeyboardInterrupt: - sys.exit(0) - logging.debug("[Received connection from %r]", ai) - pid = os.fork() - if pid == 0: - os.dup2(s.fileno(), 0) # pylint: disable=E1103 - os.dup2(s.fileno(), 1) # pylint: disable=E1103 - s.close() - #os.closerange(3, os.sysconf("SC_OPEN_MAX")) - server_main(args) - sys.exit() - else: - logging.debug("[Spawned server %d]", pid) - while True: - try: - pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612 - if pid: - logging.debug("[Server %s exited]", pid) - continue - except: # pylint: disable=W0702 - pass - break + def expire(v): + return rpki.rtr.pdus.valid_expire(int(v)) + # Some duplication of arguments here, not enough to be worth huge + # effort to clean up, worry about it later in any case. -def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - # These could have been lambdas, but doing it this way results in - # more useful error messages on argparse failures. - - def refresh(v): - return rpki.rtr.pdus.valid_refresh(int(v)) - - def retry(v): - return rpki.rtr.pdus.valid_retry(int(v)) - - def expire(v): - return rpki.rtr.pdus.valid_expire(int(v)) - - # Some duplication of arguments here, not enough to be worth huge - # effort to clean up, worry about it later in any case. - - subparser = subparsers.add_parser("server", description = server_main.__doc__, - help = "RPKI-RTR protocol server") - subparser.set_defaults(func = server_main, default_log_to = "syslog") - subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") - subparser.add_argument("--retry", type = retry, help = "override default retry timer") - subparser.add_argument("--expire", type = expire, help = "override default expire timer") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") - - subparser = subparsers.add_parser("listener", description = listener_main.__doc__, - help = "TCP listener for RPKI-RTR protocol server") - subparser.set_defaults(func = listener_main, default_log_to = "syslog") - subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") - subparser.add_argument("--retry", type = retry, help = "override default retry timer") - subparser.add_argument("--expire", type = expire, help = "override default expire timer") - subparser.add_argument("port", type = int, help = "TCP port on which to listen") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + subparser = subparsers.add_parser("server", description = server_main.__doc__, + help = "RPKI-RTR protocol server") + subparser.set_defaults(func = server_main, default_log_to = "syslog") + subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") + subparser.add_argument("--retry", type = retry, help = "override default retry timer") + subparser.add_argument("--expire", type = expire, help = "override default expire timer") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + + subparser = subparsers.add_parser("listener", description = listener_main.__doc__, + help = "TCP listener for RPKI-RTR protocol server") + subparser.set_defaults(func = listener_main, default_log_to = "syslog") + subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") + subparser.add_argument("--retry", type = retry, help = "override default retry timer") + subparser.add_argument("--expire", type = expire, help = "override default expire timer") + subparser.add_argument("port", type = int, help = "TCP port on which to listen") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") diff --git a/rpki/sundial.py b/rpki/sundial.py index 60037277..0381599f 100644 --- a/rpki/sundial.py +++ b/rpki/sundial.py @@ -48,257 +48,257 @@ import datetime as pydatetime import re def now(): - """ - Get current timestamp. - """ - - return datetime.utcnow() - -class ParseFailure(Exception): - """ - Parse failure constructing timedelta. - """ - -class datetime(pydatetime.datetime): - """ - RPKI extensions to standard datetime.datetime class. All work here - is in UTC, so we use naive datetime objects. - """ - - def totimestamp(self): - """ - Convert to seconds from epoch (like time.time()). Conversion - method is a bit silly, but avoids time module timezone whackiness. - """ - - return int(self.strftime("%s")) - - @classmethod - def fromXMLtime(cls, x): - """ - Convert from XML time representation. - """ - - if x is None: - return None - else: - return cls.strptime(x, "%Y-%m-%dT%H:%M:%SZ") - - def toXMLtime(self): """ - Convert to XML time representation. + Get current timestamp. """ - return self.strftime("%Y-%m-%dT%H:%M:%SZ") - - def __str__(self): - return self.toXMLtime() - - @classmethod - def from_datetime(cls, x): - """ - Convert a datetime.datetime object into this subclass. This is - whacky due to the weird constructors for datetime. - """ - - return cls.combine(x.date(), x.time()) - - def to_datetime(self): - """ - Convert to a datetime.datetime object. In most cases this - shouldn't be necessary, but convincing SQL interfaces to use - subclasses of datetime can be hard. - """ - - return pydatetime.datetime(year = self.year, month = self.month, day = self.day, - hour = self.hour, minute = self.minute, second = self.second, - microsecond = 0, tzinfo = None) - - - @classmethod - def fromOpenSSL(cls, x): - """ - Convert from the format OpenSSL's command line tool uses into this - subclass. May require rewriting if we run into locale problems. - """ + return datetime.utcnow() - if x.startswith("notBefore=") or x.startswith("notAfter="): - x = x.partition("=")[2] - return cls.strptime(x, "%b %d %H:%M:%S %Y GMT") - - @classmethod - def from_sql(cls, x): - """ - Convert from SQL storage format. - """ - - return cls.from_datetime(x) - - def to_sql(self): - """ - Convert to SQL storage format. - """ - - return self.to_datetime() - - def later(self, other): +class ParseFailure(Exception): """ - Return the later of two timestamps. + Parse failure constructing timedelta. """ - return other if other > self else self - - def earlier(self, other): +class datetime(pydatetime.datetime): """ - Return the earlier of two timestamps. + RPKI extensions to standard datetime.datetime class. All work here + is in UTC, so we use naive datetime objects. """ - return other if other < self else self - - def __add__(self, y): return _cast(pydatetime.datetime.__add__(self, y)) - def __radd__(self, y): return _cast(pydatetime.datetime.__radd__(self, y)) - def __rsub__(self, y): return _cast(pydatetime.datetime.__rsub__(self, y)) - def __sub__(self, y): return _cast(pydatetime.datetime.__sub__(self, y)) + def totimestamp(self): + """ + Convert to seconds from epoch (like time.time()). Conversion + method is a bit silly, but avoids time module timezone whackiness. + """ + + return int(self.strftime("%s")) + + @classmethod + def fromXMLtime(cls, x): + """ + Convert from XML time representation. + """ + + if x is None: + return None + else: + return cls.strptime(x, "%Y-%m-%dT%H:%M:%SZ") + + def toXMLtime(self): + """ + Convert to XML time representation. + """ + + return self.strftime("%Y-%m-%dT%H:%M:%SZ") + + def __str__(self): + return self.toXMLtime() + + @classmethod + def from_datetime(cls, x): + """ + Convert a datetime.datetime object into this subclass. This is + whacky due to the weird constructors for datetime. + """ + + return cls.combine(x.date(), x.time()) + + def to_datetime(self): + """ + Convert to a datetime.datetime object. In most cases this + shouldn't be necessary, but convincing SQL interfaces to use + subclasses of datetime can be hard. + """ + + return pydatetime.datetime(year = self.year, month = self.month, day = self.day, + hour = self.hour, minute = self.minute, second = self.second, + microsecond = 0, tzinfo = None) + + + @classmethod + def fromOpenSSL(cls, x): + """ + Convert from the format OpenSSL's command line tool uses into this + subclass. May require rewriting if we run into locale problems. + """ + + if x.startswith("notBefore=") or x.startswith("notAfter="): + x = x.partition("=")[2] + return cls.strptime(x, "%b %d %H:%M:%S %Y GMT") + + @classmethod + def from_sql(cls, x): + """ + Convert from SQL storage format. + """ + + return cls.from_datetime(x) + + def to_sql(self): + """ + Convert to SQL storage format. + """ + + return self.to_datetime() + + def later(self, other): + """ + Return the later of two timestamps. + """ + + return other if other > self else self + + def earlier(self, other): + """ + Return the earlier of two timestamps. + """ + + return other if other < self else self + + def __add__(self, y): return _cast(pydatetime.datetime.__add__(self, y)) + def __radd__(self, y): return _cast(pydatetime.datetime.__radd__(self, y)) + def __rsub__(self, y): return _cast(pydatetime.datetime.__rsub__(self, y)) + def __sub__(self, y): return _cast(pydatetime.datetime.__sub__(self, y)) + + @classmethod + def DateTime_or_None(cls, s): + """ + MySQLdb converter. Parse as this class if we can, let the default + MySQLdb DateTime_or_None() converter deal with failure cases. + """ + + for sep in " T": + d, _, t = s.partition(sep) # pylint: disable=W0612 + if t: + try: + return cls(*[int(x) for x in d.split("-") + t.split(":")]) + except: # pylint: disable=W0702 + break + + from rpki.mysql_import import MySQLdb + return MySQLdb.times.DateTime_or_None(s) - @classmethod - def DateTime_or_None(cls, s): - """ - MySQLdb converter. Parse as this class if we can, let the default - MySQLdb DateTime_or_None() converter deal with failure cases. +class timedelta(pydatetime.timedelta): """ + Timedelta with text parsing. This accepts two input formats: - for sep in " T": - d, _, t = s.partition(sep) # pylint: disable=W0612 - if t: - try: - return cls(*[int(x) for x in d.split("-") + t.split(":")]) - except: # pylint: disable=W0702 - break + - A simple integer, indicating a number of seconds. - from rpki.mysql_import import MySQLdb - return MySQLdb.times.DateTime_or_None(s) + - A string of the form "uY vW wD xH yM zS" where u, v, w, x, y, and z + are integers and Y, W, D, H, M, and S indicate years, weeks, days, + hours, minutes, and seconds. All of the fields are optional, but + at least one must be specified. Eg,"3D4H" means "three days plus + four hours". -class timedelta(pydatetime.timedelta): - """ - Timedelta with text parsing. This accepts two input formats: - - - A simple integer, indicating a number of seconds. - - - A string of the form "uY vW wD xH yM zS" where u, v, w, x, y, and z - are integers and Y, W, D, H, M, and S indicate years, weeks, days, - hours, minutes, and seconds. All of the fields are optional, but - at least one must be specified. Eg,"3D4H" means "three days plus - four hours". - - There is no "months" format, because the definition of a month is too - fuzzy to be useful (what day is six months from August 30th?) - - Similarly, the "years" conversion may produce surprising results, as - "one year" in conventional English does not refer to a fixed interval - but rather a fixed (and in some cases undefined) offset within the - Gregorian calendar (what day is one year from February 29th?) 1Y as - implemented by this code refers to a specific number of seconds. - If you mean 365 days or 52 weeks, say that instead. - """ - - ## @var regexp - # Hideously ugly regular expression to parse the complex text form. - # Tags are intended for use with re.MatchObject.groupdict() and map - # directly to the keywords expected by the timedelta constructor. - - regexp = re.compile("\\s*".join(("^", - "(?:(?P<years>\\d+)Y)?", - "(?:(?P<weeks>\\d+)W)?", - "(?:(?P<days>\\d+)D)?", - "(?:(?P<hours>\\d+)H)?", - "(?:(?P<minutes>\\d+)M)?", - "(?:(?P<seconds>\\d+)S)?", - "$")), - re.I) - - ## @var years_to_seconds - # Conversion factor from years to seconds (value furnished by the - # "units" program). - - years_to_seconds = 31556926 - - @classmethod - def parse(cls, arg): - """ - Parse text into a timedelta object. - """ + There is no "months" format, because the definition of a month is too + fuzzy to be useful (what day is six months from August 30th?) - if not isinstance(arg, str): - return cls(seconds = arg) - elif arg.isdigit(): - return cls(seconds = int(arg)) - else: - match = cls.regexp.match(arg) - if match: - #return cls(**dict((k, int(v)) for (k, v) in match.groupdict().items() if v is not None)) - d = match.groupdict("0") - for k, v in d.iteritems(): - d[k] = int(v) - d["days"] += d.pop("weeks") * 7 - d["seconds"] += d.pop("years") * cls.years_to_seconds - return cls(**d) - else: - raise ParseFailure("Couldn't parse timedelta %r" % (arg,)) - - def convert_to_seconds(self): - """ - Convert a timedelta interval to seconds. + Similarly, the "years" conversion may produce surprising results, as + "one year" in conventional English does not refer to a fixed interval + but rather a fixed (and in some cases undefined) offset within the + Gregorian calendar (what day is one year from February 29th?) 1Y as + implemented by this code refers to a specific number of seconds. + If you mean 365 days or 52 weeks, say that instead. """ - return self.days * 24 * 60 * 60 + self.seconds + ## @var regexp + # Hideously ugly regular expression to parse the complex text form. + # Tags are intended for use with re.MatchObject.groupdict() and map + # directly to the keywords expected by the timedelta constructor. + + regexp = re.compile("\\s*".join(("^", + "(?:(?P<years>\\d+)Y)?", + "(?:(?P<weeks>\\d+)W)?", + "(?:(?P<days>\\d+)D)?", + "(?:(?P<hours>\\d+)H)?", + "(?:(?P<minutes>\\d+)M)?", + "(?:(?P<seconds>\\d+)S)?", + "$")), + re.I) + + ## @var years_to_seconds + # Conversion factor from years to seconds (value furnished by the + # "units" program). + + years_to_seconds = 31556926 + + @classmethod + def parse(cls, arg): + """ + Parse text into a timedelta object. + """ + + if not isinstance(arg, str): + return cls(seconds = arg) + elif arg.isdigit(): + return cls(seconds = int(arg)) + else: + match = cls.regexp.match(arg) + if match: + #return cls(**dict((k, int(v)) for (k, v) in match.groupdict().items() if v is not None)) + d = match.groupdict("0") + for k, v in d.iteritems(): + d[k] = int(v) + d["days"] += d.pop("weeks") * 7 + d["seconds"] += d.pop("years") * cls.years_to_seconds + return cls(**d) + else: + raise ParseFailure("Couldn't parse timedelta %r" % (arg,)) + + def convert_to_seconds(self): + """ + Convert a timedelta interval to seconds. + """ + + return self.days * 24 * 60 * 60 + self.seconds + + @classmethod + def fromtimedelta(cls, x): + """ + Convert a datetime.timedelta object into this subclass. + """ + + return cls(days = x.days, seconds = x.seconds, microseconds = x.microseconds) + + def __abs__(self): return _cast(pydatetime.timedelta.__abs__(self)) + def __add__(self, x): return _cast(pydatetime.timedelta.__add__(self, x)) + def __div__(self, x): return _cast(pydatetime.timedelta.__div__(self, x)) + def __floordiv__(self, x): return _cast(pydatetime.timedelta.__floordiv__(self, x)) + def __mul__(self, x): return _cast(pydatetime.timedelta.__mul__(self, x)) + def __neg__(self): return _cast(pydatetime.timedelta.__neg__(self)) + def __pos__(self): return _cast(pydatetime.timedelta.__pos__(self)) + def __radd__(self, x): return _cast(pydatetime.timedelta.__radd__(self, x)) + def __rdiv__(self, x): return _cast(pydatetime.timedelta.__rdiv__(self, x)) + def __rfloordiv__(self, x): return _cast(pydatetime.timedelta.__rfloordiv__(self, x)) + def __rmul__(self, x): return _cast(pydatetime.timedelta.__rmul__(self, x)) + def __rsub__(self, x): return _cast(pydatetime.timedelta.__rsub__(self, x)) + def __sub__(self, x): return _cast(pydatetime.timedelta.__sub__(self, x)) - @classmethod - def fromtimedelta(cls, x): +def _cast(x): """ - Convert a datetime.timedelta object into this subclass. + Cast result of arithmetic operations back into correct subtype. """ - return cls(days = x.days, seconds = x.seconds, microseconds = x.microseconds) - - def __abs__(self): return _cast(pydatetime.timedelta.__abs__(self)) - def __add__(self, x): return _cast(pydatetime.timedelta.__add__(self, x)) - def __div__(self, x): return _cast(pydatetime.timedelta.__div__(self, x)) - def __floordiv__(self, x): return _cast(pydatetime.timedelta.__floordiv__(self, x)) - def __mul__(self, x): return _cast(pydatetime.timedelta.__mul__(self, x)) - def __neg__(self): return _cast(pydatetime.timedelta.__neg__(self)) - def __pos__(self): return _cast(pydatetime.timedelta.__pos__(self)) - def __radd__(self, x): return _cast(pydatetime.timedelta.__radd__(self, x)) - def __rdiv__(self, x): return _cast(pydatetime.timedelta.__rdiv__(self, x)) - def __rfloordiv__(self, x): return _cast(pydatetime.timedelta.__rfloordiv__(self, x)) - def __rmul__(self, x): return _cast(pydatetime.timedelta.__rmul__(self, x)) - def __rsub__(self, x): return _cast(pydatetime.timedelta.__rsub__(self, x)) - def __sub__(self, x): return _cast(pydatetime.timedelta.__sub__(self, x)) - -def _cast(x): - """ - Cast result of arithmetic operations back into correct subtype. - """ - - if isinstance(x, pydatetime.datetime): - return datetime.from_datetime(x) - if isinstance(x, pydatetime.timedelta): - return timedelta.fromtimedelta(x) - return x + if isinstance(x, pydatetime.datetime): + return datetime.from_datetime(x) + if isinstance(x, pydatetime.timedelta): + return timedelta.fromtimedelta(x) + return x if __name__ == "__main__": - def test(t): - print - print "str: ", t - print "repr: ", repr(t) - print "seconds since epoch:", t.strftime("%s") - print "XMLtime: ", t.toXMLtime() - print + def test(t): + print + print "str: ", t + print "repr: ", repr(t) + print "seconds since epoch:", t.strftime("%s") + print "XMLtime: ", t.toXMLtime() + print - print - print "Testing time conversion routines" - test(now()) - test(now() + timedelta(days = 30)) - test(now() + timedelta.parse("3d5s")) - test(now() + timedelta.parse(" 3d 5s ")) - test(now() + timedelta.parse("1y3d5h")) + print + print "Testing time conversion routines" + test(now()) + test(now() + timedelta(days = 30)) + test(now() + timedelta.parse("3d5s")) + test(now() + timedelta.parse(" 3d 5s ")) + test(now() + timedelta.parse("1y3d5h")) diff --git a/rpki/up_down.py b/rpki/up_down.py index 90965e45..cfe86714 100644 --- a/rpki/up_down.py +++ b/rpki/up_down.py @@ -61,34 +61,34 @@ tag_status = xmlns + "status" class multi_uri(list): - """ - Container for a set of URIs. This probably could be simplified. - """ - - def __init__(self, ini): - list.__init__(self) - if isinstance(ini, (list, tuple)): - self[:] = ini - elif isinstance(ini, str): - self[:] = ini.split(",") - for s in self: - if s.strip() != s or "://" not in s: - raise rpki.exceptions.BadURISyntax("Bad URI \"%s\"" % s) - else: - raise TypeError - - def __str__(self): - return ",".join(self) - - def rsync(self): """ - Find first rsync://... URI in self. + Container for a set of URIs. This probably could be simplified. """ - for s in self: - if s.startswith("rsync://"): - return s - return None + def __init__(self, ini): + list.__init__(self) + if isinstance(ini, (list, tuple)): + self[:] = ini + elif isinstance(ini, str): + self[:] = ini.split(",") + for s in self: + if s.strip() != s or "://" not in s: + raise rpki.exceptions.BadURISyntax("Bad URI \"%s\"" % s) + else: + raise TypeError + + def __str__(self): + return ",".join(self) + + def rsync(self): + """ + Find first rsync://... URI in self. + """ + + for s in self: + if s.startswith("rsync://"): + return s + return None error_response_codes = { @@ -111,62 +111,62 @@ exception_map = { def check_response(r_msg, q_type): - """ - Additional checks beyond the XML schema for whether this looks like - a reasonable up-down response message. - """ + """ + Additional checks beyond the XML schema for whether this looks like + a reasonable up-down response message. + """ - r_type = r_msg.get("type") + r_type = r_msg.get("type") - if r_type == "error_response": - raise rpki.exceptions.UpstreamError(error_response_codes[int(r_msg.findtext(tag_status))]) + if r_type == "error_response": + raise rpki.exceptions.UpstreamError(error_response_codes[int(r_msg.findtext(tag_status))]) - if r_type != q_type + "_response": - raise rpki.exceptions.UnexpectedUpDownResponse + if r_type != q_type + "_response": + raise rpki.exceptions.UnexpectedUpDownResponse - if r_type == "issue_response" and (len(r_msg) != 1 or len(r_msg[0]) != 2): - logger.debug("Weird issue_response %r: len(r_msg) %s len(r_msg[0]) %s", - r_msg, len(r_msg), len(r_msg[0]) if len(r_msg) else None) - logger.debug("Offending message\n%s", ElementToString(r_msg)) - raise rpki.exceptions.BadIssueResponse + if r_type == "issue_response" and (len(r_msg) != 1 or len(r_msg[0]) != 2): + logger.debug("Weird issue_response %r: len(r_msg) %s len(r_msg[0]) %s", + r_msg, len(r_msg), len(r_msg[0]) if len(r_msg) else None) + logger.debug("Offending message\n%s", ElementToString(r_msg)) + raise rpki.exceptions.BadIssueResponse def generate_error_response(r_msg, status = 2001, description = None): - """ - Generate an error response. If status is given, it specifies the - numeric code to use, otherwise we default to "internal error". - If description is specified, we use it as the description, otherwise - we just use the default string associated with status. - """ - - assert status in error_response_codes - del r_msg[:] - r_msg.set("type", "error_response") - SubElement(r_msg, tag_status).text = str(status) - se = SubElement(r_msg, tag_description) - se.set("{http://www.w3.org/XML/1998/namespace}lang", "en-US") - se.text = str(description or error_response_codes[status]) + """ + Generate an error response. If status is given, it specifies the + numeric code to use, otherwise we default to "internal error". + If description is specified, we use it as the description, otherwise + we just use the default string associated with status. + """ + + assert status in error_response_codes + del r_msg[:] + r_msg.set("type", "error_response") + SubElement(r_msg, tag_status).text = str(status) + se = SubElement(r_msg, tag_description) + se.set("{http://www.w3.org/XML/1998/namespace}lang", "en-US") + se.text = str(description or error_response_codes[status]) def generate_error_response_from_exception(r_msg, e, q_type): - """ - Construct an error response from an exception. q_type - specifies the kind of query to which this is a response, since the - same exception can generate different codes in response to different - queries. - """ + """ + Construct an error response from an exception. q_type + specifies the kind of query to which this is a response, since the + same exception can generate different codes in response to different + queries. + """ - t = type(e) - code = (exception_map.get((t, q_type)) or exception_map.get(t) or 2001) - generate_error_response(r_msg, code, e) + t = type(e) + code = (exception_map.get((t, q_type)) or exception_map.get(t) or 2001) + generate_error_response(r_msg, code, e) class cms_msg(rpki.x509.XML_CMS_object): - """ - CMS-signed up-down PDU. - """ - - encoding = "UTF-8" - schema = rpki.relaxng.up_down - allow_extra_certs = True - allow_extra_crls = True + """ + CMS-signed up-down PDU. + """ + + encoding = "UTF-8" + schema = rpki.relaxng.up_down + allow_extra_certs = True + allow_extra_crls = True diff --git a/rpki/x509.py b/rpki/x509.py index 16981d06..d904bb0f 100644 --- a/rpki/x509.py +++ b/rpki/x509.py @@ -52,1175 +52,1175 @@ import rpki.relaxng logger = logging.getLogger(__name__) def base64_with_linebreaks(der): - """ - Encode DER (really, anything) as Base64 text, with linebreaks to - keep the result (sort of) readable. - """ - - b = base64.b64encode(der) - n = len(b) - return "\n" + "\n".join(b[i : min(i + 64, n)] for i in xrange(0, n, 64)) + "\n" - -def looks_like_PEM(text): - """ - Guess whether text looks like a PEM encoding. - """ - - i = text.find("-----BEGIN ") - return i >= 0 and text.find("\n-----END ", i) > i - -def first_uri_matching_prefix(xia, prefix): - """ - Find first URI in a sequence of AIA or SIA URIs which matches a - particular prefix string. Returns the URI if found, otherwise None. - """ - - if xia is not None: - for uri in xia: - if uri.startswith(prefix): - return uri - return None - -def first_rsync_uri(xia): - """ - Find first rsync URI in a sequence of AIA or SIA URIs. - Returns the URI if found, otherwise None. - """ - - return first_uri_matching_prefix(xia, "rsync://") - -def first_http_uri(xia): - """ - Find first HTTP URI in a sequence of AIA or SIA URIs. - Returns the URI if found, otherwise None. - """ - - return first_uri_matching_prefix(xia, "http://") - -def first_https_uri(xia): - """ - Find first HTTPS URI in a sequence of AIA or SIA URIs. - Returns the URI if found, otherwise None. - """ - - return first_uri_matching_prefix(xia, "https://") - -def sha1(data): - """ - Calculate SHA-1 digest of some data. - Convenience wrapper around rpki.POW.Digest class. - """ - - d = rpki.POW.Digest(rpki.POW.SHA1_DIGEST) - d.update(data) - return d.digest() - -def sha256(data): - """ - Calculate SHA-256 digest of some data. - Convenience wrapper around rpki.POW.Digest class. - """ - - d = rpki.POW.Digest(rpki.POW.SHA256_DIGEST) - d.update(data) - return d.digest() - - -class X501DN(object): - """ - Class to hold an X.501 Distinguished Name. - - This is nothing like a complete implementation, just enough for our - purposes. See RFC 5280 4.1.2.4 for the ASN.1 details. In brief: - - - A DN is a SEQUENCE OF RDNs. - - - A RDN is a SET OF AttributeAndValues; in practice, multi-value - RDNs are rare, so an RDN is almost always a set with a single - element. - - - An AttributeAndValue is a SEQUENCE consisting of a OID and a - value, where a whole bunch of things including both syntax and - semantics of the value are determined by the OID. - - - The value is some kind of ASN.1 string; there are far too many - encoding options options, most of which are either strongly - discouraged or outright forbidden by the PKIX profile, but which - persist for historical reasons. The only ones PKIX actually - likes are PrintableString and UTF8String, but there are nuances - and special cases where some of the others are required. - - The RPKI profile further restricts DNs to a single mandatory - CommonName attribute with a single optional SerialNumber attribute - (not to be confused with the certificate serial number). - - BPKI certificates should (we hope) follow the general PKIX guideline - but the ones we construct ourselves are likely to be relatively - simple. - """ - - def __str__(self): - return "".join("/" + "+".join("%s=%s" % (rpki.oids.oid2name(a[0]), a[1]) - for a in rdn) - for rdn in self.dn) - - def __cmp__(self, other): - return cmp(self.dn, other.dn) - - def __repr__(self): - return rpki.log.log_repr(self, str(self)) - - def _debug(self): - logger.debug("++ %r %r", self, self.dn) - - @classmethod - def from_cn(cls, cn, sn = None): - assert isinstance(cn, (str, unicode)) - if isinstance(sn, (int, long)): - sn = "%08X" % sn - elif isinstance(sn, (str, unicode)): - assert all(c in "0123456789abcdefABCDEF" for c in sn) - sn = str(sn) - self = cls() - if sn is not None: - self.dn = (((rpki.oids.commonName, cn),), ((rpki.oids.serialNumber, sn),)) - else: - self.dn = (((rpki.oids.commonName, cn),),) - return self - - @classmethod - def from_POW(cls, t): - assert isinstance(t, tuple) - self = cls() - self.dn = t - return self - - def get_POW(self): - return self.dn - - def extract_cn_and_sn(self): - cn = None - sn = None - - for rdn in self.dn: - if len(rdn) == 1 and len(rdn[0]) == 2: - oid = rdn[0][0] - val = rdn[0][1] - if oid == rpki.oids.commonName and cn is None: - cn = val - continue - if oid == rpki.oids.serialNumber and sn is None: - sn = val - continue - raise rpki.exceptions.BadX510DN("Bad subject name: %s" % (self.dn,)) - - if cn is None: - raise rpki.exceptions.BadX510DN("Subject name is missing CN: %s" % (self.dn,)) - - return cn, sn - - -class DER_object(object): - """ - Virtual class to hold a generic DER object. - """ - - ## @var formats - # Formats supported in this object. This is kind of redundant now - # that we're down to a single ASN.1 package and everything supports - # the same DER and POW formats, it's mostly historical baggage from - # the days when we had three different ASN.1 encoders, each with its - # own low-level Python object format. Clean up, some day. - formats = ("DER", "POW") - - ## @var POW_class - # Class of underlying POW object. Concrete subclasses must supply this. - POW_class = None - - ## Other attributes that self.clear() should whack. - other_clear = () - - ## @var DER - # DER value of this object - DER = None - - ## @var failure_threshold - # Rate-limiting interval between whines about Auto_update objects. - failure_threshold = rpki.sundial.timedelta(minutes = 5) - - def empty(self): - """ - Test whether this object is empty. - """ - - return all(getattr(self, a, None) is None for a in self.formats) - - def clear(self): - """ - Make this object empty. """ - - for a in self.formats + self.other_clear: - setattr(self, a, None) - self.filename = None - self.timestamp = None - self.lastfail = None - - def __init__(self, **kw): - """ - Initialize a DER_object. + Encode DER (really, anything) as Base64 text, with linebreaks to + keep the result (sort of) readable. """ - self.clear() - if len(kw): - self.set(**kw) + b = base64.b64encode(der) + n = len(b) + return "\n" + "\n".join(b[i : min(i + 64, n)] for i in xrange(0, n, 64)) + "\n" - def set(self, **kw): - """ - Set this object by setting one of its known formats. - - This method only allows one to set one format at a time. - Subsequent calls will clear the object first. The point of all - this is to let the object's internal converters handle mustering - the object into whatever format you need at the moment. - """ - - if len(kw) == 1: - name = kw.keys()[0] - if name in self.formats: - self.clear() - setattr(self, name, kw[name]) - return - if name == "PEM": - self.clear() - self._set_PEM(kw[name]) - return - if name == "Base64": - self.clear() - self.DER = base64.b64decode(kw[name]) - return - if name == "Auto_update": - self.filename = kw[name] - self.check_auto_update() - return - if name in ("PEM_file", "DER_file", "Auto_file"): - f = open(kw[name], "rb") - value = f.read() - f.close() - self.clear() - if name == "PEM_file" or (name == "Auto_file" and looks_like_PEM(value)): - self._set_PEM(value) - else: - self.DER = value - return - raise rpki.exceptions.DERObjectConversionError("Can't honor conversion request %r" % (kw,)) - - def check_auto_update(self): - """ - Check for updates to a DER object that auto-updates from a file. - """ - - if self.filename is None: - return - try: - filename = self.filename - timestamp = os.stat(self.filename).st_mtime - if self.timestamp is None or self.timestamp < timestamp: - logger.debug("Updating %s, timestamp %s", - filename, rpki.sundial.datetime.fromtimestamp(timestamp)) - f = open(filename, "rb") - value = f.read() - f.close() - self.clear() - if looks_like_PEM(value): - self._set_PEM(value) - else: - self.DER = value - self.filename = filename - self.timestamp = timestamp - except (IOError, OSError), e: - now = rpki.sundial.now() - if self.lastfail is None or now > self.lastfail + self.failure_threshold: - logger.warning("Could not auto_update %r (last failure %s): %s", self, self.lastfail, e) - self.lastfail = now - else: - self.lastfail = None - - @property - def mtime(self): - """ - Retrieve os.stat().st_mtime for auto-update files. - """ - - return os.stat(self.filename).st_mtime - - def check(self): +def looks_like_PEM(text): """ - Perform basic checks on a DER object. + Guess whether text looks like a PEM encoding. """ - self.check_auto_update() - assert not self.empty() + i = text.find("-----BEGIN ") + return i >= 0 and text.find("\n-----END ", i) > i - def _set_PEM(self, pem): +def first_uri_matching_prefix(xia, prefix): """ - Set the POW value of this object based on a PEM input value. - Subclasses may need to override this. + Find first URI in a sequence of AIA or SIA URIs which matches a + particular prefix string. Returns the URI if found, otherwise None. """ - assert self.empty() - self.POW = self.POW_class.pemRead(pem) + if xia is not None: + for uri in xia: + if uri.startswith(prefix): + return uri + return None - def get_DER(self): +def first_rsync_uri(xia): """ - Get the DER value of this object. - Subclasses may need to override this method. + Find first rsync URI in a sequence of AIA or SIA URIs. + Returns the URI if found, otherwise None. """ - self.check() - if self.DER: - return self.DER - if self.POW: - self.DER = self.POW.derWrite() - return self.get_DER() - raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + return first_uri_matching_prefix(xia, "rsync://") - def get_POW(self): +def first_http_uri(xia): """ - Get the rpki.POW value of this object. - Subclasses may need to override this method. + Find first HTTP URI in a sequence of AIA or SIA URIs. + Returns the URI if found, otherwise None. """ - self.check() - if not self.POW: # pylint: disable=E0203 - self.POW = self.POW_class.derRead(self.get_DER()) - return self.POW + return first_uri_matching_prefix(xia, "http://") - def get_Base64(self): +def first_https_uri(xia): """ - Get the Base64 encoding of the DER value of this object. + Find first HTTPS URI in a sequence of AIA or SIA URIs. + Returns the URI if found, otherwise None. """ - return base64_with_linebreaks(self.get_DER()) + return first_uri_matching_prefix(xia, "https://") - def get_PEM(self): +def sha1(data): """ - Get the PEM representation of this object. + Calculate SHA-1 digest of some data. + Convenience wrapper around rpki.POW.Digest class. """ - return self.get_POW().pemWrite() + d = rpki.POW.Digest(rpki.POW.SHA1_DIGEST) + d.update(data) + return d.digest() - def __cmp__(self, other): - """ - Compare two DER-encoded objects. - """ - - if self is None and other is None: - return 0 - elif self is None: - return -1 - elif other is None: - return 1 - elif isinstance(other, str): - return cmp(self.get_DER(), other) - else: - return cmp(self.get_DER(), other.get_DER()) - - def hSKI(self): +def sha256(data): """ - Return hexadecimal string representation of SKI for this object. - Only work for subclasses that implement get_SKI(). + Calculate SHA-256 digest of some data. + Convenience wrapper around rpki.POW.Digest class. """ - ski = self.get_SKI() - return ":".join(("%02X" % ord(i) for i in ski)) if ski else "" - - def gSKI(self): - """ - Calculate g(SKI) for this object. Only work for subclasses - that implement get_SKI(). - """ + d = rpki.POW.Digest(rpki.POW.SHA256_DIGEST) + d.update(data) + return d.digest() - return base64.urlsafe_b64encode(self.get_SKI()).rstrip("=") - def hAKI(self): - """ - Return hexadecimal string representation of AKI for this - object. Only work for subclasses that implement get_AKI(). +class X501DN(object): """ + Class to hold an X.501 Distinguished Name. - aki = self.get_AKI() - return ":".join(("%02X" % ord(i) for i in aki)) if aki else "" + This is nothing like a complete implementation, just enough for our + purposes. See RFC 5280 4.1.2.4 for the ASN.1 details. In brief: - def gAKI(self): - """ - Calculate g(AKI) for this object. Only work for subclasses - that implement get_AKI(). - """ + - A DN is a SEQUENCE OF RDNs. - return base64.urlsafe_b64encode(self.get_AKI()).rstrip("=") + - A RDN is a SET OF AttributeAndValues; in practice, multi-value + RDNs are rare, so an RDN is almost always a set with a single + element. - def get_AKI(self): - """ - Get the AKI extension from this object, if supported. - """ + - An AttributeAndValue is a SEQUENCE consisting of a OID and a + value, where a whole bunch of things including both syntax and + semantics of the value are determined by the OID. - return self.get_POW().getAKI() + - The value is some kind of ASN.1 string; there are far too many + encoding options options, most of which are either strongly + discouraged or outright forbidden by the PKIX profile, but which + persist for historical reasons. The only ones PKIX actually + likes are PrintableString and UTF8String, but there are nuances + and special cases where some of the others are required. - def get_SKI(self): - """ - Get the SKI extension from this object, if supported. - """ - - return self.get_POW().getSKI() + The RPKI profile further restricts DNs to a single mandatory + CommonName attribute with a single optional SerialNumber attribute + (not to be confused with the certificate serial number). - def get_EKU(self): + BPKI certificates should (we hope) follow the general PKIX guideline + but the ones we construct ourselves are likely to be relatively + simple. """ - Get the Extended Key Usage extension from this object, if supported. - """ - - return self.get_POW().getEKU() - def get_SIA(self): - """ - Get the SIA extension from this object. Only works for subclasses - that support getSIA(). - """ + def __str__(self): + return "".join("/" + "+".join("%s=%s" % (rpki.oids.oid2name(a[0]), a[1]) + for a in rdn) + for rdn in self.dn) - return self.get_POW().getSIA() + def __cmp__(self, other): + return cmp(self.dn, other.dn) - def get_sia_directory_uri(self): - """ - Get SIA directory (id-ad-caRepository) URI from this object. - Only works for subclasses that support getSIA(). - """ + def __repr__(self): + return rpki.log.log_repr(self, str(self)) - sia = self.get_POW().getSIA() - return None if sia is None else first_rsync_uri(sia[0]) + def _debug(self): + logger.debug("++ %r %r", self, self.dn) - def get_sia_manifest_uri(self): - """ - Get SIA manifest (id-ad-rpkiManifest) URI from this object. - Only works for subclasses that support getSIA(). - """ + @classmethod + def from_cn(cls, cn, sn = None): + assert isinstance(cn, (str, unicode)) + if isinstance(sn, (int, long)): + sn = "%08X" % sn + elif isinstance(sn, (str, unicode)): + assert all(c in "0123456789abcdefABCDEF" for c in sn) + sn = str(sn) + self = cls() + if sn is not None: + self.dn = (((rpki.oids.commonName, cn),), ((rpki.oids.serialNumber, sn),)) + else: + self.dn = (((rpki.oids.commonName, cn),),) + return self - sia = self.get_POW().getSIA() - return None if sia is None else first_rsync_uri(sia[1]) + @classmethod + def from_POW(cls, t): + assert isinstance(t, tuple) + self = cls() + self.dn = t + return self - def get_sia_object_uri(self): - """ - Get SIA object (id-ad-signedObject) URI from this object. - Only works for subclasses that support getSIA(). - """ + def get_POW(self): + return self.dn - sia = self.get_POW().getSIA() - return None if sia is None else first_rsync_uri(sia[2]) + def extract_cn_and_sn(self): + cn = None + sn = None - def get_sia_rrdp_notify(self): - """ - Get SIA RRDP (id-ad-rpkiNotify) URI from this object. - We prefer HTTPS over HTTP if both are present. - Only works for subclasses that support getSIA(). - """ + for rdn in self.dn: + if len(rdn) == 1 and len(rdn[0]) == 2: + oid = rdn[0][0] + val = rdn[0][1] + if oid == rpki.oids.commonName and cn is None: + cn = val + continue + if oid == rpki.oids.serialNumber and sn is None: + sn = val + continue + raise rpki.exceptions.BadX510DN("Bad subject name: %s" % (self.dn,)) - sia = self.get_POW().getSIA() - return None if sia is None else first_https_uri(sia[3]) or first_http_uri(sia[3]) + if cn is None: + raise rpki.exceptions.BadX510DN("Subject name is missing CN: %s" % (self.dn,)) - def get_AIA(self): - """ - Get the SIA extension from this object. Only works for subclasses - that support getAIA(). - """ + return cn, sn - return self.get_POW().getAIA() - def get_aia_uri(self): +class DER_object(object): """ - Get AIA (id-ad-caIssuers) URI from this object. - Only works for subclasses that support getAIA(). + Virtual class to hold a generic DER object. """ - return first_rsync_uri(self.get_POW().getAIA()) + ## @var formats + # Formats supported in this object. This is kind of redundant now + # that we're down to a single ASN.1 package and everything supports + # the same DER and POW formats, it's mostly historical baggage from + # the days when we had three different ASN.1 encoders, each with its + # own low-level Python object format. Clean up, some day. + formats = ("DER", "POW") - def get_basicConstraints(self): - """ - Get the basicConstraints extension from this object. Only works - for subclasses that support getExtension(). - """ + ## @var POW_class + # Class of underlying POW object. Concrete subclasses must supply this. + POW_class = None - return self.get_POW().getBasicConstraints() + ## Other attributes that self.clear() should whack. + other_clear = () - def is_CA(self): - """ - Return True if and only if object has the basicConstraints - extension and its cA value is true. - """ + ## @var DER + # DER value of this object + DER = None - basicConstraints = self.get_basicConstraints() - return basicConstraints is not None and basicConstraints[0] + ## @var failure_threshold + # Rate-limiting interval between whines about Auto_update objects. + failure_threshold = rpki.sundial.timedelta(minutes = 5) - def get_3779resources(self): - """ - Get RFC 3779 resources as rpki.resource_set objects. - """ + def empty(self): + """ + Test whether this object is empty. + """ - resources = rpki.resource_set.resource_bag.from_POW_rfc3779(self.get_POW().getRFC3779()) - try: - resources.valid_until = self.getNotAfter() - except AttributeError: - pass - return resources + return all(getattr(self, a, None) is None for a in self.formats) - @classmethod - def from_sql(cls, x): - """ - Convert from SQL storage format. - """ + def clear(self): + """ + Make this object empty. + """ - return cls(DER = x) + for a in self.formats + self.other_clear: + setattr(self, a, None) + self.filename = None + self.timestamp = None + self.lastfail = None - def to_sql(self): - """ - Convert to SQL storage format. - """ + def __init__(self, **kw): + """ + Initialize a DER_object. + """ - return self.get_DER() + self.clear() + if len(kw): + self.set(**kw) + + def set(self, **kw): + """ + Set this object by setting one of its known formats. + + This method only allows one to set one format at a time. + Subsequent calls will clear the object first. The point of all + this is to let the object's internal converters handle mustering + the object into whatever format you need at the moment. + """ + + if len(kw) == 1: + name = kw.keys()[0] + if name in self.formats: + self.clear() + setattr(self, name, kw[name]) + return + if name == "PEM": + self.clear() + self._set_PEM(kw[name]) + return + if name == "Base64": + self.clear() + self.DER = base64.b64decode(kw[name]) + return + if name == "Auto_update": + self.filename = kw[name] + self.check_auto_update() + return + if name in ("PEM_file", "DER_file", "Auto_file"): + f = open(kw[name], "rb") + value = f.read() + f.close() + self.clear() + if name == "PEM_file" or (name == "Auto_file" and looks_like_PEM(value)): + self._set_PEM(value) + else: + self.DER = value + return + raise rpki.exceptions.DERObjectConversionError("Can't honor conversion request %r" % (kw,)) + + def check_auto_update(self): + """ + Check for updates to a DER object that auto-updates from a file. + """ + + if self.filename is None: + return + try: + filename = self.filename + timestamp = os.stat(self.filename).st_mtime + if self.timestamp is None or self.timestamp < timestamp: + logger.debug("Updating %s, timestamp %s", + filename, rpki.sundial.datetime.fromtimestamp(timestamp)) + f = open(filename, "rb") + value = f.read() + f.close() + self.clear() + if looks_like_PEM(value): + self._set_PEM(value) + else: + self.DER = value + self.filename = filename + self.timestamp = timestamp + except (IOError, OSError), e: + now = rpki.sundial.now() + if self.lastfail is None or now > self.lastfail + self.failure_threshold: + logger.warning("Could not auto_update %r (last failure %s): %s", self, self.lastfail, e) + self.lastfail = now + else: + self.lastfail = None - def dumpasn1(self): - """ - Pretty print an ASN.1 DER object using cryptlib dumpasn1 tool. - Use a temporary file rather than popen4() because dumpasn1 uses - seek() when decoding ASN.1 content nested in OCTET STRING values. - """ + @property + def mtime(self): + """ + Retrieve os.stat().st_mtime for auto-update files. + """ - ret = None - fn = "dumpasn1.%d.tmp" % os.getpid() - try: - f = open(fn, "wb") - f.write(self.get_DER()) - f.close() - p = subprocess.Popen(("dumpasn1", "-a", fn), stdout = subprocess.PIPE, stderr = subprocess.STDOUT) - ret = "\n".join(x for x in p.communicate()[0].splitlines() if x.startswith(" ")) - except Exception, e: - ret = "[Could not run dumpasn1: %s]" % e - finally: - os.unlink(fn) - return ret - - def tracking_data(self, uri): - """ - Return a string containing data we want to log when tracking how - objects move through the RPKI system. Subclasses may wrap this to - provide more information, but should make sure to include at least - this information at the start of the tracking line. - """ + return os.stat(self.filename).st_mtime - try: - return "%s %s %s" % (uri, - self.creation_timestamp, - "".join(("%02X" % ord(b) for b in sha1(self.get_DER())))) - except: # pylint: disable=W0702 - return uri + def check(self): + """ + Perform basic checks on a DER object. + """ - def __getstate__(self): - """ - Pickling protocol -- pickle the DER encoding. - """ + self.check_auto_update() + assert not self.empty() + + def _set_PEM(self, pem): + """ + Set the POW value of this object based on a PEM input value. + Subclasses may need to override this. + """ + + assert self.empty() + self.POW = self.POW_class.pemRead(pem) + + def get_DER(self): + """ + Get the DER value of this object. + Subclasses may need to override this method. + """ + + self.check() + if self.DER: + return self.DER + if self.POW: + self.DER = self.POW.derWrite() + return self.get_DER() + raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + + def get_POW(self): + """ + Get the rpki.POW value of this object. + Subclasses may need to override this method. + """ + + self.check() + if not self.POW: # pylint: disable=E0203 + self.POW = self.POW_class.derRead(self.get_DER()) + return self.POW + + def get_Base64(self): + """ + Get the Base64 encoding of the DER value of this object. + """ + + return base64_with_linebreaks(self.get_DER()) + + def get_PEM(self): + """ + Get the PEM representation of this object. + """ + + return self.get_POW().pemWrite() + + def __cmp__(self, other): + """ + Compare two DER-encoded objects. + """ + + if self is None and other is None: + return 0 + elif self is None: + return -1 + elif other is None: + return 1 + elif isinstance(other, str): + return cmp(self.get_DER(), other) + else: + return cmp(self.get_DER(), other.get_DER()) - return self.get_DER() + def hSKI(self): + """ + Return hexadecimal string representation of SKI for this object. + Only work for subclasses that implement get_SKI(). + """ + + ski = self.get_SKI() + return ":".join(("%02X" % ord(i) for i in ski)) if ski else "" - def __setstate__(self, state): - """ - Pickling protocol -- unpickle the DER encoding. - """ + def gSKI(self): + """ + Calculate g(SKI) for this object. Only work for subclasses + that implement get_SKI(). + """ + + return base64.urlsafe_b64encode(self.get_SKI()).rstrip("=") - self.set(DER = state) + def hAKI(self): + """ + Return hexadecimal string representation of AKI for this + object. Only work for subclasses that implement get_AKI(). + """ + + aki = self.get_AKI() + return ":".join(("%02X" % ord(i) for i in aki)) if aki else "" + + def gAKI(self): + """ + Calculate g(AKI) for this object. Only work for subclasses + that implement get_AKI(). + """ + + return base64.urlsafe_b64encode(self.get_AKI()).rstrip("=") + + def get_AKI(self): + """ + Get the AKI extension from this object, if supported. + """ + + return self.get_POW().getAKI() + + def get_SKI(self): + """ + Get the SKI extension from this object, if supported. + """ + + return self.get_POW().getSKI() + + def get_EKU(self): + """ + Get the Extended Key Usage extension from this object, if supported. + """ + + return self.get_POW().getEKU() + + def get_SIA(self): + """ + Get the SIA extension from this object. Only works for subclasses + that support getSIA(). + """ + + return self.get_POW().getSIA() + + def get_sia_directory_uri(self): + """ + Get SIA directory (id-ad-caRepository) URI from this object. + Only works for subclasses that support getSIA(). + """ + + sia = self.get_POW().getSIA() + return None if sia is None else first_rsync_uri(sia[0]) + + def get_sia_manifest_uri(self): + """ + Get SIA manifest (id-ad-rpkiManifest) URI from this object. + Only works for subclasses that support getSIA(). + """ + + sia = self.get_POW().getSIA() + return None if sia is None else first_rsync_uri(sia[1]) + + def get_sia_object_uri(self): + """ + Get SIA object (id-ad-signedObject) URI from this object. + Only works for subclasses that support getSIA(). + """ + + sia = self.get_POW().getSIA() + return None if sia is None else first_rsync_uri(sia[2]) + + def get_sia_rrdp_notify(self): + """ + Get SIA RRDP (id-ad-rpkiNotify) URI from this object. + We prefer HTTPS over HTTP if both are present. + Only works for subclasses that support getSIA(). + """ + + sia = self.get_POW().getSIA() + return None if sia is None else first_https_uri(sia[3]) or first_http_uri(sia[3]) + + def get_AIA(self): + """ + Get the SIA extension from this object. Only works for subclasses + that support getAIA(). + """ + + return self.get_POW().getAIA() + + def get_aia_uri(self): + """ + Get AIA (id-ad-caIssuers) URI from this object. + Only works for subclasses that support getAIA(). + """ + + return first_rsync_uri(self.get_POW().getAIA()) + + def get_basicConstraints(self): + """ + Get the basicConstraints extension from this object. Only works + for subclasses that support getExtension(). + """ + + return self.get_POW().getBasicConstraints() + + def is_CA(self): + """ + Return True if and only if object has the basicConstraints + extension and its cA value is true. + """ + + basicConstraints = self.get_basicConstraints() + return basicConstraints is not None and basicConstraints[0] + + def get_3779resources(self): + """ + Get RFC 3779 resources as rpki.resource_set objects. + """ + + resources = rpki.resource_set.resource_bag.from_POW_rfc3779(self.get_POW().getRFC3779()) + try: + resources.valid_until = self.getNotAfter() + except AttributeError: + pass + return resources + + @classmethod + def from_sql(cls, x): + """ + Convert from SQL storage format. + """ + + return cls(DER = x) + + def to_sql(self): + """ + Convert to SQL storage format. + """ + + return self.get_DER() + + def dumpasn1(self): + """ + Pretty print an ASN.1 DER object using cryptlib dumpasn1 tool. + Use a temporary file rather than popen4() because dumpasn1 uses + seek() when decoding ASN.1 content nested in OCTET STRING values. + """ + + ret = None + fn = "dumpasn1.%d.tmp" % os.getpid() + try: + f = open(fn, "wb") + f.write(self.get_DER()) + f.close() + p = subprocess.Popen(("dumpasn1", "-a", fn), stdout = subprocess.PIPE, stderr = subprocess.STDOUT) + ret = "\n".join(x for x in p.communicate()[0].splitlines() if x.startswith(" ")) + except Exception, e: + ret = "[Could not run dumpasn1: %s]" % e + finally: + os.unlink(fn) + return ret + + def tracking_data(self, uri): + """ + Return a string containing data we want to log when tracking how + objects move through the RPKI system. Subclasses may wrap this to + provide more information, but should make sure to include at least + this information at the start of the tracking line. + """ + + try: + return "%s %s %s" % (uri, + self.creation_timestamp, + "".join(("%02X" % ord(b) for b in sha1(self.get_DER())))) + except: # pylint: disable=W0702 + return uri + + def __getstate__(self): + """ + Pickling protocol -- pickle the DER encoding. + """ + + return self.get_DER() + + def __setstate__(self, state): + """ + Pickling protocol -- unpickle the DER encoding. + """ + + self.set(DER = state) class X509(DER_object): - """ - X.509 certificates. - - This class is designed to hold all the different representations of - X.509 certs we're using and convert between them. X.509 support in - Python a nasty maze of half-cooked stuff (except perhaps for - cryptlib, which is just different). Users of this module should not - have to care about this implementation nightmare. - """ - - POW_class = rpki.POW.X509 - - def getIssuer(self): - """ - Get the issuer of this certificate. - """ - - return X501DN.from_POW(self.get_POW().getIssuer()) - - def getSubject(self): - """ - Get the subject of this certificate. """ + X.509 certificates. - return X501DN.from_POW(self.get_POW().getSubject()) - - def getNotBefore(self): - """ - Get the inception time of this certificate. - """ - - return self.get_POW().getNotBefore() - - def getNotAfter(self): + This class is designed to hold all the different representations of + X.509 certs we're using and convert between them. X.509 support in + Python a nasty maze of half-cooked stuff (except perhaps for + cryptlib, which is just different). Users of this module should not + have to care about this implementation nightmare. """ - Get the expiration time of this certificate. - """ - - return self.get_POW().getNotAfter() - - def getSerial(self): - """ - Get the serial number of this certificate. - """ - - return self.get_POW().getSerial() - - def getPublicKey(self): - """ - Extract the public key from this certificate. - """ - - return PublicKey(POW = self.get_POW().getPublicKey()) - - def get_SKI(self): - """ - Get the SKI extension from this object. - """ - - return self.get_POW().getSKI() - - def expired(self): - """ - Test whether this certificate has expired. - """ - - return self.getNotAfter() <= rpki.sundial.now() - - def issue(self, keypair, subject_key, serial, sia, aia, crldp, notAfter, - cn = None, resources = None, is_ca = True, notBefore = None, - sn = None, eku = None): - """ - Issue an RPKI certificate. - """ - - assert aia is not None and crldp is not None - - assert eku is None or not is_ca - - return self._issue( - keypair = keypair, - subject_key = subject_key, - serial = serial, - sia = sia, - aia = aia, - crldp = crldp, - notBefore = notBefore, - notAfter = notAfter, - cn = cn, - sn = sn, - resources = resources, - is_ca = is_ca, - aki = self.get_SKI(), - issuer_name = self.getSubject(), - eku = eku) - - - @classmethod - def self_certify(cls, keypair, subject_key, serial, sia, notAfter, - cn = None, resources = None, notBefore = None, - sn = None): - """ - Generate a self-certified RPKI certificate. - """ - - ski = subject_key.get_SKI() - - if cn is None: - cn = "".join(("%02X" % ord(i) for i in ski)) - - return cls._issue( - keypair = keypair, - subject_key = subject_key, - serial = serial, - sia = sia, - aia = None, - crldp = None, - notBefore = notBefore, - notAfter = notAfter, - cn = cn, - sn = sn, - resources = resources, - is_ca = True, - aki = ski, - issuer_name = X501DN.from_cn(cn, sn), - eku = None) - - - @classmethod - def _issue(cls, keypair, subject_key, serial, sia, aia, crldp, notAfter, - cn, sn, resources, is_ca, aki, issuer_name, notBefore, eku): - """ - Common code to issue an RPKI certificate. - """ - - if not sia or len(sia) != 4 or not sia[3]: - logger.debug("Oops! _issue() sia: %r", sia) - rpki.log.show_stack(logger) - - now = rpki.sundial.now() - ski = subject_key.get_SKI() - - if notBefore is None: - notBefore = now - - if cn is None: - cn = "".join(("%02X" % ord(i) for i in ski)) - - if now >= notAfter: - raise rpki.exceptions.PastNotAfter("notAfter value %s is already in the past" % notAfter) - - if notBefore >= notAfter: - raise rpki.exceptions.NullValidityInterval("notAfter value %s predates notBefore value %s" % - (notAfter, notBefore)) - - cert = rpki.POW.X509() - - cert.setVersion(2) - cert.setSerial(serial) - cert.setIssuer(issuer_name.get_POW()) - cert.setSubject(X501DN.from_cn(cn, sn).get_POW()) - cert.setNotBefore(notBefore) - cert.setNotAfter(notAfter) - cert.setPublicKey(subject_key.get_POW()) - cert.setSKI(ski) - cert.setAKI(aki) - cert.setCertificatePolicies((rpki.oids.id_cp_ipAddr_asNumber,)) - - if crldp is not None: - cert.setCRLDP((crldp,)) - - if aia is not None: - cert.setAIA((aia,)) - - if is_ca: - cert.setBasicConstraints(True, None) - cert.setKeyUsage(frozenset(("keyCertSign", "cRLSign"))) - - else: - cert.setKeyUsage(frozenset(("digitalSignature",))) - - assert sia is not None or not is_ca - - if sia is not None: - caRepository, rpkiManifest, signedObject, rpkiNotify = sia - cert.setSIA( - (caRepository,) if isinstance(caRepository, str) else caRepository, - (rpkiManifest,) if isinstance(rpkiManifest, str) else rpkiManifest, - (signedObject,) if isinstance(signedObject, str) else signedObject, - (rpkiNotify,) if isinstance(rpkiNotify, str) else rpkiNotify) - - if resources is not None: - cert.setRFC3779( - asn = ("inherit" if resources.asn.inherit else - ((r.min, r.max) for r in resources.asn)), - ipv4 = ("inherit" if resources.v4.inherit else - ((r.min, r.max) for r in resources.v4)), - ipv6 = ("inherit" if resources.v6.inherit else - ((r.min, r.max) for r in resources.v6))) - - if eku is not None: - assert not is_ca - cert.setEKU(eku) - cert.sign(keypair.get_POW(), rpki.POW.SHA256_DIGEST) + POW_class = rpki.POW.X509 - return cls(POW = cert) + def getIssuer(self): + """ + Get the issuer of this certificate. + """ - def bpki_cross_certify(self, keypair, source_cert, serial, notAfter, - now = None, pathLenConstraint = 0): - """ - Issue a BPKI certificate with values taking from an existing certificate. - """ + return X501DN.from_POW(self.get_POW().getIssuer()) - return self.bpki_certify( - keypair = keypair, - subject_name = source_cert.getSubject(), - subject_key = source_cert.getPublicKey(), - serial = serial, - notAfter = notAfter, - now = now, - pathLenConstraint = pathLenConstraint, - is_ca = True) - - @classmethod - def bpki_self_certify(cls, keypair, subject_name, serial, notAfter, - now = None, pathLenConstraint = None): - """ - Issue a self-signed BPKI CA certificate. - """ + def getSubject(self): + """ + Get the subject of this certificate. + """ + + return X501DN.from_POW(self.get_POW().getSubject()) + + def getNotBefore(self): + """ + Get the inception time of this certificate. + """ + + return self.get_POW().getNotBefore() + + def getNotAfter(self): + """ + Get the expiration time of this certificate. + """ + + return self.get_POW().getNotAfter() + + def getSerial(self): + """ + Get the serial number of this certificate. + """ + + return self.get_POW().getSerial() + + def getPublicKey(self): + """ + Extract the public key from this certificate. + """ + + return PublicKey(POW = self.get_POW().getPublicKey()) + + def get_SKI(self): + """ + Get the SKI extension from this object. + """ - return cls._bpki_certify( - keypair = keypair, - issuer_name = subject_name, - subject_name = subject_name, - subject_key = keypair.get_public(), - serial = serial, - now = now, - notAfter = notAfter, - pathLenConstraint = pathLenConstraint, - is_ca = True) - - def bpki_certify(self, keypair, subject_name, subject_key, serial, notAfter, is_ca, - now = None, pathLenConstraint = None): - """ - Issue a normal BPKI certificate. - """ + return self.get_POW().getSKI() + + def expired(self): + """ + Test whether this certificate has expired. + """ + + return self.getNotAfter() <= rpki.sundial.now() + + def issue(self, keypair, subject_key, serial, sia, aia, crldp, notAfter, + cn = None, resources = None, is_ca = True, notBefore = None, + sn = None, eku = None): + """ + Issue an RPKI certificate. + """ + + assert aia is not None and crldp is not None + + assert eku is None or not is_ca + + return self._issue( + keypair = keypair, + subject_key = subject_key, + serial = serial, + sia = sia, + aia = aia, + crldp = crldp, + notBefore = notBefore, + notAfter = notAfter, + cn = cn, + sn = sn, + resources = resources, + is_ca = is_ca, + aki = self.get_SKI(), + issuer_name = self.getSubject(), + eku = eku) + + + @classmethod + def self_certify(cls, keypair, subject_key, serial, sia, notAfter, + cn = None, resources = None, notBefore = None, + sn = None): + """ + Generate a self-certified RPKI certificate. + """ + + ski = subject_key.get_SKI() + + if cn is None: + cn = "".join(("%02X" % ord(i) for i in ski)) + + return cls._issue( + keypair = keypair, + subject_key = subject_key, + serial = serial, + sia = sia, + aia = None, + crldp = None, + notBefore = notBefore, + notAfter = notAfter, + cn = cn, + sn = sn, + resources = resources, + is_ca = True, + aki = ski, + issuer_name = X501DN.from_cn(cn, sn), + eku = None) + + + @classmethod + def _issue(cls, keypair, subject_key, serial, sia, aia, crldp, notAfter, + cn, sn, resources, is_ca, aki, issuer_name, notBefore, eku): + """ + Common code to issue an RPKI certificate. + """ + + if not sia or len(sia) != 4 or not sia[3]: + logger.debug("Oops! _issue() sia: %r", sia) + rpki.log.show_stack(logger) + + now = rpki.sundial.now() + ski = subject_key.get_SKI() + + if notBefore is None: + notBefore = now + + if cn is None: + cn = "".join(("%02X" % ord(i) for i in ski)) + + if now >= notAfter: + raise rpki.exceptions.PastNotAfter("notAfter value %s is already in the past" % notAfter) + + if notBefore >= notAfter: + raise rpki.exceptions.NullValidityInterval("notAfter value %s predates notBefore value %s" % + (notAfter, notBefore)) + + cert = rpki.POW.X509() + + cert.setVersion(2) + cert.setSerial(serial) + cert.setIssuer(issuer_name.get_POW()) + cert.setSubject(X501DN.from_cn(cn, sn).get_POW()) + cert.setNotBefore(notBefore) + cert.setNotAfter(notAfter) + cert.setPublicKey(subject_key.get_POW()) + cert.setSKI(ski) + cert.setAKI(aki) + cert.setCertificatePolicies((rpki.oids.id_cp_ipAddr_asNumber,)) + + if crldp is not None: + cert.setCRLDP((crldp,)) + + if aia is not None: + cert.setAIA((aia,)) + + if is_ca: + cert.setBasicConstraints(True, None) + cert.setKeyUsage(frozenset(("keyCertSign", "cRLSign"))) - assert keypair.get_public() == self.getPublicKey() - return self._bpki_certify( - keypair = keypair, - issuer_name = self.getSubject(), - subject_name = subject_name, - subject_key = subject_key, - serial = serial, - now = now, - notAfter = notAfter, - pathLenConstraint = pathLenConstraint, - is_ca = is_ca) - - @classmethod - def _bpki_certify(cls, keypair, issuer_name, subject_name, subject_key, - serial, now, notAfter, pathLenConstraint, is_ca): - """ - Issue a BPKI certificate. This internal method does the real - work, after one of the wrapper methods has extracted the relevant - fields. - """ + else: + cert.setKeyUsage(frozenset(("digitalSignature",))) + + assert sia is not None or not is_ca + + if sia is not None: + caRepository, rpkiManifest, signedObject, rpkiNotify = sia + cert.setSIA( + (caRepository,) if isinstance(caRepository, str) else caRepository, + (rpkiManifest,) if isinstance(rpkiManifest, str) else rpkiManifest, + (signedObject,) if isinstance(signedObject, str) else signedObject, + (rpkiNotify,) if isinstance(rpkiNotify, str) else rpkiNotify) + + if resources is not None: + cert.setRFC3779( + asn = ("inherit" if resources.asn.inherit else + ((r.min, r.max) for r in resources.asn)), + ipv4 = ("inherit" if resources.v4.inherit else + ((r.min, r.max) for r in resources.v4)), + ipv6 = ("inherit" if resources.v6.inherit else + ((r.min, r.max) for r in resources.v6))) + + if eku is not None: + assert not is_ca + cert.setEKU(eku) + + cert.sign(keypair.get_POW(), rpki.POW.SHA256_DIGEST) + + return cls(POW = cert) + + def bpki_cross_certify(self, keypair, source_cert, serial, notAfter, + now = None, pathLenConstraint = 0): + """ + Issue a BPKI certificate with values taking from an existing certificate. + """ + + return self.bpki_certify( + keypair = keypair, + subject_name = source_cert.getSubject(), + subject_key = source_cert.getPublicKey(), + serial = serial, + notAfter = notAfter, + now = now, + pathLenConstraint = pathLenConstraint, + is_ca = True) + + @classmethod + def bpki_self_certify(cls, keypair, subject_name, serial, notAfter, + now = None, pathLenConstraint = None): + """ + Issue a self-signed BPKI CA certificate. + """ + + return cls._bpki_certify( + keypair = keypair, + issuer_name = subject_name, + subject_name = subject_name, + subject_key = keypair.get_public(), + serial = serial, + now = now, + notAfter = notAfter, + pathLenConstraint = pathLenConstraint, + is_ca = True) + + def bpki_certify(self, keypair, subject_name, subject_key, serial, notAfter, is_ca, + now = None, pathLenConstraint = None): + """ + Issue a normal BPKI certificate. + """ + + assert keypair.get_public() == self.getPublicKey() + return self._bpki_certify( + keypair = keypair, + issuer_name = self.getSubject(), + subject_name = subject_name, + subject_key = subject_key, + serial = serial, + now = now, + notAfter = notAfter, + pathLenConstraint = pathLenConstraint, + is_ca = is_ca) + + @classmethod + def _bpki_certify(cls, keypair, issuer_name, subject_name, subject_key, + serial, now, notAfter, pathLenConstraint, is_ca): + """ + Issue a BPKI certificate. This internal method does the real + work, after one of the wrapper methods has extracted the relevant + fields. + """ + + if now is None: + now = rpki.sundial.now() + + issuer_key = keypair.get_public() + + assert (issuer_key == subject_key) == (issuer_name == subject_name) + assert is_ca or issuer_name != subject_name + assert is_ca or pathLenConstraint is None + assert pathLenConstraint is None or (isinstance(pathLenConstraint, (int, long)) and + pathLenConstraint >= 0) + + cert = rpki.POW.X509() + cert.setVersion(2) + cert.setSerial(serial) + cert.setIssuer(issuer_name.get_POW()) + cert.setSubject(subject_name.get_POW()) + cert.setNotBefore(now) + cert.setNotAfter(notAfter) + cert.setPublicKey(subject_key.get_POW()) + cert.setSKI(subject_key.get_POW().calculateSKI()) + if issuer_key != subject_key: + cert.setAKI(issuer_key.get_POW().calculateSKI()) + if is_ca: + cert.setBasicConstraints(True, pathLenConstraint) + cert.sign(keypair.get_POW(), rpki.POW.SHA256_DIGEST) + return cls(POW = cert) + + @classmethod + def normalize_chain(cls, chain): + """ + Normalize a chain of certificates into a tuple of X509 objects. + Given all the glue certificates needed for BPKI cross + certification, it's easiest to allow sloppy arguments to the CMS + validation methods and provide a single method that normalizes the + allowed cases. So this method allows X509, None, lists, and + tuples, and returns a tuple of X509 objects. + """ + + if isinstance(chain, cls): + chain = (chain,) + return tuple(x for x in chain if x is not None) + + @property + def creation_timestamp(self): + """ + Time at which this object was created. + """ + + return self.getNotBefore() - if now is None: - now = rpki.sundial.now() - - issuer_key = keypair.get_public() - - assert (issuer_key == subject_key) == (issuer_name == subject_name) - assert is_ca or issuer_name != subject_name - assert is_ca or pathLenConstraint is None - assert pathLenConstraint is None or (isinstance(pathLenConstraint, (int, long)) and - pathLenConstraint >= 0) - - cert = rpki.POW.X509() - cert.setVersion(2) - cert.setSerial(serial) - cert.setIssuer(issuer_name.get_POW()) - cert.setSubject(subject_name.get_POW()) - cert.setNotBefore(now) - cert.setNotAfter(notAfter) - cert.setPublicKey(subject_key.get_POW()) - cert.setSKI(subject_key.get_POW().calculateSKI()) - if issuer_key != subject_key: - cert.setAKI(issuer_key.get_POW().calculateSKI()) - if is_ca: - cert.setBasicConstraints(True, pathLenConstraint) - cert.sign(keypair.get_POW(), rpki.POW.SHA256_DIGEST) - return cls(POW = cert) - - @classmethod - def normalize_chain(cls, chain): +class PKCS10(DER_object): """ - Normalize a chain of certificates into a tuple of X509 objects. - Given all the glue certificates needed for BPKI cross - certification, it's easiest to allow sloppy arguments to the CMS - validation methods and provide a single method that normalizes the - allowed cases. So this method allows X509, None, lists, and - tuples, and returns a tuple of X509 objects. + Class to hold a PKCS #10 request. """ - if isinstance(chain, cls): - chain = (chain,) - return tuple(x for x in chain if x is not None) - - @property - def creation_timestamp(self): - """ - Time at which this object was created. - """ + POW_class = rpki.POW.PKCS10 - return self.getNotBefore() + ## @var expected_ca_keyUsage + # KeyUsage extension flags expected for CA requests. -class PKCS10(DER_object): - """ - Class to hold a PKCS #10 request. - """ + expected_ca_keyUsage = frozenset(("keyCertSign", "cRLSign")) - POW_class = rpki.POW.PKCS10 + ## @var allowed_extensions + # Extensions allowed by RPKI profile. - ## @var expected_ca_keyUsage - # KeyUsage extension flags expected for CA requests. + allowed_extensions = frozenset((rpki.oids.basicConstraints, + rpki.oids.keyUsage, + rpki.oids.subjectInfoAccess, + rpki.oids.extendedKeyUsage)) - expected_ca_keyUsage = frozenset(("keyCertSign", "cRLSign")) - ## @var allowed_extensions - # Extensions allowed by RPKI profile. + def get_DER(self): + """ + Get the DER value of this certification request. + """ - allowed_extensions = frozenset((rpki.oids.basicConstraints, - rpki.oids.keyUsage, - rpki.oids.subjectInfoAccess, - rpki.oids.extendedKeyUsage)) + self.check() + if self.DER: + return self.DER + if self.POW: + self.DER = self.POW.derWrite() + return self.get_DER() + raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + def get_POW(self): + """ + Get the rpki.POW value of this certification request. + """ - def get_DER(self): - """ - Get the DER value of this certification request. - """ + self.check() + if not self.POW: # pylint: disable=E0203 + self.POW = rpki.POW.PKCS10.derRead(self.get_DER()) + return self.POW - self.check() - if self.DER: - return self.DER - if self.POW: - self.DER = self.POW.derWrite() - return self.get_DER() - raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + def getSubject(self): + """ + Extract the subject name from this certification request. + """ - def get_POW(self): - """ - Get the rpki.POW value of this certification request. - """ + return X501DN.from_POW(self.get_POW().getSubject()) - self.check() - if not self.POW: # pylint: disable=E0203 - self.POW = rpki.POW.PKCS10.derRead(self.get_DER()) - return self.POW + def getPublicKey(self): + """ + Extract the public key from this certification request. + """ - def getSubject(self): - """ - Extract the subject name from this certification request. - """ + return PublicKey(POW = self.get_POW().getPublicKey()) - return X501DN.from_POW(self.get_POW().getSubject()) + def get_SKI(self): + """ + Compute SKI for public key from this certification request. + """ - def getPublicKey(self): - """ - Extract the public key from this certification request. - """ + return self.getPublicKey().get_SKI() - return PublicKey(POW = self.get_POW().getPublicKey()) - def get_SKI(self): - """ - Compute SKI for public key from this certification request. - """ + def check_valid_request_common(self): + """ + Common code for checking this certification requests to see + whether they conform to the RPKI certificate profile. - return self.getPublicKey().get_SKI() + Throws an exception if the request isn't valid, so if this method + returns at all, the request is ok. + You probably don't want to call this directly, as it only performs + the checks that are common to all RPKI certificates. + """ - def check_valid_request_common(self): - """ - Common code for checking this certification requests to see - whether they conform to the RPKI certificate profile. + if not self.get_POW().verify(): + raise rpki.exceptions.BadPKCS10("PKCS #10 signature check failed") - Throws an exception if the request isn't valid, so if this method - returns at all, the request is ok. + ver = self.get_POW().getVersion() - You probably don't want to call this directly, as it only performs - the checks that are common to all RPKI certificates. - """ + if ver != 0: + raise rpki.exceptions.BadPKCS10("PKCS #10 request has bad version number %s" % ver) - if not self.get_POW().verify(): - raise rpki.exceptions.BadPKCS10("PKCS #10 signature check failed") + ku = self.get_POW().getKeyUsage() - ver = self.get_POW().getVersion() + if ku is not None and self.expected_ca_keyUsage != ku: + raise rpki.exceptions.BadPKCS10("PKCS #10 keyUsage doesn't match profile: %r" % ku) - if ver != 0: - raise rpki.exceptions.BadPKCS10("PKCS #10 request has bad version number %s" % ver) + forbidden_extensions = self.get_POW().getExtensionOIDs() - self.allowed_extensions - ku = self.get_POW().getKeyUsage() + if forbidden_extensions: + raise rpki.exceptions.BadExtension("Forbidden extension%s in PKCS #10 certificate request: %s" % ( + "" if len(forbidden_extensions) == 1 else "s", + ", ".join(forbidden_extensions))) - if ku is not None and self.expected_ca_keyUsage != ku: - raise rpki.exceptions.BadPKCS10("PKCS #10 keyUsage doesn't match profile: %r" % ku) - forbidden_extensions = self.get_POW().getExtensionOIDs() - self.allowed_extensions + def check_valid_request_ca(self): + """ + Check this certification request to see whether it's a valid + request for an RPKI CA certificate. - if forbidden_extensions: - raise rpki.exceptions.BadExtension("Forbidden extension%s in PKCS #10 certificate request: %s" % ( - "" if len(forbidden_extensions) == 1 else "s", - ", ".join(forbidden_extensions))) + Throws an exception if the request isn't valid, so if this method + returns at all, the request is ok. + """ + self.check_valid_request_common() - def check_valid_request_ca(self): - """ - Check this certification request to see whether it's a valid - request for an RPKI CA certificate. + alg = self.get_POW().getSignatureAlgorithm() + bc = self.get_POW().getBasicConstraints() + eku = self.get_POW().getEKU() + sia = self.get_POW().getSIA() - Throws an exception if the request isn't valid, so if this method - returns at all, the request is ok. - """ + if alg != rpki.oids.sha256WithRSAEncryption: + raise rpki.exceptions.BadPKCS10("PKCS #10 has bad signature algorithm for CA: %s" % alg) - self.check_valid_request_common() + if bc is None or not bc[0] or bc[1] is not None: + raise rpki.exceptions.BadPKCS10("PKCS #10 CA bad basicConstraints") - alg = self.get_POW().getSignatureAlgorithm() - bc = self.get_POW().getBasicConstraints() - eku = self.get_POW().getEKU() - sia = self.get_POW().getSIA() + if eku is not None: + raise rpki.exceptions.BadPKCS10("PKCS #10 CA EKU not allowed") - if alg != rpki.oids.sha256WithRSAEncryption: - raise rpki.exceptions.BadPKCS10("PKCS #10 has bad signature algorithm for CA: %s" % alg) + if sia is None: + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA missing") - if bc is None or not bc[0] or bc[1] is not None: - raise rpki.exceptions.BadPKCS10("PKCS #10 CA bad basicConstraints") + caRepository, rpkiManifest, signedObject, rpkiNotify = sia - if eku is not None: - raise rpki.exceptions.BadPKCS10("PKCS #10 CA EKU not allowed") + logger.debug("check_valid_request_ca(): sia: %r", sia) - if sia is None: - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA missing") + if signedObject: + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA must not have id-ad-signedObject") - caRepository, rpkiManifest, signedObject, rpkiNotify = sia + if not caRepository: + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA must have id-ad-caRepository") - logger.debug("check_valid_request_ca(): sia: %r", sia) + if not any(uri.startswith("rsync://") for uri in caRepository): + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-caRepository contains no rsync URIs") - if signedObject: - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA must not have id-ad-signedObject") + if any(uri.startswith("rsync://") and not uri.endswith("/") for uri in caRepository): + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-caRepository does not end with slash") - if not caRepository: - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA must have id-ad-caRepository") + if not rpkiManifest: + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA must have id-ad-rpkiManifest") - if not any(uri.startswith("rsync://") for uri in caRepository): - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-caRepository contains no rsync URIs") + if not any(uri.startswith("rsync://") for uri in rpkiManifest): + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-rpkiManifest contains no rsync URIs") - if any(uri.startswith("rsync://") and not uri.endswith("/") for uri in caRepository): - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-caRepository does not end with slash") + if any(uri.startswith("rsync://") and uri.endswith("/") for uri in rpkiManifest): + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-rpkiManifest ends with slash") - if not rpkiManifest: - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA must have id-ad-rpkiManifest") + if any(not uri.startswith("http://") and not uri.startswith("https://") for uri in rpkiNotify): + raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-rpkiNotify neither HTTP nor HTTPS") - if not any(uri.startswith("rsync://") for uri in rpkiManifest): - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-rpkiManifest contains no rsync URIs") + def check_valid_request_ee(self): + """ + Check this certification request to see whether it's a valid + request for an RPKI EE certificate. - if any(uri.startswith("rsync://") and uri.endswith("/") for uri in rpkiManifest): - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-rpkiManifest ends with slash") + Throws an exception if the request isn't valid, so if this method + returns at all, the request is ok. - if any(not uri.startswith("http://") and not uri.startswith("https://") for uri in rpkiNotify): - raise rpki.exceptions.BadPKCS10("PKCS #10 CA SIA id-ad-rpkiNotify neither HTTP nor HTTPS") + We're a bit less strict here than we are for either CA + certificates or BGPSEC router certificates, because the profile is + less tightly nailed down for unspecified-use RPKI EE certificates. + Future specific purposes may impose tighter constraints. - def check_valid_request_ee(self): - """ - Check this certification request to see whether it's a valid - request for an RPKI EE certificate. - - Throws an exception if the request isn't valid, so if this method - returns at all, the request is ok. - - We're a bit less strict here than we are for either CA - certificates or BGPSEC router certificates, because the profile is - less tightly nailed down for unspecified-use RPKI EE certificates. - Future specific purposes may impose tighter constraints. - - Note that this method does NOT apply to so-called "infrastructure" - EE certificates (eg, the EE certificates embedded in manifests and - ROAs); those are constrained fairly tightly, but they're also - generated internally so we don't need to check them as user or - protocol input. - """ + Note that this method does NOT apply to so-called "infrastructure" + EE certificates (eg, the EE certificates embedded in manifests and + ROAs); those are constrained fairly tightly, but they're also + generated internally so we don't need to check them as user or + protocol input. + """ - self.check_valid_request_common() + self.check_valid_request_common() - alg = self.get_POW().getSignatureAlgorithm() - bc = self.get_POW().getBasicConstraints() - sia = self.get_POW().getSIA() + alg = self.get_POW().getSignatureAlgorithm() + bc = self.get_POW().getBasicConstraints() + sia = self.get_POW().getSIA() - logger.debug("check_valid_request_ee(): sia: %r", sia) + logger.debug("check_valid_request_ee(): sia: %r", sia) - caRepository, rpkiManifest, signedObject, rpkiNotify = sia or (None, None, None, None) + caRepository, rpkiManifest, signedObject, rpkiNotify = sia or (None, None, None, None) - if alg not in (rpki.oids.sha256WithRSAEncryption, rpki.oids.ecdsa_with_SHA256): - raise rpki.exceptions.BadPKCS10("PKCS #10 has bad signature algorithm for EE: %s" % alg) + if alg not in (rpki.oids.sha256WithRSAEncryption, rpki.oids.ecdsa_with_SHA256): + raise rpki.exceptions.BadPKCS10("PKCS #10 has bad signature algorithm for EE: %s" % alg) - if bc is not None and (bc[0] or bc[1] is not None): - raise rpki.exceptions.BadPKCS10("PKCS #10 EE has bad basicConstraints") + if bc is not None and (bc[0] or bc[1] is not None): + raise rpki.exceptions.BadPKCS10("PKCS #10 EE has bad basicConstraints") - if caRepository: - raise rpki.exceptions.BadPKCS10("PKCS #10 EE must not have id-ad-caRepository") + if caRepository: + raise rpki.exceptions.BadPKCS10("PKCS #10 EE must not have id-ad-caRepository") - if rpkiManifest: - raise rpki.exceptions.BadPKCS10("PKCS #10 EE must not have id-ad-rpkiManifest") + if rpkiManifest: + raise rpki.exceptions.BadPKCS10("PKCS #10 EE must not have id-ad-rpkiManifest") - if signedObject and not any(uri.startswith("rsync://") for uri in signedObject): - raise rpki.exceptions.BadPKCS10("PKCS #10 EE SIA id-ad-signedObject contains no rsync URIs") + if signedObject and not any(uri.startswith("rsync://") for uri in signedObject): + raise rpki.exceptions.BadPKCS10("PKCS #10 EE SIA id-ad-signedObject contains no rsync URIs") - if rpkiNotify and any(not uri.startswith("http://") and not uri.startswith("https://") for uri in rpkiNotify): - raise rpki.exceptions.BadPKCS10("PKCS #10 EE SIA id-ad-rpkiNotify neither HTTP nor HTTPS") + if rpkiNotify and any(not uri.startswith("http://") and not uri.startswith("https://") for uri in rpkiNotify): + raise rpki.exceptions.BadPKCS10("PKCS #10 EE SIA id-ad-rpkiNotify neither HTTP nor HTTPS") - def check_valid_request_router(self): - """ - Check this certification request to see whether it's a valid - request for a BGPSEC router certificate. + def check_valid_request_router(self): + """ + Check this certification request to see whether it's a valid + request for a BGPSEC router certificate. - Throws an exception if the request isn't valid, so if this method - returns at all, the request is ok. + Throws an exception if the request isn't valid, so if this method + returns at all, the request is ok. - draft-ietf-sidr-bgpsec-pki-profiles 3.2 says follow RFC 6487 3 - except where explicitly overriden, and does not override for SIA. - But draft-ietf-sidr-bgpsec-pki-profiles also says that router - certificates don't get SIA, while RFC 6487 requires SIA. So what - do we do with SIA in PKCS #10 for router certificates? + draft-ietf-sidr-bgpsec-pki-profiles 3.2 says follow RFC 6487 3 + except where explicitly overriden, and does not override for SIA. + But draft-ietf-sidr-bgpsec-pki-profiles also says that router + certificates don't get SIA, while RFC 6487 requires SIA. So what + do we do with SIA in PKCS #10 for router certificates? - For the moment, ignore it, but make sure we don't include it in - the certificate when we get to the code that generates that. - """ + For the moment, ignore it, but make sure we don't include it in + the certificate when we get to the code that generates that. + """ - self.check_valid_request_ee() + self.check_valid_request_ee() - alg = self.get_POW().getSignatureAlgorithm() - eku = self.get_POW().getEKU() + alg = self.get_POW().getSignatureAlgorithm() + eku = self.get_POW().getEKU() - if alg != rpki.oids.ecdsa_with_SHA256: - raise rpki.exceptions.BadPKCS10("PKCS #10 has bad signature algorithm for router: %s" % alg) + if alg != rpki.oids.ecdsa_with_SHA256: + raise rpki.exceptions.BadPKCS10("PKCS #10 has bad signature algorithm for router: %s" % alg) - # Not really clear to me whether PKCS #10 should have EKU or not, so allow - # either, but insist that it be the right one if present. + # Not really clear to me whether PKCS #10 should have EKU or not, so allow + # either, but insist that it be the right one if present. - if eku is not None and rpki.oids.id_kp_bgpsec_router not in eku: - raise rpki.exceptions.BadPKCS10("PKCS #10 router must have EKU") + if eku is not None and rpki.oids.id_kp_bgpsec_router not in eku: + raise rpki.exceptions.BadPKCS10("PKCS #10 router must have EKU") - @classmethod - def create(cls, keypair, exts = None, is_ca = False, - caRepository = None, rpkiManifest = None, signedObject = None, - cn = None, sn = None, eku = None, rpkiNotify = None): - """ - Create a new request for a given keypair. - """ + @classmethod + def create(cls, keypair, exts = None, is_ca = False, + caRepository = None, rpkiManifest = None, signedObject = None, + cn = None, sn = None, eku = None, rpkiNotify = None): + """ + Create a new request for a given keypair. + """ - if cn is None: - cn = "".join(("%02X" % ord(i) for i in keypair.get_SKI())) + if cn is None: + cn = "".join(("%02X" % ord(i) for i in keypair.get_SKI())) - req = rpki.POW.PKCS10() - req.setVersion(0) - req.setSubject(X501DN.from_cn(cn, sn).get_POW()) - req.setPublicKey(keypair.get_POW()) + req = rpki.POW.PKCS10() + req.setVersion(0) + req.setSubject(X501DN.from_cn(cn, sn).get_POW()) + req.setPublicKey(keypair.get_POW()) - if is_ca: - req.setBasicConstraints(True, None) - req.setKeyUsage(cls.expected_ca_keyUsage) + if is_ca: + req.setBasicConstraints(True, None) + req.setKeyUsage(cls.expected_ca_keyUsage) - sia = (caRepository, rpkiManifest, signedObject, rpkiNotify) - if not all(s is None for s in sia): - req.setSIA(*tuple([str(s)] if isinstance(s, (str, unicode)) else s for s in sia)) + sia = (caRepository, rpkiManifest, signedObject, rpkiNotify) + if not all(s is None for s in sia): + req.setSIA(*tuple([str(s)] if isinstance(s, (str, unicode)) else s for s in sia)) - if eku: - req.setEKU(eku) + if eku: + req.setEKU(eku) - req.sign(keypair.get_POW(), rpki.POW.SHA256_DIGEST) - return cls(POW = req) + req.sign(keypair.get_POW(), rpki.POW.SHA256_DIGEST) + return cls(POW = req) ## @var generate_insecure_debug_only_rsa_key # Debugging hack to let us save throwaway RSA keys from one debug @@ -1230,919 +1230,919 @@ generate_insecure_debug_only_rsa_key = None class insecure_debug_only_rsa_key_generator(object): - def __init__(self, filename, keyno = 0): - try: - try: - import gdbm as dbm_du_jour - except ImportError: - import dbm as dbm_du_jour - self.keyno = long(keyno) - self.filename = filename - self.db = dbm_du_jour.open(filename, "c") - except: - logger.warning("insecure_debug_only_rsa_key_generator initialization FAILED, hack inoperative") - raise - - def __call__(self): - k = str(self.keyno) - try: - v = rpki.POW.Asymmetric.derReadPrivate(self.db[k]) - except KeyError: - v = rpki.POW.Asymmetric.generateRSA(2048) - self.db[k] = v.derWritePrivate() - self.keyno += 1 - return v + def __init__(self, filename, keyno = 0): + try: + try: + import gdbm as dbm_du_jour + except ImportError: + import dbm as dbm_du_jour + self.keyno = long(keyno) + self.filename = filename + self.db = dbm_du_jour.open(filename, "c") + except: + logger.warning("insecure_debug_only_rsa_key_generator initialization FAILED, hack inoperative") + raise + + def __call__(self): + k = str(self.keyno) + try: + v = rpki.POW.Asymmetric.derReadPrivate(self.db[k]) + except KeyError: + v = rpki.POW.Asymmetric.generateRSA(2048) + self.db[k] = v.derWritePrivate() + self.keyno += 1 + return v class PrivateKey(DER_object): - """ - Class to hold a Public/Private key pair. - """ - - POW_class = rpki.POW.Asymmetric - - def get_DER(self): """ - Get the DER value of this keypair. + Class to hold a Public/Private key pair. """ - self.check() - if self.DER: - return self.DER - if self.POW: - self.DER = self.POW.derWritePrivate() - return self.get_DER() - raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + POW_class = rpki.POW.Asymmetric - def get_POW(self): - """ - Get the rpki.POW value of this keypair. - """ + def get_DER(self): + """ + Get the DER value of this keypair. + """ - self.check() - if not self.POW: # pylint: disable=E0203 - self.POW = rpki.POW.Asymmetric.derReadPrivate(self.get_DER()) - return self.POW + self.check() + if self.DER: + return self.DER + if self.POW: + self.DER = self.POW.derWritePrivate() + return self.get_DER() + raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") - def get_PEM(self): - """ - Get the PEM representation of this keypair. - """ + def get_POW(self): + """ + Get the rpki.POW value of this keypair. + """ - return self.get_POW().pemWritePrivate() + self.check() + if not self.POW: # pylint: disable=E0203 + self.POW = rpki.POW.Asymmetric.derReadPrivate(self.get_DER()) + return self.POW - def _set_PEM(self, pem): - """ - Set the POW value of this keypair from a PEM string. - """ + def get_PEM(self): + """ + Get the PEM representation of this keypair. + """ - assert self.empty() - self.POW = self.POW_class.pemReadPrivate(pem) + return self.get_POW().pemWritePrivate() - def get_public_DER(self): - """ - Get the DER encoding of the public key from this keypair. - """ + def _set_PEM(self, pem): + """ + Set the POW value of this keypair from a PEM string. + """ - return self.get_POW().derWritePublic() + assert self.empty() + self.POW = self.POW_class.pemReadPrivate(pem) - def get_SKI(self): - """ - Calculate the SKI of this keypair. - """ + def get_public_DER(self): + """ + Get the DER encoding of the public key from this keypair. + """ - return self.get_POW().calculateSKI() + return self.get_POW().derWritePublic() - def get_public(self): - """ - Convert the public key of this keypair into a PublicKey object. - """ + def get_SKI(self): + """ + Calculate the SKI of this keypair. + """ - return PublicKey(DER = self.get_public_DER()) + return self.get_POW().calculateSKI() -class PublicKey(DER_object): - """ - Class to hold a public key. - """ + def get_public(self): + """ + Convert the public key of this keypair into a PublicKey object. + """ - POW_class = rpki.POW.Asymmetric + return PublicKey(DER = self.get_public_DER()) - def get_DER(self): +class PublicKey(DER_object): """ - Get the DER value of this public key. + Class to hold a public key. """ - self.check() - if self.DER: - return self.DER - if self.POW: - self.DER = self.POW.derWritePublic() - return self.get_DER() - raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + POW_class = rpki.POW.Asymmetric - def get_POW(self): - """ - Get the rpki.POW value of this public key. - """ + def get_DER(self): + """ + Get the DER value of this public key. + """ - self.check() - if not self.POW: # pylint: disable=E0203 - self.POW = rpki.POW.Asymmetric.derReadPublic(self.get_DER()) - return self.POW + self.check() + if self.DER: + return self.DER + if self.POW: + self.DER = self.POW.derWritePublic() + return self.get_DER() + raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") - def get_PEM(self): - """ - Get the PEM representation of this public key. - """ + def get_POW(self): + """ + Get the rpki.POW value of this public key. + """ - return self.get_POW().pemWritePublic() + self.check() + if not self.POW: # pylint: disable=E0203 + self.POW = rpki.POW.Asymmetric.derReadPublic(self.get_DER()) + return self.POW - def _set_PEM(self, pem): - """ - Set the POW value of this public key from a PEM string. - """ + def get_PEM(self): + """ + Get the PEM representation of this public key. + """ - assert self.empty() - self.POW = self.POW_class.pemReadPublic(pem) + return self.get_POW().pemWritePublic() - def get_SKI(self): - """ - Calculate the SKI of this public key. - """ + def _set_PEM(self, pem): + """ + Set the POW value of this public key from a PEM string. + """ - return self.get_POW().calculateSKI() + assert self.empty() + self.POW = self.POW_class.pemReadPublic(pem) + + def get_SKI(self): + """ + Calculate the SKI of this public key. + """ + + return self.get_POW().calculateSKI() class KeyParams(DER_object): - """ - Wrapper for OpenSSL's asymmetric key parameter classes. - """ + """ + Wrapper for OpenSSL's asymmetric key parameter classes. + """ - POW_class = rpki.POW.AsymmetricParams + POW_class = rpki.POW.AsymmetricParams - @classmethod - def generateEC(cls, curve = rpki.POW.EC_P256_CURVE): - return cls(POW = rpki.POW.AsymmetricParams.generateEC(curve = curve)) + @classmethod + def generateEC(cls, curve = rpki.POW.EC_P256_CURVE): + return cls(POW = rpki.POW.AsymmetricParams.generateEC(curve = curve)) class RSA(PrivateKey): - """ - Class to hold an RSA key pair. - """ - - @classmethod - def generate(cls, keylength = 2048, quiet = False): """ - Generate a new keypair. + Class to hold an RSA key pair. """ - if not quiet: - logger.debug("Generating new %d-bit RSA key", keylength) - if generate_insecure_debug_only_rsa_key is not None: - return cls(POW = generate_insecure_debug_only_rsa_key()) - else: - return cls(POW = rpki.POW.Asymmetric.generateRSA(keylength)) + @classmethod + def generate(cls, keylength = 2048, quiet = False): + """ + Generate a new keypair. + """ -class ECDSA(PrivateKey): - """ - Class to hold an ECDSA key pair. - """ + if not quiet: + logger.debug("Generating new %d-bit RSA key", keylength) + if generate_insecure_debug_only_rsa_key is not None: + return cls(POW = generate_insecure_debug_only_rsa_key()) + else: + return cls(POW = rpki.POW.Asymmetric.generateRSA(keylength)) - @classmethod - def generate(cls, params = None, quiet = False): +class ECDSA(PrivateKey): """ - Generate a new keypair. + Class to hold an ECDSA key pair. """ - if params is None: - if not quiet: - logger.debug("Generating new ECDSA key parameters") - params = KeyParams.generateEC() + @classmethod + def generate(cls, params = None, quiet = False): + """ + Generate a new keypair. + """ - assert isinstance(params, KeyParams) + if params is None: + if not quiet: + logger.debug("Generating new ECDSA key parameters") + params = KeyParams.generateEC() - if not quiet: - logger.debug("Generating new ECDSA key") + assert isinstance(params, KeyParams) - return cls(POW = rpki.POW.Asymmetric.generateFromParams(params.get_POW())) + if not quiet: + logger.debug("Generating new ECDSA key") + + return cls(POW = rpki.POW.Asymmetric.generateFromParams(params.get_POW())) class CMS_object(DER_object): - """ - Abstract class to hold a CMS object. - """ + """ + Abstract class to hold a CMS object. + """ - econtent_oid = rpki.oids.id_data - POW_class = rpki.POW.CMS + econtent_oid = rpki.oids.id_data + POW_class = rpki.POW.CMS - ## @var dump_on_verify_failure - # Set this to True to get dumpasn1 dumps of ASN.1 on CMS verify failures. + ## @var dump_on_verify_failure + # Set this to True to get dumpasn1 dumps of ASN.1 on CMS verify failures. - dump_on_verify_failure = True + dump_on_verify_failure = True - ## @var debug_cms_certs - # Set this to True to log a lot of chatter about CMS certificates. + ## @var debug_cms_certs + # Set this to True to log a lot of chatter about CMS certificates. - debug_cms_certs = False + debug_cms_certs = False - ## @var dump_using_dumpasn1 - # Set this to use external dumpasn1 program, which is prettier and - # more informative than OpenSSL's CMS text dump, but which won't - # work if the dumpasn1 program isn't installed. + ## @var dump_using_dumpasn1 + # Set this to use external dumpasn1 program, which is prettier and + # more informative than OpenSSL's CMS text dump, but which won't + # work if the dumpasn1 program isn't installed. - dump_using_dumpasn1 = False + dump_using_dumpasn1 = False - ## @var require_crls - # Set this to False to make CMS CRLs optional in the cases where we - # would otherwise require them. Some day this option should go away - # and CRLs should be uncondtionally mandatory in such cases. + ## @var require_crls + # Set this to False to make CMS CRLs optional in the cases where we + # would otherwise require them. Some day this option should go away + # and CRLs should be uncondtionally mandatory in such cases. - require_crls = False + require_crls = False - ## @var allow_extra_certs - # Set this to True to allow CMS messages to contain CA certificates. + ## @var allow_extra_certs + # Set this to True to allow CMS messages to contain CA certificates. - allow_extra_certs = False + allow_extra_certs = False + + ## @var allow_extra_crls + # Set this to True to allow CMS messages to contain multiple CRLs. - ## @var allow_extra_crls - # Set this to True to allow CMS messages to contain multiple CRLs. + allow_extra_crls = False - allow_extra_crls = False + ## @var print_on_der_error + # Set this to True to log alleged DER when we have trouble parsing + # it, in case it's really a Perl backtrace or something. - ## @var print_on_der_error - # Set this to True to log alleged DER when we have trouble parsing - # it, in case it's really a Perl backtrace or something. + print_on_der_error = True - print_on_der_error = True + def get_DER(self): + """ + Get the DER value of this CMS_object. + """ - def get_DER(self): - """ - Get the DER value of this CMS_object. - """ + self.check() + if self.DER: + return self.DER + if self.POW: + self.DER = self.POW.derWrite() + return self.get_DER() + raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") - self.check() - if self.DER: - return self.DER - if self.POW: - self.DER = self.POW.derWrite() - return self.get_DER() - raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + def get_POW(self): + """ + Get the rpki.POW value of this CMS_object. + """ - def get_POW(self): - """ - Get the rpki.POW value of this CMS_object. - """ + self.check() + if not self.POW: # pylint: disable=E0203 + self.POW = self.POW_class.derRead(self.get_DER()) + return self.POW - self.check() - if not self.POW: # pylint: disable=E0203 - self.POW = self.POW_class.derRead(self.get_DER()) - return self.POW + def get_signingTime(self): + """ + Extract signingTime from CMS signed attributes. + """ - def get_signingTime(self): - """ - Extract signingTime from CMS signed attributes. - """ + return self.get_POW().signingTime() - return self.get_POW().signingTime() + def verify(self, ta): + """ + Verify CMS wrapper and store inner content. + """ - def verify(self, ta): - """ - Verify CMS wrapper and store inner content. - """ + try: + cms = self.get_POW() + except: + if self.print_on_der_error: + logger.debug("Problem parsing DER CMS message, might not really be DER: %r", + self.get_DER()) + raise rpki.exceptions.UnparsableCMSDER - try: - cms = self.get_POW() - except: - if self.print_on_der_error: - logger.debug("Problem parsing DER CMS message, might not really be DER: %r", - self.get_DER()) - raise rpki.exceptions.UnparsableCMSDER - - if cms.eContentType() != self.econtent_oid: - raise rpki.exceptions.WrongEContentType("Got CMS eContentType %s, expected %s" % ( - cms.eContentType(), self.econtent_oid)) - - certs = [X509(POW = x) for x in cms.certs()] - crls = [CRL(POW = c) for c in cms.crls()] - - if self.debug_cms_certs: - for x in certs: - logger.debug("Received CMS cert issuer %s subject %s SKI %s", - x.getIssuer(), x.getSubject(), x.hSKI()) - for c in crls: - logger.debug("Received CMS CRL issuer %r", c.getIssuer()) - - store = rpki.POW.X509Store() - - now = rpki.sundial.now() - - trusted_ee = None - - for x in X509.normalize_chain(ta): - if self.debug_cms_certs: - logger.debug("CMS trusted cert issuer %s subject %s SKI %s", - x.getIssuer(), x.getSubject(), x.hSKI()) - if x.getNotAfter() < now: - raise rpki.exceptions.TrustedCMSCertHasExpired("Trusted CMS certificate has expired", - "%s (%s)" % (x.getSubject(), x.hSKI())) - if not x.is_CA(): - if trusted_ee is None: - trusted_ee = x - else: - raise rpki.exceptions.MultipleCMSEECert("Multiple CMS EE certificates", *("%s (%s)" % ( - x.getSubject(), x.hSKI()) for x in ta if not x.is_CA())) - store.addTrust(x.get_POW()) - - if trusted_ee: - if self.debug_cms_certs: - logger.debug("Trusted CMS EE cert issuer %s subject %s SKI %s", - trusted_ee.getIssuer(), trusted_ee.getSubject(), trusted_ee.hSKI()) - if len(certs) > 1 or (len(certs) == 1 and - (certs[0].getSubject() != trusted_ee.getSubject() or - certs[0].getPublicKey() != trusted_ee.getPublicKey())): - raise rpki.exceptions.UnexpectedCMSCerts("Unexpected CMS certificates", *("%s (%s)" % ( - x.getSubject(), x.hSKI()) for x in certs)) - if crls: - raise rpki.exceptions.UnexpectedCMSCRLs("Unexpected CRLs", *("%s (%s)" % ( - c.getIssuer(), c.hAKI()) for c in crls)) - - else: - untrusted_ee = [x for x in certs if not x.is_CA()] - if len(untrusted_ee) < 1: - raise rpki.exceptions.MissingCMSEEcert - if len(untrusted_ee) > 1 or (not self.allow_extra_certs and len(certs) > len(untrusted_ee)): - raise rpki.exceptions.UnexpectedCMSCerts("Unexpected CMS certificates", *("%s (%s)" % ( - x.getSubject(), x.hSKI()) for x in certs)) - if len(crls) < 1: - if self.require_crls: - raise rpki.exceptions.MissingCMSCRL - else: - logger.warning("MISSING CMS CRL! Ignoring per self.require_crls setting") - if len(crls) > 1 and not self.allow_extra_crls: - raise rpki.exceptions.UnexpectedCMSCRLs("Unexpected CRLs", *("%s (%s)" % ( - c.getIssuer(), c.hAKI()) for c in crls)) - - for x in certs: - if x.getNotAfter() < now: - raise rpki.exceptions.CMSCertHasExpired("CMS certificate has expired", "%s (%s)" % ( - x.getSubject(), x.hSKI())) - - for c in crls: - if c.getNextUpdate() < now: - logger.warning("Stale BPKI CMS CRL (%s %s %s)", c.getNextUpdate(), c.getIssuer(), c.hAKI()) - - try: - content = cms.verify(store) - except: - if self.dump_on_verify_failure: - if self.dump_using_dumpasn1: - dbg = self.dumpasn1() - else: - dbg = cms.pprint() - logger.warning("CMS verification failed, dumping ASN.1 (%d octets):", len(self.get_DER())) - for line in dbg.splitlines(): - logger.warning(line) - raise rpki.exceptions.CMSVerificationFailed("CMS verification failed") + if cms.eContentType() != self.econtent_oid: + raise rpki.exceptions.WrongEContentType("Got CMS eContentType %s, expected %s" % ( + cms.eContentType(), self.econtent_oid)) - return content + certs = [X509(POW = x) for x in cms.certs()] + crls = [CRL(POW = c) for c in cms.crls()] - def extract(self): - """ - Extract and store inner content from CMS wrapper without verifying - the CMS. + if self.debug_cms_certs: + for x in certs: + logger.debug("Received CMS cert issuer %s subject %s SKI %s", + x.getIssuer(), x.getSubject(), x.hSKI()) + for c in crls: + logger.debug("Received CMS CRL issuer %r", c.getIssuer()) + + store = rpki.POW.X509Store() + + now = rpki.sundial.now() + + trusted_ee = None + + for x in X509.normalize_chain(ta): + if self.debug_cms_certs: + logger.debug("CMS trusted cert issuer %s subject %s SKI %s", + x.getIssuer(), x.getSubject(), x.hSKI()) + if x.getNotAfter() < now: + raise rpki.exceptions.TrustedCMSCertHasExpired("Trusted CMS certificate has expired", + "%s (%s)" % (x.getSubject(), x.hSKI())) + if not x.is_CA(): + if trusted_ee is None: + trusted_ee = x + else: + raise rpki.exceptions.MultipleCMSEECert("Multiple CMS EE certificates", *("%s (%s)" % ( + x.getSubject(), x.hSKI()) for x in ta if not x.is_CA())) + store.addTrust(x.get_POW()) + + if trusted_ee: + if self.debug_cms_certs: + logger.debug("Trusted CMS EE cert issuer %s subject %s SKI %s", + trusted_ee.getIssuer(), trusted_ee.getSubject(), trusted_ee.hSKI()) + if len(certs) > 1 or (len(certs) == 1 and + (certs[0].getSubject() != trusted_ee.getSubject() or + certs[0].getPublicKey() != trusted_ee.getPublicKey())): + raise rpki.exceptions.UnexpectedCMSCerts("Unexpected CMS certificates", *("%s (%s)" % ( + x.getSubject(), x.hSKI()) for x in certs)) + if crls: + raise rpki.exceptions.UnexpectedCMSCRLs("Unexpected CRLs", *("%s (%s)" % ( + c.getIssuer(), c.hAKI()) for c in crls)) - DANGER WILL ROBINSON!!! + else: + untrusted_ee = [x for x in certs if not x.is_CA()] + if len(untrusted_ee) < 1: + raise rpki.exceptions.MissingCMSEEcert + if len(untrusted_ee) > 1 or (not self.allow_extra_certs and len(certs) > len(untrusted_ee)): + raise rpki.exceptions.UnexpectedCMSCerts("Unexpected CMS certificates", *("%s (%s)" % ( + x.getSubject(), x.hSKI()) for x in certs)) + if len(crls) < 1: + if self.require_crls: + raise rpki.exceptions.MissingCMSCRL + else: + logger.warning("MISSING CMS CRL! Ignoring per self.require_crls setting") + if len(crls) > 1 and not self.allow_extra_crls: + raise rpki.exceptions.UnexpectedCMSCRLs("Unexpected CRLs", *("%s (%s)" % ( + c.getIssuer(), c.hAKI()) for c in crls)) + + for x in certs: + if x.getNotAfter() < now: + raise rpki.exceptions.CMSCertHasExpired("CMS certificate has expired", "%s (%s)" % ( + x.getSubject(), x.hSKI())) + + for c in crls: + if c.getNextUpdate() < now: + logger.warning("Stale BPKI CMS CRL (%s %s %s)", c.getNextUpdate(), c.getIssuer(), c.hAKI()) + + try: + content = cms.verify(store) + except: + if self.dump_on_verify_failure: + if self.dump_using_dumpasn1: + dbg = self.dumpasn1() + else: + dbg = cms.pprint() + logger.warning("CMS verification failed, dumping ASN.1 (%d octets):", len(self.get_DER())) + for line in dbg.splitlines(): + logger.warning(line) + raise rpki.exceptions.CMSVerificationFailed("CMS verification failed") + + return content + + def extract(self): + """ + Extract and store inner content from CMS wrapper without verifying + the CMS. + + DANGER WILL ROBINSON!!! + + Do not use this method on unvalidated data. Use the verify() + method instead. + + If you don't understand this warning, don't use this method. + """ + + try: + cms = self.get_POW() + except: + raise rpki.exceptions.UnparsableCMSDER + + if cms.eContentType() != self.econtent_oid: + raise rpki.exceptions.WrongEContentType("Got CMS eContentType %s, expected %s" % ( + cms.eContentType(), self.econtent_oid)) + + return cms.verify(rpki.POW.X509Store(), None, + (rpki.POW.CMS_NOCRL | rpki.POW.CMS_NO_SIGNER_CERT_VERIFY | + rpki.POW.CMS_NO_ATTR_VERIFY | rpki.POW.CMS_NO_CONTENT_VERIFY)) + + + def sign(self, keypair, certs, crls = None, no_certs = False): + """ + Sign and wrap inner content. + """ + + if isinstance(certs, X509): + cert = certs + certs = () + else: + cert = certs[0] + certs = certs[1:] - Do not use this method on unvalidated data. Use the verify() - method instead. + if crls is None: + crls = () + elif isinstance(crls, CRL): + crls = (crls,) - If you don't understand this warning, don't use this method. - """ + if self.debug_cms_certs: + logger.debug("Signing with cert issuer %s subject %s SKI %s", + cert.getIssuer(), cert.getSubject(), cert.hSKI()) + for i, c in enumerate(certs): + logger.debug("Additional cert %d issuer %s subject %s SKI %s", + i, c.getIssuer(), c.getSubject(), c.hSKI()) - try: - cms = self.get_POW() - except: - raise rpki.exceptions.UnparsableCMSDER + self._sign(cert.get_POW(), + keypair.get_POW(), + [x.get_POW() for x in certs], + [c.get_POW() for c in crls], + rpki.POW.CMS_NOCERTS if no_certs else 0) - if cms.eContentType() != self.econtent_oid: - raise rpki.exceptions.WrongEContentType("Got CMS eContentType %s, expected %s" % ( - cms.eContentType(), self.econtent_oid)) + @property + def creation_timestamp(self): + """ + Time at which this object was created. + """ - return cms.verify(rpki.POW.X509Store(), None, - (rpki.POW.CMS_NOCRL | rpki.POW.CMS_NO_SIGNER_CERT_VERIFY | - rpki.POW.CMS_NO_ATTR_VERIFY | rpki.POW.CMS_NO_CONTENT_VERIFY)) + return self.get_signingTime() - def sign(self, keypair, certs, crls = None, no_certs = False): - """ - Sign and wrap inner content. +class Wrapped_CMS_object(CMS_object): """ + Abstract class to hold CMS objects wrapping non-DER content (eg, XML + or VCard). - if isinstance(certs, X509): - cert = certs - certs = () - else: - cert = certs[0] - certs = certs[1:] - - if crls is None: - crls = () - elif isinstance(crls, CRL): - crls = (crls,) - - if self.debug_cms_certs: - logger.debug("Signing with cert issuer %s subject %s SKI %s", - cert.getIssuer(), cert.getSubject(), cert.hSKI()) - for i, c in enumerate(certs): - logger.debug("Additional cert %d issuer %s subject %s SKI %s", - i, c.getIssuer(), c.getSubject(), c.hSKI()) - - self._sign(cert.get_POW(), - keypair.get_POW(), - [x.get_POW() for x in certs], - [c.get_POW() for c in crls], - rpki.POW.CMS_NOCERTS if no_certs else 0) - - @property - def creation_timestamp(self): - """ - Time at which this object was created. + CMS-wrapped objects are a little different from the other DER_object + types because the signed object is CMS wrapping some other kind of + inner content. A Wrapped_CMS_object is the outer CMS wrapped object + so that the usual DER and PEM operations do the obvious things, and + the inner content is handle via separate methods. """ - return self.get_signingTime() + other_clear = ("content",) + def get_content(self): + """ + Get the inner content of this Wrapped_CMS_object. + """ -class Wrapped_CMS_object(CMS_object): - """ - Abstract class to hold CMS objects wrapping non-DER content (eg, XML - or VCard). + if self.content is None: + raise rpki.exceptions.CMSContentNotSet("Inner content of CMS object %r is not set" % self) + return self.content - CMS-wrapped objects are a little different from the other DER_object - types because the signed object is CMS wrapping some other kind of - inner content. A Wrapped_CMS_object is the outer CMS wrapped object - so that the usual DER and PEM operations do the obvious things, and - the inner content is handle via separate methods. - """ + def set_content(self, content): + """ + Set the (inner) content of this Wrapped_CMS_object, clearing the wrapper. + """ - other_clear = ("content",) + self.clear() + self.content = content - def get_content(self): - """ - Get the inner content of this Wrapped_CMS_object. - """ + def verify(self, ta): + """ + Verify CMS wrapper and store inner content. + """ - if self.content is None: - raise rpki.exceptions.CMSContentNotSet("Inner content of CMS object %r is not set" % self) - return self.content + self.decode(CMS_object.verify(self, ta)) + return self.get_content() - def set_content(self, content): - """ - Set the (inner) content of this Wrapped_CMS_object, clearing the wrapper. - """ + def extract(self): + """ + Extract and store inner content from CMS wrapper without verifying + the CMS. - self.clear() - self.content = content + DANGER WILL ROBINSON!!! - def verify(self, ta): - """ - Verify CMS wrapper and store inner content. - """ + Do not use this method on unvalidated data. Use the verify() + method instead. - self.decode(CMS_object.verify(self, ta)) - return self.get_content() + If you don't understand this warning, don't use this method. + """ - def extract(self): - """ - Extract and store inner content from CMS wrapper without verifying - the CMS. + self.decode(CMS_object.extract(self)) + return self.get_content() - DANGER WILL ROBINSON!!! + def extract_if_needed(self): + """ + Extract inner content if needed. See caveats for .extract(), do + not use unless you really know what you are doing. + """ - Do not use this method on unvalidated data. Use the verify() - method instead. - - If you don't understand this warning, don't use this method. - """ + if self.content is None: + self.extract() - self.decode(CMS_object.extract(self)) - return self.get_content() + def _sign(self, cert, keypair, certs, crls, flags): + """ + Internal method to call POW to do CMS signature. This is split + out from the .sign() API method to handle differences in how + different CMS-based POW classes handle the inner content. + """ - def extract_if_needed(self): - """ - Extract inner content if needed. See caveats for .extract(), do - not use unless you really know what you are doing. - """ + cms = self.POW_class() + cms.sign(cert, keypair, self.encode(), certs, crls, self.econtent_oid, flags) + self.POW = cms - if self.content is None: - self.extract() - def _sign(self, cert, keypair, certs, crls, flags): +class DER_CMS_object(CMS_object): """ - Internal method to call POW to do CMS signature. This is split - out from the .sign() API method to handle differences in how - different CMS-based POW classes handle the inner content. + Abstract class for CMS-based objects with DER-encoded content + handled by C-level subclasses of rpki.POW.CMS. """ - cms = self.POW_class() - cms.sign(cert, keypair, self.encode(), certs, crls, self.econtent_oid, flags) - self.POW = cms + def _sign(self, cert, keypair, certs, crls, flags): + self.get_POW().sign(cert, keypair, certs, crls, self.econtent_oid, flags) -class DER_CMS_object(CMS_object): - """ - Abstract class for CMS-based objects with DER-encoded content - handled by C-level subclasses of rpki.POW.CMS. - """ + def extract_if_needed(self): + """ + Extract inner content if needed. See caveats for .extract(), do + not use unless you really know what you are doing. + """ - def _sign(self, cert, keypair, certs, crls, flags): - self.get_POW().sign(cert, keypair, certs, crls, self.econtent_oid, flags) - - - def extract_if_needed(self): - """ - Extract inner content if needed. See caveats for .extract(), do - not use unless you really know what you are doing. - """ - - try: - self.get_POW().getVersion() - except rpki.POW.NotVerifiedError: - self.extract() + try: + self.get_POW().getVersion() + except rpki.POW.NotVerifiedError: + self.extract() class SignedManifest(DER_CMS_object): - """ - Class to hold a signed manifest. - """ - - econtent_oid = rpki.oids.id_ct_rpkiManifest - POW_class = rpki.POW.Manifest - - def getThisUpdate(self): """ - Get thisUpdate value from this manifest. + Class to hold a signed manifest. """ - return self.get_POW().getThisUpdate() + econtent_oid = rpki.oids.id_ct_rpkiManifest + POW_class = rpki.POW.Manifest - def getNextUpdate(self): - """ - Get nextUpdate value from this manifest. - """ + def getThisUpdate(self): + """ + Get thisUpdate value from this manifest. + """ - return self.get_POW().getNextUpdate() + return self.get_POW().getThisUpdate() - @classmethod - def build(cls, serial, thisUpdate, nextUpdate, names_and_objs, keypair, certs, version = 0): - """ - Build a signed manifest. - """ + def getNextUpdate(self): + """ + Get nextUpdate value from this manifest. + """ - filelist = [] - for name, obj in names_and_objs: - filelist.append((name.rpartition("/")[2], sha256(obj.get_DER()))) - filelist.sort(key = lambda x: x[0]) + return self.get_POW().getNextUpdate() - obj = cls.POW_class() - obj.setVersion(version) - obj.setManifestNumber(serial) - obj.setThisUpdate(thisUpdate) - obj.setNextUpdate(nextUpdate) - obj.setAlgorithm(rpki.oids.id_sha256) - obj.addFiles(filelist) + @classmethod + def build(cls, serial, thisUpdate, nextUpdate, names_and_objs, keypair, certs, version = 0): + """ + Build a signed manifest. + """ - self = cls(POW = obj) - self.sign(keypair, certs) - return self + filelist = [] + for name, obj in names_and_objs: + filelist.append((name.rpartition("/")[2], sha256(obj.get_DER()))) + filelist.sort(key = lambda x: x[0]) -class ROA(DER_CMS_object): - """ - Class to hold a signed ROA. - """ + obj = cls.POW_class() + obj.setVersion(version) + obj.setManifestNumber(serial) + obj.setThisUpdate(thisUpdate) + obj.setNextUpdate(nextUpdate) + obj.setAlgorithm(rpki.oids.id_sha256) + obj.addFiles(filelist) - econtent_oid = rpki.oids.id_ct_routeOriginAttestation - POW_class = rpki.POW.ROA + self = cls(POW = obj) + self.sign(keypair, certs) + return self - @classmethod - def build(cls, asn, ipv4, ipv6, keypair, certs, version = 0): +class ROA(DER_CMS_object): """ - Build a ROA. + Class to hold a signed ROA. + """ + + econtent_oid = rpki.oids.id_ct_routeOriginAttestation + POW_class = rpki.POW.ROA + + @classmethod + def build(cls, asn, ipv4, ipv6, keypair, certs, version = 0): + """ + Build a ROA. + """ + + ipv4 = ipv4.to_POW_roa_tuple() if ipv4 else None + ipv6 = ipv6.to_POW_roa_tuple() if ipv6 else None + obj = cls.POW_class() + obj.setVersion(version) + obj.setASID(asn) + obj.setPrefixes(ipv4 = ipv4, ipv6 = ipv6) + self = cls(POW = obj) + self.sign(keypair, certs) + return self + + def tracking_data(self, uri): + """ + Return a string containing data we want to log when tracking how + objects move through the RPKI system. + """ + + msg = DER_CMS_object.tracking_data(self, uri) + try: + self.extract_if_needed() + asn = self.get_POW().getASID() + text = [] + for prefixes in self.get_POW().getPrefixes(): + if prefixes is not None: + for prefix, prefixlen, maxprefixlen in prefixes: + if maxprefixlen is None or prefixlen == maxprefixlen: + text.append("%s/%s" % (prefix, prefixlen)) + else: + text.append("%s/%s-%s" % (prefix, prefixlen, maxprefixlen)) + text.sort() + msg = "%s %s %s" % (msg, asn, ",".join(text)) + except: # pylint: disable=W0702 + pass + return msg + +class DeadDrop(object): """ + Dead-drop utility for storing copies of CMS messages for debugging or + audit. At the moment this uses Maildir mailbox format, as it has + approximately the right properties and a number of useful tools for + manipulating it already exist. + """ + + def __init__(self, name): + self.name = name + self.pid = os.getpid() + self.maildir = mailbox.Maildir(name, factory = None, create = True) + self.warned = False + + def dump(self, obj): + try: + now = time.time() + msg = email.mime.application.MIMEApplication(obj.get_DER(), "x-rpki") + msg["Date"] = email.utils.formatdate(now) + msg["Subject"] = "Process %s dump of %r" % (self.pid, obj) + msg["Message-ID"] = email.utils.make_msgid() + msg["X-RPKI-PID"] = str(self.pid) + msg["X-RPKI-Object"] = repr(obj) + msg["X-RPKI-Timestamp"] = "%f" % now + self.maildir.add(msg) + self.warned = False + except Exception, e: + if not self.warned: + logger.warning("Could not write to mailbox %s: %s", self.name, e) + self.warned = True - ipv4 = ipv4.to_POW_roa_tuple() if ipv4 else None - ipv6 = ipv6.to_POW_roa_tuple() if ipv6 else None - obj = cls.POW_class() - obj.setVersion(version) - obj.setASID(asn) - obj.setPrefixes(ipv4 = ipv4, ipv6 = ipv6) - self = cls(POW = obj) - self.sign(keypair, certs) - return self - - def tracking_data(self, uri): +class XML_CMS_object(Wrapped_CMS_object): """ - Return a string containing data we want to log when tracking how - objects move through the RPKI system. + Class to hold CMS-wrapped XML protocol data. """ - msg = DER_CMS_object.tracking_data(self, uri) - try: - self.extract_if_needed() - asn = self.get_POW().getASID() - text = [] - for prefixes in self.get_POW().getPrefixes(): - if prefixes is not None: - for prefix, prefixlen, maxprefixlen in prefixes: - if maxprefixlen is None or prefixlen == maxprefixlen: - text.append("%s/%s" % (prefix, prefixlen)) - else: - text.append("%s/%s-%s" % (prefix, prefixlen, maxprefixlen)) - text.sort() - msg = "%s %s %s" % (msg, asn, ",".join(text)) - except: # pylint: disable=W0702 - pass - return msg - -class DeadDrop(object): - """ - Dead-drop utility for storing copies of CMS messages for debugging or - audit. At the moment this uses Maildir mailbox format, as it has - approximately the right properties and a number of useful tools for - manipulating it already exist. - """ - - def __init__(self, name): - self.name = name - self.pid = os.getpid() - self.maildir = mailbox.Maildir(name, factory = None, create = True) - self.warned = False - - def dump(self, obj): - try: - now = time.time() - msg = email.mime.application.MIMEApplication(obj.get_DER(), "x-rpki") - msg["Date"] = email.utils.formatdate(now) - msg["Subject"] = "Process %s dump of %r" % (self.pid, obj) - msg["Message-ID"] = email.utils.make_msgid() - msg["X-RPKI-PID"] = str(self.pid) - msg["X-RPKI-Object"] = repr(obj) - msg["X-RPKI-Timestamp"] = "%f" % now - self.maildir.add(msg) - self.warned = False - except Exception, e: - if not self.warned: - logger.warning("Could not write to mailbox %s: %s", self.name, e) - self.warned = True - -class XML_CMS_object(Wrapped_CMS_object): - """ - Class to hold CMS-wrapped XML protocol data. - """ + econtent_oid = rpki.oids.id_ct_xml - econtent_oid = rpki.oids.id_ct_xml + ## @var dump_outbound_cms + # If set, we write all outbound XML-CMS PDUs to disk, for debugging. + # If set, value should be a DeadDrop object. - ## @var dump_outbound_cms - # If set, we write all outbound XML-CMS PDUs to disk, for debugging. - # If set, value should be a DeadDrop object. + dump_outbound_cms = None - dump_outbound_cms = None + ## @var dump_inbound_cms + # If set, we write all inbound XML-CMS PDUs to disk, for debugging. + # If set, value should be a DeadDrop object. - ## @var dump_inbound_cms - # If set, we write all inbound XML-CMS PDUs to disk, for debugging. - # If set, value should be a DeadDrop object. + dump_inbound_cms = None - dump_inbound_cms = None + ## @var check_inbound_schema + # If set, perform RelaxNG schema check on inbound messages. - ## @var check_inbound_schema - # If set, perform RelaxNG schema check on inbound messages. + check_inbound_schema = True - check_inbound_schema = True + ## @var check_outbound_schema + # If set, perform RelaxNG schema check on outbound messages. - ## @var check_outbound_schema - # If set, perform RelaxNG schema check on outbound messages. + check_outbound_schema = True - check_outbound_schema = True + def encode(self): + """ + Encode inner content for signing. + """ - def encode(self): - """ - Encode inner content for signing. - """ + return lxml.etree.tostring(self.get_content(), + pretty_print = True, + encoding = self.encoding, + xml_declaration = True) - return lxml.etree.tostring(self.get_content(), - pretty_print = True, - encoding = self.encoding, - xml_declaration = True) + def decode(self, xml): + """ + Decode XML and set inner content. + """ - def decode(self, xml): - """ - Decode XML and set inner content. - """ + self.content = lxml.etree.fromstring(xml) - self.content = lxml.etree.fromstring(xml) + def pretty_print_content(self): + """ + Pretty print XML content of this message. + """ - def pretty_print_content(self): - """ - Pretty print XML content of this message. - """ + return lxml.etree.tostring(self.get_content(), + pretty_print = True, + encoding = self.encoding, + xml_declaration = True) - return lxml.etree.tostring(self.get_content(), - pretty_print = True, - encoding = self.encoding, - xml_declaration = True) + def schema_check(self): + """ + Handle XML RelaxNG schema check. + """ - def schema_check(self): - """ - Handle XML RelaxNG schema check. - """ + try: + self.schema.assertValid(self.get_content()) + except lxml.etree.DocumentInvalid: + logger.error("PDU failed schema check") + for line in self.pretty_print_content().splitlines(): + logger.warning(line) + raise - try: - self.schema.assertValid(self.get_content()) - except lxml.etree.DocumentInvalid: - logger.error("PDU failed schema check") - for line in self.pretty_print_content().splitlines(): - logger.warning(line) - raise + def dump_to_disk(self, prefix): + """ + Write DER of current message to disk, for debugging. + """ - def dump_to_disk(self, prefix): - """ - Write DER of current message to disk, for debugging. - """ - - f = open(prefix + rpki.sundial.now().isoformat() + "Z.cms", "wb") - f.write(self.get_DER()) - f.close() + f = open(prefix + rpki.sundial.now().isoformat() + "Z.cms", "wb") + f.write(self.get_DER()) + f.close() - def wrap(self, msg, keypair, certs, crls = None): - """ - Wrap an XML PDU in CMS and return its DER encoding. - """ + def wrap(self, msg, keypair, certs, crls = None): + """ + Wrap an XML PDU in CMS and return its DER encoding. + """ + + self.set_content(msg) + if self.check_outbound_schema: + self.schema_check() + self.sign(keypair, certs, crls) + if self.dump_outbound_cms: + self.dump_outbound_cms.dump(self) + return self.get_DER() + + def unwrap(self, ta): + """ + Unwrap a CMS-wrapped XML PDU and return Python objects. + """ + + if self.dump_inbound_cms: + self.dump_inbound_cms.dump(self) + self.verify(ta) + if self.check_inbound_schema: + self.schema_check() + return self.get_content() + + def check_replay(self, timestamp, *context): + """ + Check CMS signing-time in this object against a recorded + timestamp. Raises an exception if the recorded timestamp is more + recent, otherwise returns the new timestamp. + """ + + new_timestamp = self.get_signingTime() + if timestamp is not None and timestamp > new_timestamp: + if context: + context = " (" + " ".join(context) + ")" + raise rpki.exceptions.CMSReplay( + "CMS replay: last message %s, this message %s%s" % ( + timestamp, new_timestamp, context)) + return new_timestamp + + def check_replay_sql(self, obj, *context): + """ + Like .check_replay() but gets recorded timestamp from + "last_cms_timestamp" field of an SQL object and stores the new + timestamp back in that same field. + """ + + obj.last_cms_timestamp = self.check_replay(obj.last_cms_timestamp, *context) + obj.save() - self.set_content(msg) - if self.check_outbound_schema: - self.schema_check() - self.sign(keypair, certs, crls) - if self.dump_outbound_cms: - self.dump_outbound_cms.dump(self) - return self.get_DER() +class SignedReferral(XML_CMS_object): + encoding = "us-ascii" + schema = rpki.relaxng.oob_setup - def unwrap(self, ta): +class Ghostbuster(Wrapped_CMS_object): """ - Unwrap a CMS-wrapped XML PDU and return Python objects. + Class to hold Ghostbusters record (CMS-wrapped VCard). This is + quite minimal because we treat the VCard as an opaque byte string + managed by the back-end. """ - if self.dump_inbound_cms: - self.dump_inbound_cms.dump(self) - self.verify(ta) - if self.check_inbound_schema: - self.schema_check() - return self.get_content() + econtent_oid = rpki.oids.id_ct_rpkiGhostbusters - def check_replay(self, timestamp, *context): - """ - Check CMS signing-time in this object against a recorded - timestamp. Raises an exception if the recorded timestamp is more - recent, otherwise returns the new timestamp. - """ + def encode(self): + """ + Encode inner content for signing. At the moment we're treating + the VCard as an opaque byte string, so no encoding needed here. + """ - new_timestamp = self.get_signingTime() - if timestamp is not None and timestamp > new_timestamp: - if context: - context = " (" + " ".join(context) + ")" - raise rpki.exceptions.CMSReplay( - "CMS replay: last message %s, this message %s%s" % ( - timestamp, new_timestamp, context)) - return new_timestamp + return self.get_content() - def check_replay_sql(self, obj, *context): - """ - Like .check_replay() but gets recorded timestamp from - "last_cms_timestamp" field of an SQL object and stores the new - timestamp back in that same field. - """ + def decode(self, vcard): + """ + Decode XML and set inner content. At the moment we're treating + the VCard as an opaque byte string, so no encoding needed here. + """ - obj.last_cms_timestamp = self.check_replay(obj.last_cms_timestamp, *context) - obj.save() + self.content = vcard -class SignedReferral(XML_CMS_object): - encoding = "us-ascii" - schema = rpki.relaxng.oob_setup + @classmethod + def build(cls, vcard, keypair, certs): + """ + Build a Ghostbuster record. + """ -class Ghostbuster(Wrapped_CMS_object): - """ - Class to hold Ghostbusters record (CMS-wrapped VCard). This is - quite minimal because we treat the VCard as an opaque byte string - managed by the back-end. - """ + self = cls() + self.set_content(vcard) + self.sign(keypair, certs) + return self - econtent_oid = rpki.oids.id_ct_rpkiGhostbusters - def encode(self): +class CRL(DER_object): """ - Encode inner content for signing. At the moment we're treating - the VCard as an opaque byte string, so no encoding needed here. + Class to hold a Certificate Revocation List. """ - return self.get_content() + POW_class = rpki.POW.CRL - def decode(self, vcard): - """ - Decode XML and set inner content. At the moment we're treating - the VCard as an opaque byte string, so no encoding needed here. - """ + def get_DER(self): + """ + Get the DER value of this CRL. + """ - self.content = vcard + self.check() + if self.DER: + return self.DER + if self.POW: + self.DER = self.POW.derWrite() + return self.get_DER() + raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") - @classmethod - def build(cls, vcard, keypair, certs): - """ - Build a Ghostbuster record. - """ + def get_POW(self): + """ + Get the rpki.POW value of this CRL. + """ - self = cls() - self.set_content(vcard) - self.sign(keypair, certs) - return self + self.check() + if not self.POW: # pylint: disable=E0203 + self.POW = rpki.POW.CRL.derRead(self.get_DER()) + return self.POW + def getThisUpdate(self): + """ + Get thisUpdate value from this CRL. + """ -class CRL(DER_object): - """ - Class to hold a Certificate Revocation List. - """ + return self.get_POW().getThisUpdate() - POW_class = rpki.POW.CRL + def getNextUpdate(self): + """ + Get nextUpdate value from this CRL. + """ - def get_DER(self): - """ - Get the DER value of this CRL. - """ + return self.get_POW().getNextUpdate() - self.check() - if self.DER: - return self.DER - if self.POW: - self.DER = self.POW.derWrite() - return self.get_DER() - raise rpki.exceptions.DERObjectConversionError("No conversion path to DER available") + def getIssuer(self): + """ + Get issuer value of this CRL. + """ - def get_POW(self): - """ - Get the rpki.POW value of this CRL. - """ + return X501DN.from_POW(self.get_POW().getIssuer()) - self.check() - if not self.POW: # pylint: disable=E0203 - self.POW = rpki.POW.CRL.derRead(self.get_DER()) - return self.POW + def getCRLNumber(self): + """ + Get CRL Number value for this CRL. + """ - def getThisUpdate(self): - """ - Get thisUpdate value from this CRL. - """ + return self.get_POW().getCRLNumber() - return self.get_POW().getThisUpdate() + @classmethod + def generate(cls, keypair, issuer, serial, thisUpdate, nextUpdate, revokedCertificates, version = 1): + """ + Generate a new CRL. + """ - def getNextUpdate(self): - """ - Get nextUpdate value from this CRL. - """ + crl = rpki.POW.CRL() + crl.setVersion(version) + crl.setIssuer(issuer.getSubject().get_POW()) + crl.setThisUpdate(thisUpdate) + crl.setNextUpdate(nextUpdate) + crl.setAKI(issuer.get_SKI()) + crl.setCRLNumber(serial) + crl.addRevocations(revokedCertificates) + crl.sign(keypair.get_POW()) + return cls(POW = crl) - return self.get_POW().getNextUpdate() + @property + def creation_timestamp(self): + """ + Time at which this object was created. + """ - def getIssuer(self): - """ - Get issuer value of this CRL. - """ - - return X501DN.from_POW(self.get_POW().getIssuer()) - - def getCRLNumber(self): - """ - Get CRL Number value for this CRL. - """ - - return self.get_POW().getCRLNumber() - - @classmethod - def generate(cls, keypair, issuer, serial, thisUpdate, nextUpdate, revokedCertificates, version = 1): - """ - Generate a new CRL. - """ - - crl = rpki.POW.CRL() - crl.setVersion(version) - crl.setIssuer(issuer.getSubject().get_POW()) - crl.setThisUpdate(thisUpdate) - crl.setNextUpdate(nextUpdate) - crl.setAKI(issuer.get_SKI()) - crl.setCRLNumber(serial) - crl.addRevocations(revokedCertificates) - crl.sign(keypair.get_POW()) - return cls(POW = crl) - - @property - def creation_timestamp(self): - """ - Time at which this object was created. - """ - - return self.getThisUpdate() + return self.getThisUpdate() ## @var uri_dispatch_map # Map of known URI filename extensions and corresponding classes. @@ -2157,8 +2157,8 @@ uri_dispatch_map = { } def uri_dispatch(uri): - """ - Return the Python class object corresponding to a given URI. - """ + """ + Return the Python class object corresponding to a given URI. + """ - return uri_dispatch_map[os.path.splitext(uri)[1]] + return uri_dispatch_map[os.path.splitext(uri)[1]] |