diff options
Diffstat (limited to 'rtr-origin/rtr-origin.py')
-rwxr-xr-x[-rw-r--r--] | rtr-origin/rtr-origin.py | 253 |
1 files changed, 171 insertions, 82 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py index ebb729ec..e81f6963 100644..100755 --- a/rtr-origin/rtr-origin.py +++ b/rtr-origin/rtr-origin.py @@ -1,34 +1,36 @@ -""" -Router origin-authentication rpki-router protocol implementation. See -draft-ietf-sidr-rpki-rtr in fine Internet-Draft repositories near you. - -Run the program with the --help argument for usage information, or see -documentation for the *_main() functions. - - -$Id$ - -Copyright (C) 2009-2010 Internet Systems Consortium ("ISC") - -Permission to use, copy, modify, and distribute this software for any -purpose with or without fee is hereby granted, provided that the above -copyright notice and this permission notice appear in all copies. - -THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH -REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY -AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, -INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM -LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE -OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR -PERFORMANCE OF THIS SOFTWARE. -""" +#!/usr/bin/env python + +# Router origin-authentication rpki-router protocol implementation. See +# draft-ietf-sidr-rpki-rtr in fine Internet-Draft repositories near you. +# +# Run the program with the --help argument for usage information, or see +# documentation for the *_main() functions. +# +# +# $Id$ +# +# Copyright (C) 2009-2010 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notice and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH +# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +# AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, +# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE +# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. import sys, os, struct, time, glob, socket, fcntl, signal, syslog -import asyncore, asynchat, subprocess, traceback, getopt, bisect +import asyncore, asynchat, subprocess, traceback, getopt, bisect, random + class IgnoreThisRecord(Exception): pass + class timestamp(int): """ Wrapper around time module. @@ -72,6 +74,23 @@ class v6addr(ipaddr): size = 16 +def read_current(): + """ + Read current serial number and nonce. Return None for both if + serial and nonce not recorded. For backwards compatibility, treat + file containing just a serial number as having a nonce of zero. + """ + try: + f = open("current", "r") + values = tuple(int(s) for s in f.read().split()) + f.close() + return values[0], values[1] + except IndexError: + return values[0], 0 + except IOError: + return None, None + + class read_buffer(object): """ Wrapper around synchronous/asynchronous read state. @@ -174,9 +193,9 @@ class pdu(object): Send a content of a file as a cache response. Caller should catch IOError. """ f = open(filename, "rb") - server.push_pdu(cache_response()) + server.push_pdu(cache_response(nonce = server.current_nonce)) server.push_file(f) - server.push_pdu(end_of_data(serial = server.current_serial)) + server.push_pdu(end_of_data(serial = server.current_serial, nonce = server.current_nonce)) def send_nodata(self, server): """ @@ -186,39 +205,71 @@ class pdu(object): class pdu_with_serial(pdu): """ - Base class for PDUs consisting of just a serial number. + Base class for PDUs consisting of just a serial number and nonce. """ header_struct = struct.Struct("!BBHLL") - def __init__(self, serial = None): + def __init__(self, serial = None, nonce = None): if serial is not None: - if isinstance(serial, str): - serial = int(serial) assert isinstance(serial, int) self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce def __str__(self): - return "[%s, serial #%s]" % (self.__class__.__name__, self.serial) + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) def to_pdu(self): """ - Generate the wire format PDU for this prefix. + Generate the wire format PDU. """ if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size, self.serial) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size, self.serial) return self._pdu def got_pdu(self, reader): if not reader.ready(): return None b = reader.get(self.header_struct.size) - version, pdu_type, zero, length, self.serial = self.header_struct.unpack(b) - assert zero == 0 + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) assert length == 12 assert b == self.to_pdu() return self +class pdu_nonce(pdu): + """ + Base class for PDUs consisting of just a nonce. + """ + + header_struct = struct.Struct("!BBHL") + + def __init__(self, nonce = None): + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert length == 8 + assert b == self.to_pdu() + return self + class pdu_empty(pdu): """ Base class for empty PDUs. @@ -260,10 +311,10 @@ class serial_notify(pdu_with_serial): reset_query, depending on what we already know. """ log(self) - if client.current_serial is None: + if client.current_serial is None or client.current_nonce != self.nonce: client.push_pdu(reset_query()) elif self.serial != client.current_serial: - client.push_pdu(serial_query(serial = client.current_serial)) + client.push_pdu(serial_query(serial = client.current_serial, nonce = client.current_nonce)) else: log("[Notify did not change serial number, ignoring]") @@ -283,13 +334,16 @@ class serial_query(pdu_with_serial): log(self) if server.get_serial() is None: self.send_nodata(server) - elif int(server.current_serial) == self.serial: + elif server.current_nonce != self.nonce: + log("[Client requested wrong nonce, resetting client]") + server.push_pdu(cache_reset()) + elif server.current_serial == self.serial: log("[Client is already current, sending empty IXFR]") - server.push_pdu(cache_response()) - server.push_pdu(end_of_data(serial = server.current_serial)) + server.push_pdu(cache_response(nonce = server.current_nonce)) + server.push_pdu(end_of_data(serial = server.current_serial, nonce = server.current_nonce)) else: try: - self.send_file(server, "%s.ix.%s" % (server.current_serial, self.serial)) + self.send_file(server, "%d.ix.%d" % (server.current_serial, self.serial)) except IOError: server.push_pdu(cache_reset()) @@ -309,12 +363,12 @@ class reset_query(pdu_empty): self.send_nodata(server) else: try: - fn = "%s.ax" % server.current_serial + fn = "%d.ax" % server.current_serial self.send_file(server, fn) except IOError: server.push_pdu(error_report(errno = error_report.codes["Internal Error"], errpdu = self, errmsg = "Couldn't open %s" % fn)) -class cache_response(pdu_empty): +class cache_response(pdu_nonce): """ Incremental Response PDU. """ @@ -334,6 +388,7 @@ class end_of_data(pdu_with_serial): """ log(self) client.current_serial = self.serial + client.current_nonce = self.nonce class cache_reset(pdu_empty): """ @@ -386,12 +441,12 @@ class prefix(pdu): return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, ":".join(("%02X" % ord(b) for b in self.to_pdu()))) def show(self): - print "# Class: ", self.__class__.__name__ - print "# ASN: ", self.asn - print "# Prefix: ", self.prefix - print "# Prefixlen: ", self.prefixlen - print "# MaxPrefixlen:", self.max_prefixlen - print "# Announce: ", self.announce + log("# Class: %s" % self.__class__.__name__) + log("# ASN: %s" % self.asn) + log("# Prefix: %s" % self.prefix) + log("# Prefixlen: %s" % self.prefixlen) + log("# MaxPrefixlen: %s" % self.max_prefixlen) + log("# Announce: %s" % self.announce) def check(self): """ @@ -587,6 +642,11 @@ class prefix_set(list): p = r.retry() self.append(p) + @staticmethod + def seq_ge(a, b): + return ((a - b) % (1 << 32)) < (1 << 31) + + class axfr_set(prefix_set): """ Object representing a complete set of prefixes, that is, one @@ -661,20 +721,49 @@ class axfr_set(prefix_set): f.write(p.to_pdu()) f.close() + def destroy_old_data(self): + """ + Destroy old data files, presumably because our nonce changed and + the old serial numbers are no longer valid. + """ + for i in glob.iglob("*.ix.*"): + os.unlink(i) + for i in glob.iglob("*.ax"): + if i != self.filename(): + os.unlink(i) + + @staticmethod + def new_nonce(): + """ + Create and return a new nonce value. + """ + if force_zero_nonce: + return 0 + try: + return random.SystemRandom().getrandbits(16) + except NotImplementedError: + return random.getrandbits(16) + def mark_current(self): """ - Mark the current serial number as current. + Save current serial number and nonce, creating new nonce if + necessary. Creating a new nonce triggers cleanup of old state, as + the new nonce invalidates all old serial numbers. """ + old_serial, nonce = read_current() + if old_serial is None or self.seq_ge(old_serial, self.serial): + log("Deleting old data and creating new nonce") + self.destroy_old_data() + nonce = self.new_nonce() tmpfn = "current.%d.tmp" % os.getpid() try: f = open(tmpfn, "w") - f.write("%d\n" % self.serial) + f.write("%d %d\n" % (self.serial, nonce)) f.close() os.rename(tmpfn, "current") - except: + finally: if os.path.exists(tmpfn): os.unlink(tmpfn) - raise def save_ixfr(self, other): """ @@ -709,9 +798,9 @@ class axfr_set(prefix_set): """ Print this axfr_set. """ - print "# AXFR %d (%s)" % (self.serial, self.serial) + log("# AXFR %d (%s)" % (self.serial, self.serial)) for p in self: - print p + log(p) @staticmethod def read_bgpdump(filename): @@ -786,10 +875,10 @@ class ixfr_set(prefix_set): """ Print this ixfr_set. """ - print "# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial, - self.to_serial, self.to_serial) + log("# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial, + self.to_serial, self.to_serial)) for p in self: - print p + log(p) class file_producer(object): """ @@ -980,13 +1069,7 @@ class server_channel(pdu_channel): happen, but maybe we got started in server mode while the cronjob mode instance is still building its database. """ - try: - f = open("current", "r") - self.current_serial = f.read().strip() - assert self.current_serial.isdigit() - f.close() - except IOError: - self.current_serial = None + self.current_serial, self.current_nonce = read_current() return self.current_serial def check_serial(self): @@ -1000,8 +1083,8 @@ class server_channel(pdu_channel): """ Cronjob instance kicked us, send a notify message. """ - if self.check_serial(): - self.push_pdu(serial_notify(serial = self.current_serial)) + if self.check_serial() is not None: + self.push_pdu(serial_notify(serial = self.current_serial, nonce = self.current_nonce)) else: log("Cronjob kicked me without a valid current serial number") @@ -1011,6 +1094,7 @@ class client_channel(pdu_channel): """ current_serial = None + current_nonce = None def __init__(self, sock, proc, killsig): self.killsig = killsig @@ -1182,7 +1266,7 @@ def cronjob_main(argv): for f in glob.iglob("*.ax"): t = timestamp(os.stat(f).st_mtime) if t < cutoff: - print "# Deleting old file %s, timestamp %s" % (f, t) + log("# Deleting old file %s, timestamp %s" % (f, t)) os.unlink(f) pdus = axfr_set.parse_rcynic(argv[0]) @@ -1192,28 +1276,31 @@ def cronjob_main(argv): pdus.save_ixfr(axfr_set.load(axfr)) pdus.mark_current() - print "# New serial is %s" % pdus.serial + log("# New serial is %d (%s)" % (pdus.serial, pdus.serial)) try: os.stat(kickme_dir) except OSError: - print '# Creating directory "%s"' % kickme_dir + log('# Creating directory "%s"' % kickme_dir) os.makedirs(kickme_dir) - msg = "Good morning, serial %s is ready" % pdus.serial + msg = "Good morning, serial %d is ready" % pdus.serial sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) for name in glob.iglob("%s.*" % kickme_base): try: - print "# Kicking %s" % name + log("# Kicking %s" % name) sock.sendto(msg, name) except: - print "# Failed to kick %s" % name + log("# Failed to kick %s" % name) sock.close() old_ixfrs.sort() for ixfr in old_ixfrs: - print "# Deleting old file %s" % ixfr - os.unlink(ixfr) + try: + log("# Deleting old file %s" % ixfr) + os.unlink(ixfr) + except OSError: + pass def show_main(argv): """ @@ -1311,10 +1398,10 @@ def client_main(argv): else: sys.exit("Unexpected arguments: %r" % (argv,)) while True: - if client.current_serial is None: + if client.current_serial is None or client.current_nonce is None: client.push_pdu(reset_query()) else: - client.push_pdu(serial_query(serial = client.current_serial)) + client.push_pdu(serial_query(serial = client.current_serial, nonce = client.current_nonce)) wakeup = time.time() + 600 while wakeup > time.time(): asyncore.loop(timeout = wakeup - time.time(), count = 1) @@ -1361,8 +1448,6 @@ def bgpdump_main(argv): axfrs.append(db.filename()) log("DB serial now %d (%s)" % (db.serial, db.serial)) - print "Finished generating AXFRs, last is", axfrs[-1] - del axfrs[-1] for axfr in axfrs: @@ -1383,6 +1468,8 @@ print_roa = os.path.normpath(os.path.join(sys.path[0], "..", "utils", if not os.path.exists(print_roa): print_roa = "print_roa" +force_zero_nonce = False + mode = None kickme_dir = "sockets" @@ -1405,11 +1492,13 @@ def usage(): print func.__doc__ sys.exit(0) -opts, argv = getopt.getopt(sys.argv[1:], "h?", ["help"] + main_dispatch.keys()) +opts, argv = getopt.getopt(sys.argv[1:], "hz?", ["help", "zero-nonce"] + main_dispatch.keys()) for o, a in opts: if o in ("-h", "--help", "-?"): usage() - if len(o) > 2 and o[2:] in main_dispatch: + elif o in ("-z", "--zero-nonce"): + force_zero_nonce = True + elif len(o) > 2 and o[2:] in main_dispatch: if mode is not None: sys.exit("Conflicting modes specified") mode = o[2:] |