diff options
Diffstat (limited to 'rpki/rtr')
-rwxr-xr-x | rpki/rtr/bgpdump.py | 485 | ||||
-rw-r--r-- | rpki/rtr/channels.py | 366 | ||||
-rw-r--r-- | rpki/rtr/client.py | 816 | ||||
-rw-r--r-- | rpki/rtr/generator.py | 959 | ||||
-rw-r--r-- | rpki/rtr/main.py | 85 | ||||
-rw-r--r-- | rpki/rtr/pdus.py | 990 | ||||
-rw-r--r-- | rpki/rtr/server.py | 874 |
7 files changed, 2293 insertions, 2282 deletions
diff --git a/rpki/rtr/bgpdump.py b/rpki/rtr/bgpdump.py index fc3ae9df..22ac0d83 100755 --- a/rpki/rtr/bgpdump.py +++ b/rpki/rtr/bgpdump.py @@ -39,292 +39,295 @@ from rpki.rtr.channels import Timestamp class IgnoreThisRecord(Exception): - pass + pass class PrefixPDU(rpki.rtr.generator.PrefixPDU): - @staticmethod - def from_bgpdump(line, rib_dump): - try: - assert isinstance(rib_dump, bool) - fields = line.split("|") - - # Parse prefix, including figuring out IP protocol version - cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU - self = cls() - self.timestamp = Timestamp(fields[1]) - p, l = fields[5].split("/") - self.prefix = rpki.POW.IPAddress(p) - self.prefixlen = self.max_prefixlen = int(l) - - # Withdrawals don't have AS paths, so be careful - assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W") - if fields[2] == "W": - self.asn = 0 - self.announce = 0 - else: - self.announce = 1 - if not fields[6] or "{" in fields[6] or "(" in fields[6]: - raise IgnoreThisRecord - a = fields[6].split()[-1] - if "." in a: - a = [int(s) for s in a.split(".")] - if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535: - logging.warn("Bad dotted ASNum %r, ignoring record", fields[6]) + @staticmethod + def from_bgpdump(line, rib_dump): + try: + assert isinstance(rib_dump, bool) + fields = line.split("|") + + # Parse prefix, including figuring out IP protocol version + cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU + self = cls(version = min(rpki.rtr.pdus.PDU.version_map)) + self.timestamp = Timestamp(fields[1]) + p, l = fields[5].split("/") + self.prefix = rpki.POW.IPAddress(p) + self.prefixlen = self.max_prefixlen = int(l) + + # Withdrawals don't have AS paths, so be careful + assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W") + if fields[2] == "W": + self.asn = 0 + self.announce = 0 + else: + self.announce = 1 + if not fields[6] or "{" in fields[6] or "(" in fields[6]: + raise IgnoreThisRecord + a = fields[6].split()[-1] + if "." in a: + a = [int(s) for s in a.split(".")] + if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535: + logging.warn("Bad dotted ASNum %r, ignoring record", fields[6]) + raise IgnoreThisRecord + a = (a[0] << 16) | a[1] + else: + a = int(a) + self.asn = a + + self.check() + return self + + except IgnoreThisRecord: + raise + + except Exception, e: + logging.warn("Ignoring line %r: %s", line, e) raise IgnoreThisRecord - a = (a[0] << 16) | a[1] - else: - a = int(a) - self.asn = a - self.check() - return self - except IgnoreThisRecord: - raise +class AXFRSet(rpki.rtr.generator.AXFRSet): - except Exception, e: - logging.warn("Ignoring line %r: %s", line, e) - raise IgnoreThisRecord + serial = None + + @staticmethod + def read_bgpdump(filename): + assert filename.endswith(".bz2") + logging.debug("Reading %s", filename) + bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE) + bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE) + return bgpdump.stdout + + @classmethod + def parse_bgpdump_rib_dump(cls, filename): + # pylint: disable=W0201 + assert os.path.basename(filename).startswith("ribs.") + self = cls(version = min(rpki.rtr.pdus.PDU.version_map)) + self.serial = None + for line in cls.read_bgpdump(filename): + try: + pfx = PrefixPDU.from_bgpdump(line, rib_dump = True) + except IgnoreThisRecord: + continue + self.append(pfx) + self.serial = pfx.timestamp + if self.serial is None: + sys.exit("Failed to parse anything useful from %s" % filename) + self.sort() + for i in xrange(len(self) - 2, -1, -1): + if self[i] == self[i + 1]: + del self[i + 1] + return self + + def parse_bgpdump_update(self, filename): + assert os.path.basename(filename).startswith("updates.") + for line in self.read_bgpdump(filename): + try: + pfx = PrefixPDU.from_bgpdump(line, rib_dump = False) + except IgnoreThisRecord: + continue + announce = pfx.announce + pfx.announce = 1 + i = bisect.bisect_left(self, pfx) + if announce: + if i >= len(self) or pfx != self[i]: + self.insert(i, pfx) + else: + while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen: + del self[i] + self.serial = pfx.timestamp -class AXFRSet(rpki.rtr.generator.AXFRSet): +def bgpdump_convert_main(args): + """ + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + Simulate route origin data from a set of BGP dump files. + argv is an ordered list of filenames. Each file must be a BGP RIB + dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by + this program's --cronjob command. The first file must be a RIB dump + or AXFR dump, it cannot be an UPDATE dump. Output will be a set of + AXFR and IXFR files with timestamps derived from the BGP dumps, + which can be used as input to this program's --server command for + test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL. + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + """ + + first = True + db = None + axfrs = [] + version = max(rpki.rtr.pdus.PDU.version_map.iterkeys()) + + for filename in args.files: + + if ".ax.v" in filename: + logging.debug("Reading %s", filename) + db = AXFRSet.load(filename) + + elif os.path.basename(filename).startswith("ribs."): + db = AXFRSet.parse_bgpdump_rib_dump(filename) + db.save_axfr() + + elif not first: + assert db is not None + db.parse_bgpdump_update(filename) + db.save_axfr() - @staticmethod - def read_bgpdump(filename): - assert filename.endswith(".bz2") - logging.debug("Reading %s", filename) - bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE) - bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE) - return bgpdump.stdout - - @classmethod - def parse_bgpdump_rib_dump(cls, filename): - assert os.path.basename(filename).startswith("ribs.") - self = cls() - self.serial = None - for line in cls.read_bgpdump(filename): - try: - pfx = PrefixPDU.from_bgpdump(line, rib_dump = True) - except IgnoreThisRecord: - continue - self.append(pfx) - self.serial = pfx.timestamp - if self.serial is None: - sys.exit("Failed to parse anything useful from %s" % filename) - self.sort() - for i in xrange(len(self) - 2, -1, -1): - if self[i] == self[i + 1]: - del self[i + 1] - return self - - def parse_bgpdump_update(self, filename): - assert os.path.basename(filename).startswith("updates.") - for line in self.read_bgpdump(filename): - try: - pfx = PrefixPDU.from_bgpdump(line, rib_dump = False) - except IgnoreThisRecord: - continue - announce = pfx.announce - pfx.announce = 1 - i = bisect.bisect_left(self, pfx) - if announce: - if i >= len(self) or pfx != self[i]: - self.insert(i, pfx) - else: - while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen: - del self[i] - self.serial = pfx.timestamp + else: + sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename) + logging.debug("DB serial now %d (%s)", db.serial, db.serial) + if first and rpki.rtr.server.read_current(version) == (None, None): + db.mark_current() + first = False -def bgpdump_convert_main(args): - """ - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - Simulate route origin data from a set of BGP dump files. - argv is an ordered list of filenames. Each file must be a BGP RIB - dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by - this program's --cronjob command. The first file must be a RIB dump - or AXFR dump, it cannot be an UPDATE dump. Output will be a set of - AXFR and IXFR files with timestamps derived from the BGP dumps, - which can be used as input to this program's --server command for - test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL. - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - """ - - first = True - db = None - axfrs = [] - version = max(rpki.rtr.pdus.PDU.version_map.iterkeys()) - - for filename in args.files: - - if ".ax.v" in filename: - logging.debug("Reading %s", filename) - db = AXFRSet.load(filename) - - elif os.path.basename(filename).startswith("ribs."): - db = AXFRSet.parse_bgpdump_rib_dump(filename) - db.save_axfr() - - elif not first: - assert db is not None - db.parse_bgpdump_update(filename) - db.save_axfr() - - else: - sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename) - - logging.debug("DB serial now %d (%s)", db.serial, db.serial) - if first and rpki.rtr.server.read_current(version) == (None, None): - db.mark_current() - first = False - - for axfr in axfrs: - logging.debug("Loading %s", axfr) - ax = AXFRSet.load(axfr) - logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial) - db.save_ixfr(ax) - del ax - - axfrs.append(db.filename()) + for axfr in axfrs: + logging.debug("Loading %s", axfr) + ax = AXFRSet.load(axfr) + logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial) + db.save_ixfr(ax) + del ax + + axfrs.append(db.filename()) def bgpdump_select_main(args): - """ - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - Simulate route origin data from a set of BGP dump files. - Set current serial number to correspond to an .ax file created by - converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL. - * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * - """ + """ + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + Simulate route origin data from a set of BGP dump files. + Set current serial number to correspond to an .ax file created by + converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL. + * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! * + """ - head, sep, tail = os.path.basename(args.ax_file).partition(".") - if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit(): - sys.exit("Argument must be name of a .ax file") + head, sep, tail = os.path.basename(args.ax_file).partition(".") + if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit(): + sys.exit("Argument must be name of a .ax file") - serial = Timestamp(head) - version = int(tail[4:]) + serial = Timestamp(head) + version = int(tail[4:]) - if version not in rpki.rtr.pdus.PDU.version_map: - sys.exit("Unknown protocol version %d" % version) + if version not in rpki.rtr.pdus.PDU.version_map: + sys.exit("Unknown protocol version %d" % version) - nonce = rpki.rtr.server.read_current(version)[1] - if nonce is None: - nonce = rpki.rtr.generator.new_nonce() + nonce = rpki.rtr.server.read_current(version)[1] + if nonce is None: + nonce = rpki.rtr.generator.AXFRSet.new_nonce(force_zero_nonce = False) - rpki.rtr.server.write_current(serial, nonce, version) - rpki.rtr.generator.kick_all(serial) + rpki.rtr.server.write_current(serial, nonce, version) + rpki.rtr.generator.kick_all(serial) class BGPDumpReplayClock(object): - """ - Internal clock for replaying BGP dump files. + """ + Internal clock for replaying BGP dump files. - * DANGER WILL ROBINSON! * - * DEBUGGING AND TEST USE ONLY! * + * DANGER WILL ROBINSON! * + * DEBUGGING AND TEST USE ONLY! * - This class replaces the normal on-disk serial number mechanism with - an in-memory version based on pre-computed data. + This class replaces the normal on-disk serial number mechanism with + an in-memory version based on pre-computed data. - bgpdump_server_main() uses this hack to replay historical data for - testing purposes. DO NOT USE THIS IN PRODUCTION. + bgpdump_server_main() uses this hack to replay historical data for + testing purposes. DO NOT USE THIS IN PRODUCTION. - You have been warned. - """ + You have been warned. + """ - def __init__(self): - self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")] - self.timestamps.sort() - self.offset = self.timestamps[0] - int(time.time()) - self.nonce = rpki.rtr.generator.new_nonce() + def __init__(self): + self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")] + self.timestamps.sort() + self.offset = self.timestamps[0] - int(time.time()) + self.nonce = rpki.rtr.generator.AXFRSet.new_nonce(force_zero_nonce = False) - def __nonzero__(self): - return len(self.timestamps) > 0 + def __nonzero__(self): + return len(self.timestamps) > 0 - def now(self): - return Timestamp.now(self.offset) + def now(self): + return Timestamp.now(self.offset) - def read_current(self, version): - now = self.now() - while len(self.timestamps) > 1 and now >= self.timestamps[1]: - del self.timestamps[0] - return self.timestamps[0], self.nonce + def read_current(self, version): + now = self.now() + while len(self.timestamps) > 1 and now >= self.timestamps[1]: + del self.timestamps[0] + return self.timestamps[0], self.nonce - def siesta(self): - now = self.now() - if len(self.timestamps) <= 1: - return None - elif now < self.timestamps[1]: - return self.timestamps[1] - now - else: - return 1 + def siesta(self): + now = self.now() + if len(self.timestamps) <= 1: + return None + elif now < self.timestamps[1]: + return self.timestamps[1] - now + else: + return 1 def bgpdump_server_main(args): - """ - Simulate route origin data from a set of BGP dump files. + """ + Simulate route origin data from a set of BGP dump files. + + * DANGER WILL ROBINSON! * + * DEBUGGING AND TEST USE ONLY! * + + This is a clone of server_main() which replaces the external serial + number updates triggered via the kickme channel by cronjob_main with + an internal clocking mechanism to replay historical test data. - * DANGER WILL ROBINSON! * - * DEBUGGING AND TEST USE ONLY! * + DO NOT USE THIS IN PRODUCTION. - This is a clone of server_main() which replaces the external serial - number updates triggered via the kickme channel by cronjob_main with - an internal clocking mechanism to replay historical test data. + You have been warned. + """ - DO NOT USE THIS IN PRODUCTION. + logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server.hostport_tag())) - You have been warned. - """ + logger.debug("[Starting]") - logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag())) + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) - logger.debug("[Starting]") + # Yes, this really does replace a global function defined in another + # module with a bound method to our clock object. Fun stuff, huh? + # + clock = BGPDumpReplayClock() + rpki.rtr.server.read_current = clock.read_current - if args.rpki_rtr_dir: try: - os.chdir(args.rpki_rtr_dir) - except OSError, e: - sys.exit(e) - - # Yes, this really does replace a global function defined in another - # module with a bound method to our clock object. Fun stuff, huh? - # - clock = BGPDumpReplayClock() - rpki.rtr.server.read_current = clock.read_current - - try: - server = rpki.rtr.server.ServerChannel(logger = logger) - old_serial = server.get_serial() - logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial) - while clock: - new_serial = server.get_serial() - if old_serial != new_serial: - logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial) - server.notify() - old_serial = new_serial - asyncore.loop(timeout = clock.siesta(), count = 1) - except KeyboardInterrupt: - sys.exit(0) + server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) + old_serial = server.get_serial() + logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial) + while clock: + new_serial = server.get_serial() + if old_serial != new_serial: + logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial) + server.notify() + old_serial = new_serial + asyncore.loop(timeout = clock.siesta(), count = 1) + except KeyboardInterrupt: + sys.exit(0) def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__, - help = "Convert bgpdump to fake ROAs") - subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog") - subparser.add_argument("files", nargs = "+", help = "input files") - - subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__, - help = "Set current serial number for fake ROA data") - subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog") - subparser.add_argument("ax_file", help = "name of the .ax to select") - - subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__, - help = "Replay fake ROAs generated from historical data") - subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__, + help = "Convert bgpdump to fake ROAs") + subparser.set_defaults(func = bgpdump_convert_main, default_log_destination = "syslog") + subparser.add_argument("files", nargs = "+", help = "input files") + + subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__, + help = "Set current serial number for fake ROA data") + subparser.set_defaults(func = bgpdump_select_main, default_log_destination = "syslog") + subparser.add_argument("ax_file", help = "name of the .ax to select") + + subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__, + help = "Replay fake ROAs generated from historical data") + subparser.set_defaults(func = bgpdump_server_main, default_log_destination = "syslog") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") diff --git a/rpki/rtr/channels.py b/rpki/rtr/channels.py index d14c024d..a4dccbc1 100644 --- a/rpki/rtr/channels.py +++ b/rpki/rtr/channels.py @@ -32,215 +32,217 @@ import rpki.rtr.pdus class Timestamp(int): - """ - Wrapper around time module. - """ - - def __new__(cls, t): - # __new__() is a static method, not a class method, hence the odd calling sequence. - return super(Timestamp, cls).__new__(cls, t) - - @classmethod - def now(cls, delta = 0): - return cls(time.time() + delta) - - def __str__(self): - return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self)) - - -class ReadBuffer(object): - """ - Wrapper around synchronous/asynchronous read state. - - This also handles tracking the current protocol version, - because it has to go somewhere and there's no better place. - """ - - def __init__(self): - self.buffer = "" - self.version = None - - def update(self, need, callback): """ - Update count of needed bytes and callback, then dispatch to callback. + Wrapper around time module. """ - self.need = need - self.callback = callback - return self.retry() + def __new__(cls, t): + # __new__() is a static method, not a class method, hence the odd calling sequence. + return super(Timestamp, cls).__new__(cls, t) - def retry(self): - """ - Try dispatching to the callback again. - """ + @classmethod + def now(cls, delta = 0): + return cls(time.time() + delta) - return self.callback(self) + def __str__(self): + return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self)) - def available(self): - """ - How much data do we have available in this buffer? - """ - return len(self.buffer) - - def needed(self): - """ - How much more data does this buffer need to become ready? +class ReadBuffer(object): """ + Wrapper around synchronous/asynchronous read state. - return self.need - self.available() - - def ready(self): - """ - Is this buffer ready to read yet? + This also handles tracking the current protocol version, + because it has to go somewhere and there's no better place. """ - return self.available() >= self.need + def __init__(self): + self.buffer = "" + self.version = None + self.need = None + self.callback = None - def get(self, n): - """ - Hand some data to the caller. - """ + def update(self, need, callback): + """ + Update count of needed bytes and callback, then dispatch to callback. + """ - b = self.buffer[:n] - self.buffer = self.buffer[n:] - return b + self.need = need + self.callback = callback + return self.retry() - def put(self, b): - """ - Accumulate some data. - """ + def retry(self): + """ + Try dispatching to the callback again. + """ - self.buffer += b + return self.callback(self) - def check_version(self, version): - """ - Track version number of PDUs read from this buffer. - Once set, the version must not change. - """ + def available(self): + """ + How much data do we have available in this buffer? + """ - if self.version is not None and version != self.version: - raise rpki.rtr.pdus.CorruptData( - "Received PDU version %d, expected %d" % (version, self.version)) - if self.version is None and version not in rpki.rtr.pdus.PDU.version_map: - raise rpki.rtr.pdus.UnsupportedProtocolVersion( - "Received PDU version %s, known versions %s" % ( - version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map))) - self.version = version + return len(self.buffer) + def needed(self): + """ + How much more data does this buffer need to become ready? + """ -class PDUChannel(asynchat.async_chat, object): - """ - asynchat subclass that understands our PDUs. This just handles - network I/O. Specific engines (client, server) should be subclasses - of this with methods that do something useful with the resulting - PDUs. - """ - - def __init__(self, root_pdu_class, sock = None): - asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super() - self.reader = ReadBuffer() - assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU) - self.root_pdu_class = root_pdu_class - - @property - def version(self): - return self.reader.version - - @version.setter - def version(self, version): - self.reader.check_version(version) - - def start_new_pdu(self): - """ - Start read of a new PDU. - """ - - try: - p = self.root_pdu_class.read_pdu(self.reader) - while p is not None: - self.deliver_pdu(p) - p = self.root_pdu_class.read_pdu(self.reader) - except rpki.rtr.pdus.PDUException, e: - self.push_pdu(e.make_error_report(version = self.version)) - self.close_when_done() - else: - assert not self.reader.ready() - self.set_terminator(self.reader.needed()) - - def collect_incoming_data(self, data): - """ - Collect data into the read buffer. - """ - - self.reader.put(data) - - def found_terminator(self): - """ - Got requested data, see if we now have a PDU. If so, pass it - along, then restart cycle for a new PDU. - """ - - p = self.reader.retry() - if p is None: - self.set_terminator(self.reader.needed()) - else: - self.deliver_pdu(p) - self.start_new_pdu() - - def push_pdu(self, pdu): - """ - Write PDU to stream. - """ + return self.need - self.available() - try: - self.push(pdu.to_pdu()) - except OSError, e: - if e.errno != errno.EAGAIN: - raise + def ready(self): + """ + Is this buffer ready to read yet? + """ - def log(self, msg): - """ - Intercept asyncore's logging. - """ + return self.available() >= self.need - logging.info(msg) + def get(self, n): + """ + Hand some data to the caller. + """ - def log_info(self, msg, tag = "info"): - """ - Intercept asynchat's logging. - """ + b = self.buffer[:n] + self.buffer = self.buffer[n:] + return b - logging.info("asynchat: %s: %s", tag, msg) + def put(self, b): + """ + Accumulate some data. + """ - def handle_error(self): - """ - Handle errors caught by asyncore main loop. - """ + self.buffer += b - logging.exception("[Unhandled exception]") - logging.critical("[Exiting after unhandled exception]") - sys.exit(1) + def check_version(self, version): + """ + Track version number of PDUs read from this buffer. + Once set, the version must not change. + """ - def init_file_dispatcher(self, fd): - """ - Kludge to plug asyncore.file_dispatcher into asynchat. Call from - subclass's __init__() method, after calling - PDUChannel.__init__(), and don't read this on a full stomach. - """ + if self.version is not None and version != self.version: + raise rpki.rtr.pdus.CorruptData( + "Received PDU version %d, expected %d" % (version, self.version)) + if self.version is None and version not in rpki.rtr.pdus.PDU.version_map: + raise rpki.rtr.pdus.UnsupportedProtocolVersion( + "Received PDU version %s, known versions %s" % ( + version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map))) + self.version = version - self.connected = True - self._fileno = fd - self.socket = asyncore.file_wrapper(fd) - self.add_channel() - flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(fd, fcntl.F_SETFL, flags) - def handle_close(self): - """ - Exit when channel closed. +class PDUChannel(asynchat.async_chat, object): """ - - asynchat.async_chat.handle_close(self) - sys.exit(0) + asynchat subclass that understands our PDUs. This just handles + network I/O. Specific engines (client, server) should be subclasses + of this with methods that do something useful with the resulting + PDUs. + """ + + def __init__(self, root_pdu_class, sock = None): + asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super() + self.reader = ReadBuffer() + assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU) + self.root_pdu_class = root_pdu_class + + @property + def version(self): + return self.reader.version + + @version.setter + def version(self, version): + self.reader.check_version(version) + + def start_new_pdu(self): + """ + Start read of a new PDU. + """ + + try: + p = self.root_pdu_class.read_pdu(self.reader) + while p is not None: + self.deliver_pdu(p) + p = self.root_pdu_class.read_pdu(self.reader) + except rpki.rtr.pdus.PDUException, e: + self.push_pdu(e.make_error_report(version = self.version)) + self.close_when_done() + else: + assert not self.reader.ready() + self.set_terminator(self.reader.needed()) + + def collect_incoming_data(self, data): + """ + Collect data into the read buffer. + """ + + self.reader.put(data) + + def found_terminator(self): + """ + Got requested data, see if we now have a PDU. If so, pass it + along, then restart cycle for a new PDU. + """ + + p = self.reader.retry() + if p is None: + self.set_terminator(self.reader.needed()) + else: + self.deliver_pdu(p) + self.start_new_pdu() + + def push_pdu(self, pdu): + """ + Write PDU to stream. + """ + + try: + self.push(pdu.to_pdu()) + except OSError, e: + if e.errno != errno.EAGAIN: + raise + + def log(self, msg): + """ + Intercept asyncore's logging. + """ + + logging.info(msg) + + def log_info(self, msg, tag = "info"): + """ + Intercept asynchat's logging. + """ + + logging.info("asynchat: %s: %s", tag, msg) + + def handle_error(self): + """ + Handle errors caught by asyncore main loop. + """ + + logging.exception("[Unhandled exception]") + logging.critical("[Exiting after unhandled exception]") + sys.exit(1) + + def init_file_dispatcher(self, fd): + """ + Kludge to plug asyncore.file_dispatcher into asynchat. Call from + subclass's __init__() method, after calling + PDUChannel.__init__(), and don't read this on a full stomach. + """ + + self.connected = True + self._fileno = fd + self.socket = asyncore.file_wrapper(fd) + self.add_channel() + flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + def handle_close(self): + """ + Exit when channel closed. + """ + + asynchat.async_chat.handle_close(self) + sys.exit(0) diff --git a/rpki/rtr/client.py b/rpki/rtr/client.py index a35ab81d..a8348087 100644 --- a/rpki/rtr/client.py +++ b/rpki/rtr/client.py @@ -37,13 +37,13 @@ from rpki.rtr.channels import Timestamp class PDU(rpki.rtr.pdus.PDU): - def consume(self, client): - """ - Handle results in test client. Default behavior is just to print - out the PDU; data PDU subclasses may override this. - """ + def consume(self, client): + """ + Handle results in test client. Default behavior is just to print + out the PDU; data PDU subclasses may override this. + """ - logging.debug(self) + logging.debug(self) clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU) @@ -52,407 +52,407 @@ clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU) @clone_pdu class SerialNotifyPDU(rpki.rtr.pdus.SerialNotifyPDU): - def consume(self, client): - """ - Respond to a SerialNotifyPDU with either a SerialQueryPDU or a - ResetQueryPDU, depending on what we already know. - """ + def consume(self, client): + """ + Respond to a SerialNotifyPDU with either a SerialQueryPDU or a + ResetQueryPDU, depending on what we already know. + """ - logging.debug(self) - if client.serial is None or client.nonce != self.nonce: - client.push_pdu(ResetQueryPDU(version = client.version)) - elif self.serial != client.serial: - client.push_pdu(SerialQueryPDU(version = client.version, - serial = client.serial, - nonce = client.nonce)) - else: - logging.debug("[Notify did not change serial number, ignoring]") + logging.debug(self) + if client.serial is None or client.nonce != self.nonce: + client.push_pdu(ResetQueryPDU(version = client.version)) + elif self.serial != client.serial: + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.serial, + nonce = client.nonce)) + else: + logging.debug("[Notify did not change serial number, ignoring]") @clone_pdu class CacheResponsePDU(rpki.rtr.pdus.CacheResponsePDU): - def consume(self, client): - """ - Handle CacheResponsePDU. - """ + def consume(self, client): + """ + Handle CacheResponsePDU. + """ - logging.debug(self) - if self.nonce != client.nonce: - logging.debug("[Nonce changed, resetting]") - client.cache_reset() + logging.debug(self) + if self.nonce != client.nonce: + logging.debug("[Nonce changed, resetting]") + client.cache_reset() @clone_pdu class EndOfDataPDUv0(rpki.rtr.pdus.EndOfDataPDUv0): - def consume(self, client): - """ - Handle EndOfDataPDU response. - """ + def consume(self, client): + """ + Handle EndOfDataPDU response. + """ - logging.debug(self) - client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + logging.debug(self) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) @clone_pdu class EndOfDataPDUv1(rpki.rtr.pdus.EndOfDataPDUv1): - def consume(self, client): - """ - Handle EndOfDataPDU response. - """ + def consume(self, client): + """ + Handle EndOfDataPDU response. + """ - logging.debug(self) - client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + logging.debug(self) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) @clone_pdu class CacheResetPDU(rpki.rtr.pdus.CacheResetPDU): - def consume(self, client): - """ - Handle CacheResetPDU response, by issuing a ResetQueryPDU. - """ + def consume(self, client): + """ + Handle CacheResetPDU response, by issuing a ResetQueryPDU. + """ - logging.debug(self) - client.cache_reset() - client.push_pdu(ResetQueryPDU(version = client.version)) + logging.debug(self) + client.cache_reset() + client.push_pdu(ResetQueryPDU(version = client.version)) class PrefixPDU(rpki.rtr.pdus.PrefixPDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - def consume(self, client): """ - Handle one incoming prefix PDU + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. """ - logging.debug(self) - client.consume_prefix(self) + def consume(self, client): + """ + Handle one incoming prefix PDU + """ + + logging.debug(self) + client.consume_prefix(self) @clone_pdu class IPv4PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv4PrefixPDU): - pass + pass @clone_pdu class IPv6PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv6PrefixPDU): - pass + pass @clone_pdu class ErrorReportPDU(PDU, rpki.rtr.pdus.ErrorReportPDU): - pass + pass @clone_pdu class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU): - """ - Router Key PDU. - """ - - def consume(self, client): """ - Handle one incoming Router Key PDU + Router Key PDU. """ - logging.debug(self) - client.consume_routerkey(self) + def consume(self, client): + """ + Handle one incoming Router Key PDU + """ + logging.debug(self) + client.consume_routerkey(self) -class ClientChannel(rpki.rtr.channels.PDUChannel): - """ - Client protocol engine, handles upcalls from PDUChannel. - """ - - serial = None - nonce = None - sql = None - host = None - port = None - cache_id = None - refresh = rpki.rtr.pdus.default_refresh - retry = rpki.rtr.pdus.default_retry - expire = rpki.rtr.pdus.default_expire - updated = Timestamp(0) - - def __init__(self, sock, proc, killsig, args, host = None, port = None): - self.killsig = killsig - self.proc = proc - self.args = args - self.host = args.host if host is None else host - self.port = args.port if port is None else port - super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU) - if args.force_version is not None: - self.version = args.force_version - self.start_new_pdu() - if args.sql_database: - self.setup_sql() - - @classmethod - def ssh(cls, args): - """ - Set up ssh connection and start listening for first PDU. - """ - if args.port is None: - argv = ("ssh", "-s", args.host, "rpki-rtr") - else: - argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr") - logging.debug("[Running ssh: %s]", " ".join(argv)) - s = socket.socketpair() - return cls(sock = s[1], - proc = subprocess.Popen(argv, executable = "/usr/bin/ssh", - stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGKILL, args = args) - - @classmethod - def tcp(cls, args): - """ - Set up TCP connection and start listening for first PDU. +class ClientChannel(rpki.rtr.channels.PDUChannel): """ - - logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port) - try: - addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM) - except socket.error, e: - logging.debug("[socket.getaddrinfo() failed: %s]", e) - else: - for ai in addrinfo: - af, socktype, proto, cn, sa = ai # pylint: disable=W0612 - logging.debug("[Trying addr %s port %s]", sa[0], sa[1]) + Client protocol engine, handles upcalls from PDUChannel. + """ + + serial = None + nonce = None + sql = None + host = None + port = None + cache_id = None + refresh = rpki.rtr.pdus.default_refresh + retry = rpki.rtr.pdus.default_retry + expire = rpki.rtr.pdus.default_expire + updated = Timestamp(0) + + def __init__(self, sock, proc, killsig, args, host = None, port = None): + self.killsig = killsig + self.proc = proc + self.args = args + self.host = args.host if host is None else host + self.port = args.port if port is None else port + super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU) + if args.force_version is not None: + self.version = args.force_version + self.start_new_pdu() + if args.sql_database: + self.setup_sql() + + @classmethod + def ssh(cls, args): + """ + Set up ssh connection and start listening for first PDU. + """ + + if args.port is None: + argv = ("ssh", "-s", args.host, "rpki-rtr") + else: + argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr") + logging.debug("[Running ssh: %s]", " ".join(argv)) + s = socket.socketpair() + return cls(sock = s[1], + proc = subprocess.Popen(argv, executable = "/usr/bin/ssh", + stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, args = args) + + @classmethod + def tcp(cls, args): + """ + Set up TCP connection and start listening for first PDU. + """ + + logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port) try: - s = socket.socket(af, socktype, proto) + addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM) except socket.error, e: - logging.debug("[socket.socket() failed: %s]", e) - continue + logging.debug("[socket.getaddrinfo() failed: %s]", e) + else: + for ai in addrinfo: + af, socktype, proto, cn, sa = ai # pylint: disable=W0612 + logging.debug("[Trying addr %s port %s]", sa[0], sa[1]) + try: + s = socket.socket(af, socktype, proto) + except socket.error, e: + logging.debug("[socket.socket() failed: %s]", e) + continue + try: + s.connect(sa) + except socket.error, e: + logging.exception("[socket.connect() failed: %s]", e) + s.close() + continue + return cls(sock = s, proc = None, killsig = None, args = args) + sys.exit(1) + + @classmethod + def loopback(cls, args): + """ + Set up loopback connection and start listening for first PDU. + """ + + s = socket.socketpair() + logging.debug("[Using direct subprocess kludge for testing]") + argv = (sys.executable, sys.argv[0], "server") + return cls(sock = s[1], + proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGINT, args = args, + host = args.host or "none", port = args.port or "none") + + @classmethod + def tls(cls, args): + """ + Set up TLS connection and start listening for first PDU. + + NB: This uses OpenSSL's "s_client" command, which does not + check server certificates properly, so this is not suitable for + production use. Fixing this would be a trivial change, it just + requires using a client program which does check certificates + properly (eg, gnutls-cli, or stunnel's client mode if that works + for such purposes this week). + """ + + argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port)) + logging.debug("[Running: %s]", " ".join(argv)) + s = socket.socketpair() + return cls(sock = s[1], + proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, args = args) + + def setup_sql(self): + """ + Set up an SQLite database to contain the table we receive. If + necessary, we will create the database. + """ + + import sqlite3 + missing = not os.path.exists(self.args.sql_database) + self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES) + self.sql.text_factory = str + cur = self.sql.cursor() + cur.execute("PRAGMA foreign_keys = on") + if missing: + cur.execute(''' + CREATE TABLE cache ( + cache_id INTEGER PRIMARY KEY NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + version INTEGER, + nonce INTEGER, + serial INTEGER, + updated INTEGER, + refresh INTEGER, + retry INTEGER, + expire INTEGER, + UNIQUE (host, port))''') + cur.execute(''' + CREATE TABLE prefix ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + prefix TEXT NOT NULL, + prefixlen INTEGER NOT NULL, + max_prefixlen INTEGER NOT NULL, + UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''') + cur.execute(''' + CREATE TABLE routerkey ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + ski TEXT NOT NULL, + key TEXT NOT NULL, + UNIQUE (cache_id, asn, ski), + UNIQUE (cache_id, asn, key))''') + elif self.args.reset_session: + cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) + cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated " + "FROM cache WHERE host = ? AND port = ?", + (self.host, self.port)) try: - s.connect(sa) - except socket.error, e: - logging.exception("[socket.connect() failed: %s]", e) - s.close() - continue - return cls(sock = s, proc = None, killsig = None, args = args) - sys.exit(1) - - @classmethod - def loopback(cls, args): - """ - Set up loopback connection and start listening for first PDU. - """ - - s = socket.socketpair() - logging.debug("[Using direct subprocess kludge for testing]") - argv = (sys.executable, sys.argv[0], "server") - return cls(sock = s[1], - proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGINT, args = args, - host = args.host or "none", port = args.port or "none") - - @classmethod - def tls(cls, args): - """ - Set up TLS connection and start listening for first PDU. - - NB: This uses OpenSSL's "s_client" command, which does not - check server certificates properly, so this is not suitable for - production use. Fixing this would be a trivial change, it just - requires using a client program which does check certificates - properly (eg, gnutls-cli, or stunnel's client mode if that works - for such purposes this week). - """ - - argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port)) - logging.debug("[Running: %s]", " ".join(argv)) - s = socket.socketpair() - return cls(sock = s[1], - proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGKILL, args = args) - - def setup_sql(self): - """ - Set up an SQLite database to contain the table we receive. If - necessary, we will create the database. - """ - - import sqlite3 - missing = not os.path.exists(self.args.sql_database) - self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES) - self.sql.text_factory = str - cur = self.sql.cursor() - cur.execute("PRAGMA foreign_keys = on") - if missing: - cur.execute(''' - CREATE TABLE cache ( - cache_id INTEGER PRIMARY KEY NOT NULL, - host TEXT NOT NULL, - port TEXT NOT NULL, - version INTEGER, - nonce INTEGER, - serial INTEGER, - updated INTEGER, - refresh INTEGER, - retry INTEGER, - expire INTEGER, - UNIQUE (host, port))''') - cur.execute(''' - CREATE TABLE prefix ( - cache_id INTEGER NOT NULL - REFERENCES cache(cache_id) - ON DELETE CASCADE - ON UPDATE CASCADE, - asn INTEGER NOT NULL, - prefix TEXT NOT NULL, - prefixlen INTEGER NOT NULL, - max_prefixlen INTEGER NOT NULL, - UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''') - cur.execute(''' - CREATE TABLE routerkey ( - cache_id INTEGER NOT NULL - REFERENCES cache(cache_id) - ON DELETE CASCADE - ON UPDATE CASCADE, - asn INTEGER NOT NULL, - ski TEXT NOT NULL, - key TEXT NOT NULL, - UNIQUE (cache_id, asn, ski), - UNIQUE (cache_id, asn, key))''') - elif self.args.reset_session: - cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) - cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated " - "FROM cache WHERE host = ? AND port = ?", - (self.host, self.port)) - try: - self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone() - if version is not None and self.version is not None and version != self.version: - cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) - raise TypeError # Simulate lookup failure case - if version is not None: - self.version = version - if refresh is not None: + self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone() + if version is not None and self.version is not None and version != self.version: + cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port)) + raise TypeError # Simulate lookup failure case + if version is not None: + self.version = version + if refresh is not None: + self.refresh = refresh + if retry is not None: + self.retry = retry + if expire is not None: + self.expire = expire + if updated is not None: + self.updated = Timestamp(updated) + except TypeError: + cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) + self.cache_id = cur.lastrowid + self.sql.commit() + logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]", + self.cache_id, self.version, self.nonce, + self.serial, self.refresh, self.retry, self.expire, self.updated) + + def cache_reset(self): + """ + Handle CacheResetPDU actions. + """ + + self.serial = None + if self.sql: + cur = self.sql.cursor() + cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) + cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,)) + cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id)) + self.sql.commit() + + def end_of_data(self, version, serial, nonce, refresh, retry, expire): + """ + Handle EndOfDataPDU actions. + """ + + assert version == self.version + self.serial = serial + self.nonce = nonce self.refresh = refresh - if retry is not None: - self.retry = retry - if expire is not None: - self.expire = expire - if updated is not None: - self.updated = Timestamp(updated) - except TypeError: - cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) - self.cache_id = cur.lastrowid - self.sql.commit() - logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]", - self.cache_id, self.version, self.nonce, - self.serial, self.refresh, self.retry, self.expire, self.updated) - - def cache_reset(self): - """ - Handle CacheResetPDU actions. - """ - - self.serial = None - if self.sql: - cur = self.sql.cursor() - cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) - cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,)) - cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id)) - self.sql.commit() - - def end_of_data(self, version, serial, nonce, refresh, retry, expire): - """ - Handle EndOfDataPDU actions. - """ - - assert version == self.version - self.serial = serial - self.nonce = nonce - self.refresh = refresh - self.retry = retry - self.expire = expire - self.updated = Timestamp.now() - if self.sql: - self.sql.execute("UPDATE cache SET" - " version = ?, serial = ?, nonce = ?," - " refresh = ?, retry = ?, expire = ?," - " updated = ? " - "WHERE cache_id = ?", - (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id)) - self.sql.commit() - - def consume_prefix(self, prefix): - """ - Handle one prefix PDU. - """ - - if self.sql: - values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) - if prefix.announce: - self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) " - "VALUES (?, ?, ?, ?, ?)", - values) - else: - self.sql.execute("DELETE FROM prefix " - "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?", - values) - - def consume_routerkey(self, routerkey): - """ - Handle one Router Key PDU. - """ - - if self.sql: - values = (self.cache_id, routerkey.asn, - base64.urlsafe_b64encode(routerkey.ski).rstrip("="), - base64.b64encode(routerkey.key)) - if routerkey.announce: - self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) " - "VALUES (?, ?, ?, ?)", - values) - else: - self.sql.execute("DELETE FROM routerkey " - "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)", - values) - - def deliver_pdu(self, pdu): - """ - Handle received PDU. - """ - - pdu.consume(self) - - def push_pdu(self, pdu): - """ - Log outbound PDU then write it to stream. - """ - - logging.debug(pdu) - super(ClientChannel, self).push_pdu(pdu) - - def cleanup(self): - """ - Force clean up this client's child process. If everything goes - well, child will have exited already before this method is called, - but we may need to whack it with a stick if something breaks. - """ - - if self.proc is not None and self.proc.returncode is None: - try: - os.kill(self.proc.pid, self.killsig) - except OSError: - pass - - def handle_close(self): - """ - Intercept close event so we can log it, then shut down. - """ - - logging.debug("Server closed channel") - super(ClientChannel, self).handle_close() + self.retry = retry + self.expire = expire + self.updated = Timestamp.now() + if self.sql: + self.sql.execute("UPDATE cache SET" + " version = ?, serial = ?, nonce = ?," + " refresh = ?, retry = ?, expire = ?," + " updated = ? " + "WHERE cache_id = ?", + (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id)) + self.sql.commit() + + def consume_prefix(self, prefix): + """ + Handle one prefix PDU. + """ + + if self.sql: + values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) + if prefix.announce: + self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) " + "VALUES (?, ?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM prefix " + "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?", + values) + + def consume_routerkey(self, routerkey): + """ + Handle one Router Key PDU. + """ + + if self.sql: + values = (self.cache_id, routerkey.asn, + base64.urlsafe_b64encode(routerkey.ski).rstrip("="), + base64.b64encode(routerkey.key)) + if routerkey.announce: + self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) " + "VALUES (?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM routerkey " + "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)", + values) + + def deliver_pdu(self, pdu): + """ + Handle received PDU. + """ + + pdu.consume(self) + + def push_pdu(self, pdu): + """ + Log outbound PDU then write it to stream. + """ + + logging.debug(pdu) + super(ClientChannel, self).push_pdu(pdu) + + def cleanup(self): + """ + Force clean up this client's child process. If everything goes + well, child will have exited already before this method is called, + but we may need to whack it with a stick if something breaks. + """ + + if self.proc is not None and self.proc.returncode is None: + try: + os.kill(self.proc.pid, self.killsig) + except OSError: + pass + + def handle_close(self): + """ + Intercept close event so we can log it, then shut down. + """ + + logging.debug("Server closed channel") + super(ClientChannel, self).handle_close() # Hack to let us subclass this from scripts without needing to rewrite client_main(). @@ -460,73 +460,73 @@ class ClientChannel(rpki.rtr.channels.PDUChannel): ClientChannelClass = ClientChannel def client_main(args): - """ - Test client, intended primarily for debugging. - """ + """ + Test client, intended primarily for debugging. + """ - logging.debug("[Startup]") + logging.debug("[Startup]") - assert issubclass(ClientChannelClass, ClientChannel) - constructor = getattr(ClientChannelClass, args.protocol) + assert issubclass(ClientChannelClass, ClientChannel) + constructor = getattr(ClientChannelClass, args.protocol) - client = None - try: - client = constructor(args) + client = None + try: + client = constructor(args) - polled = client.updated - wakeup = None + polled = client.updated + wakeup = None - while True: + while True: - now = Timestamp.now() + now = Timestamp.now() - if client.serial is not None and now > client.updated + client.expire: - logging.info("[Expiring client data: serial %s, last updated %s, expire %s]", - client.serial, client.updated, client.expire) - client.cache_reset() + if client.serial is not None and now > client.updated + client.expire: + logging.info("[Expiring client data: serial %s, last updated %s, expire %s]", + client.serial, client.updated, client.expire) + client.cache_reset() - if client.serial is None or client.nonce is None: - polled = now - client.push_pdu(ResetQueryPDU(version = client.version)) + if client.serial is None or client.nonce is None: + polled = now + client.push_pdu(ResetQueryPDU(version = client.version)) - elif now >= client.updated + client.refresh: - polled = now - client.push_pdu(SerialQueryPDU(version = client.version, - serial = client.serial, - nonce = client.nonce)) + elif now >= client.updated + client.refresh: + polled = now + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.serial, + nonce = client.nonce)) - remaining = 1 + remaining = 1 - while remaining > 0: - now = Timestamp.now() - timer = client.retry if (now >= client.updated + client.refresh) else client.refresh - wokeup = wakeup - wakeup = max(now, Timestamp(max(polled, client.updated) + timer)) - remaining = wakeup - now - if wakeup != wokeup: - logging.info("[Last client poll %s, next %s]", polled, wakeup) - asyncore.loop(timeout = remaining, count = 1) + while remaining > 0: + now = Timestamp.now() + timer = client.retry if (now >= client.updated + client.refresh) else client.refresh + wokeup = wakeup + wakeup = max(now, Timestamp(max(polled, client.updated) + timer)) + remaining = wakeup - now + if wakeup != wokeup: + logging.info("[Last client poll %s, next %s]", polled, wakeup) + asyncore.loop(timeout = remaining, count = 1) - except KeyboardInterrupt: - sys.exit(0) + except KeyboardInterrupt: + sys.exit(0) - finally: - if client is not None: - client.cleanup() + finally: + if client is not None: + client.cleanup() def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - subparser = subparsers.add_parser("client", description = client_main.__doc__, - help = "Test client for RPKI-RTR protocol") - subparser.set_defaults(func = client_main, default_log_to = "stderr") - subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state") - subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version") - subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database") - subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol") - subparser.add_argument("host", nargs = "?", help = "server host") - subparser.add_argument("port", nargs = "?", help = "server port") - return subparser + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("client", description = client_main.__doc__, + help = "Test client for RPKI-RTR protocol") + subparser.set_defaults(func = client_main, default_log_destination = "stderr") + subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state") + subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version") + subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database") + subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol") + subparser.add_argument("host", nargs = "?", help = "server host") + subparser.add_argument("port", nargs = "?", help = "server port") + return subparser diff --git a/rpki/rtr/generator.py b/rpki/rtr/generator.py index 26e25b6e..4536de30 100644 --- a/rpki/rtr/generator.py +++ b/rpki/rtr/generator.py @@ -36,540 +36,553 @@ import rpki.rtr.server from rpki.rtr.channels import Timestamp -class PrefixPDU(rpki.rtr.pdus.PrefixPDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - @staticmethod - def from_text(version, asn, addr): - """ - Construct a prefix from its text form. - """ +from rpki.rcynicdb.iterator import authenticated_objects - cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU - self = cls(version = version) - self.asn = long(asn) - p, l = addr.split("/") - self.prefix = rpki.POW.IPAddress(p) - if "-" in l: - self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-")) - else: - self.prefixlen = self.max_prefixlen = int(l) - self.announce = 1 - self.check() - return self - - @staticmethod - def from_roa(version, asn, prefix_tuple): - """ - Construct a prefix from a ROA. +class PrefixPDU(rpki.rtr.pdus.PrefixPDU): """ - - address, length, maxlength = prefix_tuple - cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU - self = cls(version = version) - self.asn = asn - self.prefix = address - self.prefixlen = length - self.max_prefixlen = length if maxlength is None else maxlength - self.announce = 1 - self.check() - return self + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + @staticmethod + def from_text(version, asn, addr): + """ + Construct a prefix from its text form. + """ + + cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU + self = cls(version = version) + self.asn = long(asn) + p, l = addr.split("/") + self.prefix = rpki.POW.IPAddress(p) + if "-" in l: + self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-")) + else: + self.prefixlen = self.max_prefixlen = int(l) + self.announce = 1 + self.check() + return self + + @staticmethod + def from_roa(version, asn, prefix_tuple): + """ + Construct a prefix from a ROA. + """ + + address, length, maxlength = prefix_tuple + cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU + self = cls(version = version) + self.asn = asn + self.prefix = address + self.prefixlen = length + self.max_prefixlen = length if maxlength is None else maxlength + self.announce = 1 + self.check() + return self class IPv4PrefixPDU(PrefixPDU): - """ - IPv4 flavor of a prefix. - """ + """ + IPv4 flavor of a prefix. + """ - pdu_type = 4 - address_byte_count = 4 + pdu_type = 4 + address_byte_count = 4 class IPv6PrefixPDU(PrefixPDU): - """ - IPv6 flavor of a prefix. - """ - - pdu_type = 6 - address_byte_count = 16 - -class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU): - """ - Router Key PDU. - """ - - @classmethod - def from_text(cls, version, asn, gski, key): """ - Construct a router key from its text form. + IPv6 flavor of a prefix. """ - self = cls(version = version) - self.asn = long(asn) - self.ski = base64.urlsafe_b64decode(gski + "=") - self.key = base64.b64decode(key) - self.announce = 1 - self.check() - return self + pdu_type = 6 + address_byte_count = 16 - @classmethod - def from_certificate(cls, version, asn, ski, key): +class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU): """ - Construct a router key from a certificate. + Router Key PDU. """ - self = cls(version = version) - self.asn = asn - self.ski = ski - self.key = key - self.announce = 1 - self.check() - return self + announce = None + ski = None + asn = None + key = None + @classmethod + def from_text(cls, version, asn, gski, key): + """ + Construct a router key from its text form. + """ -class ROA(rpki.POW.ROA): # pylint: disable=W0232 - """ - Minor additions to rpki.POW.ROA. - """ - - @classmethod - def derReadFile(cls, fn): # pylint: disable=E1002 - self = super(ROA, cls).derReadFile(fn) - self.extractWithoutVerifying() - return self - - @property - def prefixes(self): - v4, v6 = self.getPrefixes() - if v4 is not None: - for p in v4: - yield p - if v6 is not None: - for p in v6: - yield p + self = cls(version = version) + self.asn = long(asn) + self.ski = base64.urlsafe_b64decode(gski + "=") + self.key = base64.b64decode(key) + self.announce = 1 + self.check() + return self -class X509(rpki.POW.X509): # pylint: disable=W0232 - """ - Minor additions to rpki.POW.X509. - """ + @classmethod + def from_certificate(cls, version, asn, ski, key): + """ + Construct a router key from a certificate. + """ - @property - def asns(self): - resources = self.getRFC3779() - if resources is not None and resources[0] is not None: - for min_asn, max_asn in resources[0]: - for asn in xrange(min_asn, max_asn + 1): - yield asn + self = cls(version = version) + self.asn = asn + self.ski = ski + self.key = key + self.announce = 1 + self.check() + return self -class PDUSet(list): - """ - Object representing a set of PDUs, that is, one versioned and - (theoretically) consistant set of prefixes and router keys extracted - from rcynic's output. - """ - - def __init__(self, version): - assert version in rpki.rtr.pdus.PDU.version_map - super(PDUSet, self).__init__() - self.version = version - - @classmethod - def _load_file(cls, filename, version): +class ROA(rpki.POW.ROA): # pylint: disable=W0232 """ - Low-level method to read PDUSet from a file. + Minor additions to rpki.POW.ROA. """ - self = cls(version = version) - f = open(filename, "rb") - r = rpki.rtr.channels.ReadBuffer() - while True: - p = rpki.rtr.pdus.PDU.read_pdu(r) - while p is None: - b = f.read(r.needed()) - if b == "": - assert r.available() == 0 - return self - r.put(b) - p = r.retry() - assert p.version == self.version - self.append(p) - - @staticmethod - def seq_ge(a, b): - return ((a - b) % (1 << 32)) < (1 << 31) + @classmethod + def derReadFile(cls, fn): + # pylint: disable=E1002 + self = super(ROA, cls).derReadFile(fn) + self.extractWithoutVerifying() + return self + @property + def prefixes(self): + v4, v6 = self.getPrefixes() # pylint: disable=E1101 + if v4 is not None: + for p in v4: + yield p + if v6 is not None: + for p in v6: + yield p -class AXFRSet(PDUSet): - """ - Object representing a complete set of PDUs, that is, one versioned - and (theoretically) consistant set of prefixes and router - certificates extracted from rcynic's output, all with the announce - field set. - """ - - @classmethod - def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None): +class X509(rpki.POW.X509): # pylint: disable=W0232 """ - Parse ROAS and router certificates fetched (and validated!) by - rcynic to create a new AXFRSet. - - In normal operation, we use os.walk() and the rpki.POW library to - parse these data directly, but we can, if so instructed, use - external programs instead, for testing, simulation, or to provide - a way to inject local data. - - At some point the ability to parse these data from external - programs may move to a separate constructor function, so that we - can make this one a bit simpler and faster. + Minor additions to rpki.POW.X509. """ - self = cls(version = version) - self.serial = rpki.rtr.channels.Timestamp.now() - - include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version] - - if scan_roas is None or (scan_routercerts is None and include_routercerts): - for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612 - for fn in files: - if scan_roas is None and fn.endswith(".roa"): - roa = ROA.derReadFile(os.path.join(root, fn)) - asn = roa.getASID() - self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple) - for prefix_tuple in roa.prefixes) - if include_routercerts and scan_routercerts is None and fn.endswith(".cer"): - x = X509.derReadFile(os.path.join(root, fn)) - eku = x.getEKU() - if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: - ski = x.getSKI() - key = x.getPublicKey().derWritePublic() - self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key) - for asn in x.asns) - - if scan_roas is not None: - try: - p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE) - for line in p.stdout: - line = line.split() - asn = line[1] - self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr) - for addr in line[2:]) - except OSError, e: - sys.exit("Could not run %s: %s" % (scan_roas, e)) - - if include_routercerts and scan_routercerts is not None: - try: - p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE) - for line in p.stdout: - line = line.split() - gski = line[0] - key = line[-1] - self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key) - for asn in line[1:-1]) - except OSError, e: - sys.exit("Could not run %s: %s" % (scan_routercerts, e)) - - self.sort() - for i in xrange(len(self) - 2, -1, -1): - if self[i] == self[i + 1]: - del self[i + 1] - return self - - @classmethod - def load(cls, filename): - """ - Load an AXFRSet from a file, parse filename to obtain version and serial. - """ + @property + def asns(self): + resources = self.getRFC3779() # pylint: disable=E1101 + if resources is not None and resources[0] is not None: + for min_asn, max_asn in resources[0]: + for asn in xrange(min_asn, max_asn + 1): + yield asn - fn1, fn2, fn3 = os.path.basename(filename).split(".") - assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit() - version = int(fn3[1:]) - self = cls._load_file(filename, version) - self.serial = rpki.rtr.channels.Timestamp(fn1) - return self - def filename(self): - """ - Generate filename for this AXFRSet. +class PDUSet(list): """ + Object representing a set of PDUs, that is, one versioned and + (theoretically) consistant set of prefixes and router keys extracted + from rcynic's output. + """ + + def __init__(self, version): + assert version in rpki.rtr.pdus.PDU.version_map + super(PDUSet, self).__init__() + self.version = version + + @classmethod + def _load_file(cls, filename, version): + """ + Low-level method to read PDUSet from a file. + """ + + self = cls(version = version) + f = open(filename, "rb") + r = rpki.rtr.channels.ReadBuffer() + while True: + p = rpki.rtr.pdus.PDU.read_pdu(r) + while p is None: + b = f.read(r.needed()) + if b == "": + assert r.available() == 0 + return self + r.put(b) + p = r.retry() + assert p.version == self.version + self.append(p) + + @staticmethod + def seq_ge(a, b): + return ((a - b) % (1 << 32)) < (1 << 31) - return "%d.ax.v%d" % (self.serial, self.version) - @classmethod - def load_current(cls, version): - """ - Load current AXFRSet. Return None if can't. +class AXFRSet(PDUSet): """ + Object representing a complete set of PDUs, that is, one versioned + and (theoretically) consistant set of prefixes and router + certificates extracted from rcynic's output, all with the announce + field set. + """ + + class_map = dict(cer = X509, roa = ROA) + + serial = None + + @classmethod + def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None): + """ + Parse ROAS and router certificates fetched (and validated!) by + rcynic to create a new AXFRSet. + + In normal operation, we parse these data directly from whatever rcynic is using + as a validator this week, but we can, if so instructed, use external programs + instead, for testing, simulation, or to provide a way to inject local data. + + At some point the ability to parse these data from external + programs may move to a separate constructor function, so that we + can make this one a bit simpler and faster. + """ + + self = cls(version = version) + self.serial = rpki.rtr.channels.Timestamp.now() + + include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version] + + if scan_roas is None: + for uri, roa in authenticated_objects(rcynic_dir, uri_suffix = ".roa", class_map = self.class_map): + roa.extractWithoutVerifying() + asn = roa.getASID() + self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple) + for prefix_tuple in roa.prefixes) + + if scan_routercerts is None and include_routercerts: + for uri, cer in authenticated_objects(rcynic_dir, uri_suffix = ".cer", class_map = self.class_map): + eku = cer.getEKU() + if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: + ski = cer.getSKI() + key = cer.getPublicKey().derWritePublic() + self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key) + for asn in cer.asns) + + if scan_roas is not None: + try: + p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE) + for line in p.stdout: + line = line.split() + asn = line[1] + self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr) + for addr in line[2:]) + except OSError, e: + sys.exit("Could not run %s: %s" % (scan_roas, e)) + + if include_routercerts and scan_routercerts is not None: + try: + p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE) + for line in p.stdout: + line = line.split() + gski = line[0] + key = line[-1] + self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key) + for asn in line[1:-1]) + except OSError, e: + sys.exit("Could not run %s: %s" % (scan_routercerts, e)) + + self.sort() + for i in xrange(len(self) - 2, -1, -1): + if self[i] == self[i + 1]: + del self[i + 1] + return self + + @classmethod + def load(cls, filename): + """ + Load an AXFRSet from a file, parse filename to obtain version and serial. + """ + + fn1, fn2, fn3 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit() + version = int(fn3[1:]) + self = cls._load_file(filename, version) + self.serial = rpki.rtr.channels.Timestamp(fn1) + return self + + def filename(self): + """ + Generate filename for this AXFRSet. + """ + + return "%d.ax.v%d" % (self.serial, self.version) + + @classmethod + def load_current(cls, version): + """ + Load current AXFRSet. Return None if can't. + """ + + serial = rpki.rtr.server.read_current(version)[0] + if serial is None: + return None + try: + return cls.load("%d.ax.v%d" % (serial, version)) + except IOError: + return None + + def save_axfr(self): + """ + Write AXFRSet to file with magic filename. + """ + + f = open(self.filename(), "wb") + for p in self: + f.write(p.to_pdu()) + f.close() + + def destroy_old_data(self): + """ + Destroy old data files, presumably because our nonce changed and + the old serial numbers are no longer valid. + """ + + for i in glob.iglob("*.ix.*.v%d" % self.version): + os.unlink(i) + for i in glob.iglob("*.ax.v%d" % self.version): + if i != self.filename(): + os.unlink(i) + + @staticmethod + def new_nonce(force_zero_nonce): + """ + Create and return a new nonce value. + """ + + if force_zero_nonce: + return 0 + try: + return int(random.SystemRandom().getrandbits(16)) + except NotImplementedError: + return int(random.getrandbits(16)) + + def mark_current(self, force_zero_nonce = False): + """ + Save current serial number and nonce, creating new nonce if + necessary. Creating a new nonce triggers cleanup of old state, as + the new nonce invalidates all old serial numbers. + """ + + assert self.version in rpki.rtr.pdus.PDU.version_map + old_serial, nonce = rpki.rtr.server.read_current(self.version) + if old_serial is None or self.seq_ge(old_serial, self.serial): + logging.debug("Creating new nonce and deleting stale data") + nonce = self.new_nonce(force_zero_nonce) + self.destroy_old_data() + rpki.rtr.server.write_current(self.serial, nonce, self.version) + + def save_ixfr(self, other): + """ + Comparing this AXFRSet with an older one and write the resulting + IXFRSet to file with magic filename. Since we store PDUSets + in sorted order, computing the difference is a trivial linear + comparison. + """ + + f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb") + old = other + new = self + len_old = len(old) + len_new = len(new) + i_old = i_new = 0 + while i_old < len_old and i_new < len_new: + if old[i_old] < new[i_new]: + f.write(old[i_old].to_pdu(announce = 0)) + i_old += 1 + elif old[i_old] > new[i_new]: + f.write(new[i_new].to_pdu(announce = 1)) + i_new += 1 + else: + i_old += 1 + i_new += 1 + for i in xrange(i_old, len_old): + f.write(old[i].to_pdu(announce = 0)) + for i in xrange(i_new, len_new): + f.write(new[i].to_pdu(announce = 1)) + f.close() + + def show(self): + """ + Print this AXFRSet. + """ + + logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version) + for p in self: + logging.debug(p) - serial = rpki.rtr.server.read_current(version)[0] - if serial is None: - return None - try: - return cls.load("%d.ax.v%d" % (serial, version)) - except IOError: - return None - def save_axfr(self): +class IXFRSet(PDUSet): """ - Write AXFRSet to file with magic filename. + Object representing an incremental set of PDUs, that is, the + differences between one versioned and (theoretically) consistant set + of prefixes and router certificates extracted from rcynic's output + and another, with the announce fields set or cleared as necessary to + indicate the changes. """ - f = open(self.filename(), "wb") - for p in self: - f.write(p.to_pdu()) - f.close() + from_serial = None + to_serial = None - def destroy_old_data(self): - """ - Destroy old data files, presumably because our nonce changed and - the old serial numbers are no longer valid. - """ + @classmethod + def load(cls, filename): + """ + Load an IXFRSet from a file, parse filename to obtain version and serials. + """ - for i in glob.iglob("*.ix.*.v%d" % self.version): - os.unlink(i) - for i in glob.iglob("*.ax.v%d" % self.version): - if i != self.filename(): - os.unlink(i) + fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit() + version = int(fn4[1:]) + self = cls._load_file(filename, version) + self.from_serial = rpki.rtr.channels.Timestamp(fn3) + self.to_serial = rpki.rtr.channels.Timestamp(fn1) + return self - @staticmethod - def new_nonce(force_zero_nonce): - """ - Create and return a new nonce value. - """ + def filename(self): + """ + Generate filename for this IXFRSet. + """ - if force_zero_nonce: - return 0 - try: - return int(random.SystemRandom().getrandbits(16)) - except NotImplementedError: - return int(random.getrandbits(16)) + return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version) - def mark_current(self, force_zero_nonce = False): - """ - Save current serial number and nonce, creating new nonce if - necessary. Creating a new nonce triggers cleanup of old state, as - the new nonce invalidates all old serial numbers. - """ + def show(self): + """ + Print this IXFRSet. + """ - assert self.version in rpki.rtr.pdus.PDU.version_map - old_serial, nonce = rpki.rtr.server.read_current(self.version) - if old_serial is None or self.seq_ge(old_serial, self.serial): - logging.debug("Creating new nonce and deleting stale data") - nonce = self.new_nonce(force_zero_nonce) - self.destroy_old_data() - rpki.rtr.server.write_current(self.serial, nonce, self.version) + logging.debug("# IXFR %d (%s) -> %d (%s) v%d", + self.from_serial, self.from_serial, + self.to_serial, self.to_serial, + self.version) + for p in self: + logging.debug(p) - def save_ixfr(self, other): - """ - Comparing this AXFRSet with an older one and write the resulting - IXFRSet to file with magic filename. Since we store PDUSets - in sorted order, computing the difference is a trivial linear - comparison. - """ - f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb") - old = other - new = self - len_old = len(old) - len_new = len(new) - i_old = i_new = 0 - while i_old < len_old and i_new < len_new: - if old[i_old] < new[i_new]: - f.write(old[i_old].to_pdu(announce = 0)) - i_old += 1 - elif old[i_old] > new[i_new]: - f.write(new[i_new].to_pdu(announce = 1)) - i_new += 1 - else: - i_old += 1 - i_new += 1 - for i in xrange(i_old, len_old): - f.write(old[i].to_pdu(announce = 0)) - for i in xrange(i_new, len_new): - f.write(new[i].to_pdu(announce = 1)) - f.close() - - def show(self): +def kick_all(serial): """ - Print this AXFRSet. + Kick any existing server processes to wake them up. """ - logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version) - for p in self: - logging.debug(p) + try: + os.stat(rpki.rtr.server.kickme_dir) + except OSError: + logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir) + os.makedirs(rpki.rtr.server.kickme_dir) + + msg = "Good morning, serial %d is ready" % serial + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base): + try: + logging.debug("# Kicking %s", name) + sock.sendto(msg, name) + except socket.error: + try: + logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name) + os.unlink(name) + except Exception, e: + logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e) + except Exception, e: + logging.warning("# Failed to kick %s and don't understand why: %s", name, e) + sock.close() -class IXFRSet(PDUSet): - """ - Object representing an incremental set of PDUs, that is, the - differences between one versioned and (theoretically) consistant set - of prefixes and router certificates extracted from rcynic's output - and another, with the announce fields set or cleared as necessary to - indicate the changes. - """ - - @classmethod - def load(cls, filename): - """ - Load an IXFRSet from a file, parse filename to obtain version and serials. +def cronjob_main(args): """ + Run this right after running rcynic to wade through the ROAs and + router certificates that rcynic collects and translate that data + into the form used in the rpki-router protocol. Output is an + updated database containing both full dumps (AXFR) and incremental + dumps against a specific prior version (IXFR). After updating the + database, kicks any active servers, so that they can notify their + clients that a new version is available. + """ + + if args.rpki_rtr_dir: + try: + if not os.path.isdir(args.rpki_rtr_dir): + os.makedirs(args.rpki_rtr_dir) + os.chdir(args.rpki_rtr_dir) + except OSError, e: + logging.critical(str(e)) + sys.exit(1) + + for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True): + + logging.debug("# Generating updates for protocol version %d", version) + + old_ixfrs = glob.glob("*.ix.*.v%d" % version) + + current = rpki.rtr.server.read_current(version)[0] + cutoff = Timestamp.now(-(24 * 60 * 60)) + for f in glob.iglob("*.ax.v%d" % version): + t = Timestamp(int(f.split(".")[0])) + if t < cutoff and t != current: + logging.debug("# Deleting old file %s, timestamp %s", f, t) + os.unlink(f) + + pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts) + if pdus == rpki.rtr.generator.AXFRSet.load_current(version): + logging.debug("# No change, new serial not needed") + continue + pdus.save_axfr() + for axfr in glob.iglob("*.ax.v%d" % version): + if axfr != pdus.filename(): + pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr)) + pdus.mark_current(args.force_zero_nonce) + + logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial) + + rpki.rtr.generator.kick_all(pdus.serial) + + old_ixfrs.sort() + for ixfr in old_ixfrs: + try: + logging.debug("# Deleting old file %s", ixfr) + os.unlink(ixfr) + except OSError: + pass - fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".") - assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit() - version = int(fn4[1:]) - self = cls._load_file(filename, version) - self.from_serial = rpki.rtr.channels.Timestamp(fn3) - self.to_serial = rpki.rtr.channels.Timestamp(fn1) - return self - def filename(self): +def show_main(args): """ - Generate filename for this IXFRSet. + Display current rpki-rtr server database in textual form. """ - return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version) + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) - def show(self): - """ - Print this IXFRSet. - """ + g = glob.glob("*.ax.v*") + g.sort() + for f in g: + rpki.rtr.generator.AXFRSet.load(f).show() - logging.debug("# IXFR %d (%s) -> %d (%s) v%d", - self.from_serial, self.from_serial, - self.to_serial, self.to_serial, - self.version) - for p in self: - logging.debug(p) + g = glob.glob("*.ix.*.v*") + g.sort() + for f in g: + rpki.rtr.generator.IXFRSet.load(f).show() +def argparse_setup(subparsers): + """ + Set up argparse stuff for commands in this module. + """ -def kick_all(serial): - """ - Kick any existing server processes to wake them up. - """ - - try: - os.stat(rpki.rtr.server.kickme_dir) - except OSError: - logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir) - os.makedirs(rpki.rtr.server.kickme_dir) - - msg = "Good morning, serial %d is ready" % serial - sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) - for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base): - try: - logging.debug("# Kicking %s", name) - sock.sendto(msg, name) - except socket.error: - try: - logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name) - os.unlink(name) - except Exception, e: - logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e) - except Exception, e: - logging.warning("# Failed to kick %s and don't understand why: %s", name, e) - sock.close() - - -def cronjob_main(args): - """ - Run this right after running rcynic to wade through the ROAs and - router certificates that rcynic collects and translate that data - into the form used in the rpki-router protocol. Output is an - updated database containing both full dumps (AXFR) and incremental - dumps against a specific prior version (IXFR). After updating the - database, kicks any active servers, so that they can notify their - clients that a new version is available. - """ - - if args.rpki_rtr_dir: - try: - if not os.path.isdir(args.rpki_rtr_dir): - os.makedirs(args.rpki_rtr_dir) - os.chdir(args.rpki_rtr_dir) - except OSError, e: - logging.critical(str(e)) - sys.exit(1) - - for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True): - - logging.debug("# Generating updates for protocol version %d", version) - - old_ixfrs = glob.glob("*.ix.*.v%d" % version) - - current = rpki.rtr.server.read_current(version)[0] - cutoff = Timestamp.now(-(24 * 60 * 60)) - for f in glob.iglob("*.ax.v%d" % version): - t = Timestamp(int(f.split(".")[0])) - if t < cutoff and t != current: - logging.debug("# Deleting old file %s, timestamp %s", f, t) - os.unlink(f) - - pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts) - if pdus == rpki.rtr.generator.AXFRSet.load_current(version): - logging.debug("# No change, new serial not needed") - continue - pdus.save_axfr() - for axfr in glob.iglob("*.ax.v%d" % version): - if axfr != pdus.filename(): - pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr)) - pdus.mark_current(args.force_zero_nonce) - - logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial) - - rpki.rtr.generator.kick_all(pdus.serial) - - old_ixfrs.sort() - for ixfr in old_ixfrs: - try: - logging.debug("# Deleting old file %s", ixfr) - os.unlink(ixfr) - except OSError: - pass - - -def show_main(args): - """ - Display current rpki-rtr server database in textual form. - """ - - if args.rpki_rtr_dir: - try: - os.chdir(args.rpki_rtr_dir) - except OSError, e: - sys.exit(e) - - g = glob.glob("*.ax.v*") - g.sort() - for f in g: - rpki.rtr.generator.AXFRSet.load(f).show() - - g = glob.glob("*.ix.*.v*") - g.sort() - for f in g: - rpki.rtr.generator.IXFRSet.load(f).show() + subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__, + help = "Generate RPKI-RTR database from rcynic output") + subparser.set_defaults(func = cronjob_main, default_log_destination = "syslog") + subparser.add_argument("--scan-roas", help = "specify an external scan_roas program") + subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program") + subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero") + subparser.add_argument("rcynic_dir", nargs = "?", help = "directory containing validated rcynic output tree") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") -def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__, - help = "Generate RPKI-RTR database from rcynic output") - subparser.set_defaults(func = cronjob_main, default_log_to = "syslog") - subparser.add_argument("--scan-roas", help = "specify an external scan_roas program") - subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program") - subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero") - subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") - - subparser = subparsers.add_parser("show", description = show_main.__doc__, - help = "Display content of RPKI-RTR database") - subparser.set_defaults(func = show_main, default_log_to = "stderr") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + subparser = subparsers.add_parser("show", description = show_main.__doc__, + help = "Display content of RPKI-RTR database") + subparser.set_defaults(func = show_main, default_log_destination = "stderr") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") diff --git a/rpki/rtr/main.py b/rpki/rtr/main.py index 12de30cc..b915f809 100644 --- a/rpki/rtr/main.py +++ b/rpki/rtr/main.py @@ -25,70 +25,35 @@ import os import sys import time import logging -import logging.handlers -import argparse +import rpki.config -class Formatter(logging.Formatter): - - converter = time.gmtime - - def __init__(self, debug, fmt, datefmt): - self.debug = debug - super(Formatter, self).__init__(fmt, datefmt) - - def format(self, record): - if getattr(record, "connection", None) is None: - record.connection = "" - return super(Formatter, self).format(record) - - def formatException(self, ei): - if self.debug: - return super(Formatter, self).formatException(ei) - else: - return str(ei[1]) def main(): - os.environ["TZ"] = "UTC" - time.tzset() - - from rpki.rtr.server import argparse_setup as argparse_setup_server - from rpki.rtr.client import argparse_setup as argparse_setup_client - from rpki.rtr.generator import argparse_setup as argparse_setup_generator + os.environ["TZ"] = "UTC" + time.tzset() - if "rpki.rtr.bgpdump" in sys.modules: - from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump - else: - def argparse_setup_bgpdump(ignored): - pass + from rpki.rtr.server import argparse_setup as argparse_setup_server + from rpki.rtr.client import argparse_setup as argparse_setup_client + from rpki.rtr.generator import argparse_setup as argparse_setup_generator - argparser = argparse.ArgumentParser(description = __doc__) - argparser.add_argument("--debug", action = "store_true", help = "debugging mode") - argparser.add_argument("--log-level", default = "debug", - choices = ("debug", "info", "warning", "error", "critical"), - type = lambda s: s.lower()) - argparser.add_argument("--log-to", - choices = ("syslog", "stderr")) - subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode") - argparse_setup_server(subparsers) - argparse_setup_client(subparsers) - argparse_setup_generator(subparsers) - argparse_setup_bgpdump(subparsers) - args = argparser.parse_args() - - fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s" - - if (args.log_to or args.default_log_to) == "stderr": - handler = logging.StreamHandler() - fmt = "%(asctime)s " + fmt - elif os.path.exists("/dev/log"): - handler = logging.handlers.SysLogHandler("/dev/log") - else: - handler = logging.handlers.SysLogHandler() - - handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ")) - logging.root.addHandler(handler) - logging.root.setLevel(int(getattr(logging, args.log_level.upper()))) - - return args.func(args) + if "rpki.rtr.bgpdump" in sys.modules: + from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump + else: + def argparse_setup_bgpdump(ignored): + pass + + cfg = rpki.config.argparser(section = "rpki-rtr", doc = __doc__) + cfg.argparser.add_argument("--debug", action = "store_true", help = "debugging mode") + cfg.add_logging_arguments() + subparsers = cfg.argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode") + argparse_setup_server(subparsers) + argparse_setup_client(subparsers) + argparse_setup_generator(subparsers) + argparse_setup_bgpdump(subparsers) + args = cfg.argparser.parse_args() + + cfg.configure_logging(args = args, ident = "rpki-rtr/" + args.mode) + + return args.func(args) diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py index 0d2e5928..3fb7457d 100644 --- a/rpki/rtr/pdus.py +++ b/rpki/rtr/pdus.py @@ -28,292 +28,300 @@ import rpki.POW # Exceptions class PDUException(Exception): - """ - Parent exception type for exceptions that signal particular protocol - errors. String value of exception instance will be the message to - put in the ErrorReportPDU, error_report_code value of exception - will be the numeric code to use. - """ - - def __init__(self, msg = None, pdu = None): - super(PDUException, self).__init__() - assert msg is None or isinstance(msg, (str, unicode)) - self.error_report_msg = msg - self.error_report_pdu = pdu - - def __str__(self): - return self.error_report_msg or self.__class__.__name__ - - def make_error_report(self, version): - return ErrorReportPDU(version = version, - errno = self.error_report_code, - errmsg = self.error_report_msg, - errpdu = self.error_report_pdu) + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the ErrorReportPDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + super(PDUException, self).__init__() + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) class UnsupportedProtocolVersion(PDUException): - error_report_code = 4 + error_report_code = 4 class UnsupportedPDUType(PDUException): - error_report_code = 5 + error_report_code = 5 class CorruptData(PDUException): - error_report_code = 0 + error_report_code = 0 # Decorators def wire_pdu(cls, versions = None): - """ - Class decorator to add a PDU class to the set of known PDUs - for all supported protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. + """ - for v in PDU.version_map.iterkeys() if versions is None else versions: - assert cls.pdu_type not in PDU.version_map[v] - PDU.version_map[v][cls.pdu_type] = cls - return cls + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls + return cls def wire_pdu_only(*versions): - """ - Class decorator to add a PDU class to the set of known PDUs - for specific protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ - assert versions and all(v in PDU.version_map for v in versions) - return lambda cls: wire_pdu(cls, versions) + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) def clone_pdu_root(root_pdu_class): - """ - Replace a PDU root class's version_map with a two-level deep copy of itself, - and return a class decorator which subclasses can use to replace their - parent classes with themselves in the resulting cloned version map. + """ + Replace a PDU root class's version_map with a two-level deep copy of itself, + and return a class decorator which subclasses can use to replace their + parent classes with themselves in the resulting cloned version map. - This function is not itself a decorator, it returns one. - """ + This function is not itself a decorator, it returns one. + """ - root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) + root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) - def decorator(cls): - for pdu_map in root_pdu_class.version_map.itervalues(): - for pdu_type, pdu_class in pdu_map.items(): - if pdu_class in cls.__bases__: - pdu_map[pdu_type] = cls - return cls + def decorator(cls): + for pdu_map in root_pdu_class.version_map.itervalues(): + for pdu_type, pdu_class in pdu_map.items(): + if pdu_class in cls.__bases__: + pdu_map[pdu_type] = cls + return cls - return decorator + return decorator # PDUs class PDU(object): - """ - Base PDU. Real PDUs are subclasses of this class. - """ - - version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu - - _pdu = None # Cached when first generated + """ + Base PDU. Real PDUs are subclasses of this class. + """ - header_struct = struct.Struct("!BB2xL") + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu - def __init__(self, version): - assert version in self.version_map - self.version = version + _pdu = None # Cached when first generated - def __cmp__(self, other): - return cmp(self.to_pdu(), other.to_pdu()) + header_struct = struct.Struct("!BB2xL") - @property - def default_version(self): - return max(self.version_map.iterkeys()) + pdu_type = None - def check(self): - pass + def __init__(self, version): + assert version in self.version_map + self.version = version - @classmethod - def read_pdu(cls, reader): - return reader.update(need = cls.header_struct.size, callback = cls.got_header) + def __cmp__(self, other): + return cmp(self.to_pdu(), other.to_pdu()) - @classmethod - def got_header(cls, reader): - if not reader.ready(): - return None - assert reader.available() >= cls.header_struct.size - version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - reader.check_version(version) - if pdu_type not in cls.version_map[version]: - raise UnsupportedPDUType( - "Received unsupported PDU type %d" % pdu_type) - if length < 8: - raise CorruptData( - "Received PDU with length %d, which is too short to be valid" % length) - self = cls.version_map[version][pdu_type](version = version) - return reader.update(need = length, callback = self.got_pdu) + def to_pdu(self, announce = None): + return NotImplementedError + @property + def default_version(self): + return max(self.version_map.iterkeys()) -class PDUWithSerial(PDU): - """ - Base class for PDUs consisting of just a serial number and nonce. - """ + def check(self): + pass - header_struct = struct.Struct("!BBHLL") + @classmethod + def read_pdu(cls, reader): + return reader.update(need = cls.header_struct.size, callback = cls.got_header) - def __init__(self, version, serial = None, nonce = None): - super(PDUWithSerial, self).__init__(version) - if serial is not None: - assert isinstance(serial, int) - self.serial = serial - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce + @classmethod + def got_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.header_struct.size + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) + self = cls.version_map[version][pdu_type](version = version) + return reader.update(need = length, callback = self.got_pdu) - def __str__(self): - return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) - def to_pdu(self): +class PDUWithSerial(PDU): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a serial number and nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 12: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHLL") + + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU. + """ + + assert announce is None + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self class PDUWithNonce(PDU): - """ - Base class for PDUs consisting of just a nonce. - """ - - header_struct = struct.Struct("!BBHL") - - def __init__(self, version, nonce = None): - super(PDUWithNonce, self).__init__(version) - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) - return self._pdu + header_struct = struct.Struct("!BBHL") - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) -class PDUEmpty(PDU): - """ - Base class for empty PDUs. - """ + def to_pdu(self, announce = None): + """ + Generate the wire format PDU. + """ - header_struct = struct.Struct("!BBHL") + assert announce is None + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu - def __str__(self): - return "[%s]" % self.__class__.__name__ + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self - def to_pdu(self): + +class PDUEmpty(PDU): """ - Generate the wire format PDU for this prefix. + Base class for empty PDUs. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, zero, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if zero != 0: - raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHL") + + def __str__(self): + return "[%s]" % self.__class__.__name__ + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + assert announce is None + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class SerialNotifyPDU(PDUWithSerial): - """ - Serial Notify PDU. - """ + """ + Serial Notify PDU. + """ - pdu_type = 0 + pdu_type = 0 @wire_pdu class SerialQueryPDU(PDUWithSerial): - """ - Serial Query PDU. - """ + """ + Serial Query PDU. + """ - pdu_type = 1 + pdu_type = 1 - def __init__(self, version, serial = None, nonce = None): - super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) @wire_pdu class ResetQueryPDU(PDUEmpty): - """ - Reset Query PDU. - """ + """ + Reset Query PDU. + """ - pdu_type = 2 + pdu_type = 2 - def __init__(self, version): - super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) @wire_pdu class CacheResponsePDU(PDUWithNonce): - """ - Cache Response PDU. - """ + """ + Cache Response PDU. + """ - pdu_type = 3 + pdu_type = 3 def EndOfDataPDU(version, *args, **kwargs): - """ - Factory for the EndOfDataPDU classes, which take different forms in - different protocol versions. - """ + """ + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ - if version == 0: - return EndOfDataPDUv0(version, *args, **kwargs) - if version == 1: - return EndOfDataPDUv1(version, *args, **kwargs) - raise NotImplementedError + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError # Min, max, and default values, from the current RFC 6810 bis I-D. @@ -324,325 +332,345 @@ def EndOfDataPDU(version, *args, **kwargs): default_refresh = 3600 def valid_refresh(refresh): - if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: - raise ValueError - return refresh + if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: + raise ValueError + return refresh default_retry = 600 def valid_retry(retry): - if not isinstance(retry, int) or retry < 120 or retry > 7200: - raise ValueError - return retry + if not isinstance(retry, int) or retry < 120 or retry > 7200: + raise ValueError + return retry default_expire = 7200 def valid_expire(expire): - if not isinstance(expire, int) or expire < 600 or expire > 172800: - raise ValueError - return expire + if not isinstance(expire, int) or expire < 600 or expire > 172800: + raise ValueError + return expire @wire_pdu_only(0) class EndOfDataPDUv0(PDUWithSerial): - """ - End of Data PDU, protocol version 0. - """ + """ + End of Data PDU, protocol version 0. + """ - pdu_type = 7 + pdu_type = 7 - def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): - super(EndOfDataPDUv0, self).__init__(version, serial, nonce) - self.refresh = valid_refresh(default_refresh if refresh is None else refresh) - self.retry = valid_retry( default_retry if retry is None else retry) - self.expire = valid_expire( default_expire if expire is None else expire) + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv0, self).__init__(version, serial, nonce) + self.refresh = valid_refresh(default_refresh if refresh is None else refresh) + self.retry = valid_retry( default_retry if retry is None else retry) + self.expire = valid_expire( default_expire if expire is None else expire) @wire_pdu_only(1) class EndOfDataPDUv1(EndOfDataPDUv0): - """ - End of Data PDU, protocol version 1. - """ - - header_struct = struct.Struct("!BBHLLLLL") - - def __str__(self): - return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( - self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) - - def to_pdu(self): """ - Generate the wire format PDU. + End of Data PDU, protocol version 1. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial, - self.refresh, self.retry, self.expire) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ - = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 24: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHLLLLL") + + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU. + """ + + assert announce is None + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class CacheResetPDU(PDUEmpty): - """ - Cache reset PDU. - """ + """ + Cache reset PDU. + """ - pdu_type = 8 + pdu_type = 8 class PrefixPDU(PDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - header_struct = struct.Struct("!BB2xLBBBx") - asnum_struct = struct.Struct("!L") - - def __str__(self): - plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def show(self): - logging.debug("# Class: %s", self.__class__.__name__) - logging.debug("# ASN: %s", self.asn) - logging.debug("# Prefix: %s", self.prefix) - logging.debug("# Prefixlen: %s", self.prefixlen) - logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) - logging.debug("# Announce: %s", self.announce) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if self.prefix.bits != self.address_byte_count * 8: - raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) - if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: - raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) - if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: - raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - """ - Generate the wire format PDU for this prefix. - """ - - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, - announce if announce is not None else self.announce, - self.prefixlen, self.max_prefixlen) + - self.prefix.toBytes() + - self.asnum_struct.pack(self.asn)) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b1 = reader.get(self.header_struct.size) - b2 = reader.get(self.address_byte_count) - b3 = reader.get(self.asnum_struct.size) - version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) - assert version == self.version and pdu_type == self.pdu_type - if length != len(b1) + len(b2) + len(b3): - raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) - self.prefix = rpki.POW.IPAddress.fromBytes(b2) - self.asn = self.asnum_struct.unpack(b3)[0] - assert b1 + b2 + b3 == self.to_pdu() - return self + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + header_struct = struct.Struct("!BB2xLBBBx") + asnum_struct = struct.Struct("!L") + address_byte_count = 0 + + def __init__(self, version): + super(PrefixPDU, self).__init__(version) + self.asn = None + self.prefix = None + self.prefixlen = None + self.max_prefixlen = None + self.announce = None + + def __str__(self): + plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def show(self): + logging.debug("# Class: %s", self.__class__.__name__) + logging.debug("# ASN: %s", self.asn) + logging.debug("# Prefix: %s", self.prefix) + logging.debug("# Prefixlen: %s", self.prefixlen) + logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) + logging.debug("# Announce: %s", self.announce) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefix.bits != self.address_byte_count * 8: + raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, + announce if announce is not None else self.announce, + self.prefixlen, self.max_prefixlen) + + self.prefix.toBytes() + + self.asnum_struct.pack(self.asn)) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.get(self.header_struct.size) + b2 = reader.get(self.address_byte_count) + b3 = reader.get(self.asnum_struct.size) + version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) + self.prefix = rpki.POW.IPAddress.fromBytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() + return self @wire_pdu class IPv4PrefixPDU(PrefixPDU): - """ - IPv4 flavor of a prefix. - """ + """ + IPv4 flavor of a prefix. + """ - pdu_type = 4 - address_byte_count = 4 + pdu_type = 4 + address_byte_count = 4 @wire_pdu class IPv6PrefixPDU(PrefixPDU): - """ - IPv6 flavor of a prefix. - """ + """ + IPv6 flavor of a prefix. + """ - pdu_type = 6 - address_byte_count = 16 + pdu_type = 6 + address_byte_count = 16 @wire_pdu_only(1) class RouterKeyPDU(PDU): - """ - Router Key PDU. - """ - - pdu_type = 9 - - header_struct = struct.Struct("!BBBxL20sL") - - def __str__(self): - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, - base64.urlsafe_b64encode(self.ski).rstrip("="), - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if len(self.ski) != 20: - raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) - pdulen = self.header_struct.size + len(self.key) - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + len(self.key) - pdu = (self.header_struct.pack(self.version, - self.pdu_type, - announce if announce is not None else self.announce, - pdulen, - self.ski, - self.asn) - + self.key) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - if remaining <= 0: - raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) - self.key = reader.get(remaining) - assert header + self.key == self.to_pdu() - return self + """ + Router Key PDU. + """ + + pdu_type = 9 + + header_struct = struct.Struct("!BBBxL20sL") + + def __init__(self, version): + super(RouterKeyPDU, self).__init__(version) + self.announce = None + self.ski = None + self.asn = None + self.key = None + + def __str__(self): + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, + base64.urlsafe_b64encode(self.ski).rstrip("="), + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if len(self.ski) != 20: + raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) + pdulen = self.header_struct.size + len(self.key) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + len(self.key) + pdu = (self.header_struct.pack(self.version, + self.pdu_type, + announce if announce is not None else self.announce, + pdulen, + self.ski, + self.asn) + + self.key) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + if remaining <= 0: + raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) + self.key = reader.get(remaining) + assert header + self.key == self.to_pdu() + return self @wire_pdu class ErrorReportPDU(PDU): - """ - Error Report PDU. - """ - - pdu_type = 10 - - header_struct = struct.Struct("!BBHL") - string_struct = struct.Struct("!L") - - errors = { - 2 : "No Data Available" } - - fatal = { - 0 : "Corrupt Data", - 1 : "Internal Error", - 3 : "Invalid Request", - 4 : "Unsupported Protocol Version", - 5 : "Unsupported PDU Type", - 6 : "Withdrawal of Unknown Record", - 7 : "Duplicate Announcement Received" } - - assert set(errors) & set(fatal) == set() - - errors.update(fatal) - - codes = dict((v, k) for k, v in errors.items()) - - def __init__(self, version, errno = None, errpdu = None, errmsg = None): - super(ErrorReportPDU, self).__init__(version) - assert errno is None or errno in self.errors - self.errno = errno - self.errpdu = errpdu - self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] - - def __str__(self): - return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) - - def to_counted_string(self, s): - return self.string_struct.pack(len(s)) + s - - def read_counted_string(self, reader, remaining): - assert remaining >= self.string_struct.size - n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] - assert remaining >= self.string_struct.size + n - return n, reader.get(n), (remaining - self.string_struct.size - n) - - def to_pdu(self): - """ - Generate the wire format PDU for this error report. - """ - - if self._pdu is None: - assert isinstance(self.errno, int) - assert not isinstance(self.errpdu, ErrorReportPDU) - p = self.errpdu - if p is None: - p = "" - elif isinstance(p, PDU): - p = p.to_pdu() - assert isinstance(p, str) - pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) - self._pdu += self.to_counted_string(p) - self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.errno, length = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) - self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) - if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: - raise CorruptData("Got PDU length %d, expected %d" % ( - length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) - assert (header - + self.to_counted_string(self.errpdu) - + self.to_counted_string(self.errmsg.encode("utf8")) - == self.to_pdu()) - return self + """ + Error Report PDU. + """ + + pdu_type = 10 + + header_struct = struct.Struct("!BBHL") + string_struct = struct.Struct("!L") + + errors = { + 2 : "No Data Available" } + + fatal = { + 0 : "Corrupt Data", + 1 : "Internal Error", + 3 : "Invalid Request", + 4 : "Unsupported Protocol Version", + 5 : "Unsupported PDU Type", + 6 : "Withdrawal of Unknown Record", + 7 : "Duplicate Announcement Received" } + + assert set(errors) & set(fatal) == set() + + errors.update(fatal) + + codes = dict((v, k) for k, v in errors.items()) + + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) + assert errno is None or errno in self.errors + self.errno = errno + self.errpdu = errpdu + self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] + self.pdulen = None + self.errlen = None + + def __str__(self): + return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) + + def to_counted_string(self, s): + return self.string_struct.pack(len(s)) + s + + def read_counted_string(self, reader, remaining): + assert remaining >= self.string_struct.size + n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] + assert remaining >= self.string_struct.size + n + return n, reader.get(n), (remaining - self.string_struct.size - n) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this error report. + """ + + assert announce is None + if self._pdu is None: + assert isinstance(self.errno, int) + assert not isinstance(self.errpdu, ErrorReportPDU) + p = self.errpdu + if p is None: + p = "" + elif isinstance(p, PDU): + p = p.to_pdu() + assert isinstance(p, str) + pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) + self._pdu += self.to_counted_string(p) + self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) + self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) + assert (header + + self.to_counted_string(self.errpdu) + + self.to_counted_string(self.errmsg.encode("utf8")) + == self.to_pdu()) + return self diff --git a/rpki/rtr/server.py b/rpki/rtr/server.py index 2ea3a040..c08320fc 100644 --- a/rpki/rtr/server.py +++ b/rpki/rtr/server.py @@ -44,37 +44,37 @@ kickme_base = os.path.join(kickme_dir, "kickme") class PDU(rpki.rtr.pdus.PDU): - """ - Generic server PDU. - """ - - def send_file(self, server, filename): """ - Send a content of a file as a cache response. Caller should catch IOError. + Generic server PDU. """ - fn2 = os.path.splitext(filename)[1] - assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version - - f = open(filename, "rb") - server.push_pdu(CacheResponsePDU(version = server.version, - nonce = server.current_nonce)) - server.push_file(f) - server.push_pdu(EndOfDataPDU(version = server.version, - serial = server.current_serial, - nonce = server.current_nonce, - refresh = server.refresh, - retry = server.retry, - expire = server.expire)) - - def send_nodata(self, server): - """ - Send a nodata error. - """ + def send_file(self, server, filename): + """ + Send a content of a file as a cache response. Caller should catch IOError. + """ + + fn2 = os.path.splitext(filename)[1] + assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version - server.push_pdu(ErrorReportPDU(version = server.version, - errno = ErrorReportPDU.codes["No Data Available"], - errpdu = self)) + f = open(filename, "rb") + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_file(f) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce, + refresh = server.refresh, + retry = server.retry, + expire = server.expire)) + + def send_nodata(self, server): + """ + Send a nodata error. + """ + + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["No Data Available"], + errpdu = self)) clone_pdu = clone_pdu_root(PDU) @@ -82,513 +82,513 @@ clone_pdu = clone_pdu_root(PDU) @clone_pdu class SerialQueryPDU(PDU, rpki.rtr.pdus.SerialQueryPDU): - """ - Serial Query PDU. - """ - - def serve(self, server): - """ - Received a serial query, send incremental transfer in response. - If client is already up to date, just send an empty incremental - transfer. """ - - server.logger.debug(self) - if server.get_serial() is None: - self.send_nodata(server) - elif server.current_nonce != self.nonce: - server.logger.info("[Client requested wrong nonce, resetting client]") - server.push_pdu(CacheResetPDU(version = server.version)) - elif server.current_serial == self.serial: - server.logger.debug("[Client is already current, sending empty IXFR]") - server.push_pdu(CacheResponsePDU(version = server.version, - nonce = server.current_nonce)) - server.push_pdu(EndOfDataPDU(version = server.version, - serial = server.current_serial, - nonce = server.current_nonce, - refresh = server.refresh, - retry = server.retry, - expire = server.expire)) - elif disable_incrementals: - server.push_pdu(CacheResetPDU(version = server.version)) - else: - try: - self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version)) - except IOError: - server.push_pdu(CacheResetPDU(version = server.version)) + Serial Query PDU. + """ + + def serve(self, server): + """ + Received a serial query, send incremental transfer in response. + If client is already up to date, just send an empty incremental + transfer. + """ + + server.logger.debug(self) + if server.get_serial() is None: + self.send_nodata(server) + elif server.current_nonce != self.nonce: + server.logger.info("[Client requested wrong nonce, resetting client]") + server.push_pdu(CacheResetPDU(version = server.version)) + elif server.current_serial == self.serial: + server.logger.debug("[Client is already current, sending empty IXFR]") + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce, + refresh = server.refresh, + retry = server.retry, + expire = server.expire)) + elif disable_incrementals: + server.push_pdu(CacheResetPDU(version = server.version)) + else: + try: + self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version)) + except IOError: + server.push_pdu(CacheResetPDU(version = server.version)) @clone_pdu class ResetQueryPDU(PDU, rpki.rtr.pdus.ResetQueryPDU): - """ - Reset Query PDU. - """ - - def serve(self, server): """ - Received a reset query, send full current state in response. + Reset Query PDU. """ - server.logger.debug(self) - if server.get_serial() is None: - self.send_nodata(server) - else: - try: - fn = "%d.ax.v%d" % (server.current_serial, server.version) - self.send_file(server, fn) - except IOError: - server.push_pdu(ErrorReportPDU(version = server.version, - errno = ErrorReportPDU.codes["Internal Error"], - errpdu = self, - errmsg = "Couldn't open %s" % fn)) + def serve(self, server): + """ + Received a reset query, send full current state in response. + """ + + server.logger.debug(self) + if server.get_serial() is None: + self.send_nodata(server) + else: + try: + fn = "%d.ax.v%d" % (server.current_serial, server.version) + self.send_file(server, fn) + except IOError: + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["Internal Error"], + errpdu = self, + errmsg = "Couldn't open %s" % fn)) @clone_pdu class ErrorReportPDU(rpki.rtr.pdus.ErrorReportPDU): - """ - Error Report PDU. - """ - - def serve(self, server): """ - Received an ErrorReportPDU from client. Not much we can do beyond - logging it, then killing the connection if error was fatal. + Error Report PDU. """ - server.logger.error(self) - if self.errno in self.fatal: - server.logger.error("[Shutting down due to reported fatal protocol error]") - sys.exit(1) + def serve(self, server): + """ + Received an ErrorReportPDU from client. Not much we can do beyond + logging it, then killing the connection if error was fatal. + """ + + server.logger.error(self) + if self.errno in self.fatal: + server.logger.error("[Shutting down due to reported fatal protocol error]") + sys.exit(1) def read_current(version): - """ - Read current serial number and nonce. Return None for both if - serial and nonce not recorded. For backwards compatibility, treat - file containing just a serial number as having a nonce of zero. - """ - - if version is None: - return None, None - try: - with open("current.v%d" % version, "r") as f: - values = tuple(int(s) for s in f.read().split()) - return values[0], values[1] - except IndexError: - return values[0], 0 - except IOError: - return None, None + """ + Read current serial number and nonce. Return None for both if + serial and nonce not recorded. For backwards compatibility, treat + file containing just a serial number as having a nonce of zero. + """ + + if version is None: + return None, None + try: + with open("current.v%d" % version, "r") as f: + values = tuple(int(s) for s in f.read().split()) + return values[0], values[1] + except IndexError: + return values[0], 0 + except IOError: + return None, None def write_current(serial, nonce, version): - """ - Write serial number and nonce. - """ + """ + Write serial number and nonce. + """ - curfn = "current.v%d" % version - tmpfn = curfn + "%d.tmp" % os.getpid() - with open(tmpfn, "w") as f: - f.write("%d %d\n" % (serial, nonce)) - os.rename(tmpfn, curfn) + curfn = "current.v%d" % version + tmpfn = curfn + "%d.tmp" % os.getpid() + with open(tmpfn, "w") as f: + f.write("%d %d\n" % (serial, nonce)) + os.rename(tmpfn, curfn) class FileProducer(object): - """ - File-based producer object for asynchat. - """ + """ + File-based producer object for asynchat. + """ - def __init__(self, handle, buffersize): - self.handle = handle - self.buffersize = buffersize + def __init__(self, handle, buffersize): + self.handle = handle + self.buffersize = buffersize - def more(self): - return self.handle.read(self.buffersize) + def more(self): + return self.handle.read(self.buffersize) class ServerWriteChannel(rpki.rtr.channels.PDUChannel): - """ - Kludge to deal with ssh's habit of sometimes (compile time option) - invoking us with two unidirectional pipes instead of one - bidirectional socketpair. All the server logic is in the - ServerChannel class, this class just deals with sending the - server's output to a different file descriptor. - """ - - def __init__(self): """ - Set up stdout. + Kludge to deal with ssh's habit of sometimes (compile time option) + invoking us with two unidirectional pipes instead of one + bidirectional socketpair. All the server logic is in the + ServerChannel class, this class just deals with sending the + server's output to a different file descriptor. """ - super(ServerWriteChannel, self).__init__(root_pdu_class = PDU) - self.init_file_dispatcher(sys.stdout.fileno()) + def __init__(self): + """ + Set up stdout. + """ - def readable(self): - """ - This channel is never readable. - """ + super(ServerWriteChannel, self).__init__(root_pdu_class = PDU) + self.init_file_dispatcher(sys.stdout.fileno()) - return False + def readable(self): + """ + This channel is never readable. + """ - def push_file(self, f): - """ - Write content of a file to stream. - """ + return False - try: - self.push_with_producer(FileProducer(f, self.ac_out_buffer_size)) - except OSError, e: - if e.errno != errno.EAGAIN: - raise + def push_file(self, f): + """ + Write content of a file to stream. + """ + try: + self.push_with_producer(FileProducer(f, self.ac_out_buffer_size)) + except OSError, e: + if e.errno != errno.EAGAIN: + raise -class ServerChannel(rpki.rtr.channels.PDUChannel): - """ - Server protocol engine, handles upcalls from PDUChannel to - implement protocol logic. - """ - def __init__(self, logger, refresh, retry, expire): +class ServerChannel(rpki.rtr.channels.PDUChannel): """ - Set up stdin and stdout as connection and start listening for - first PDU. + Server protocol engine, handles upcalls from PDUChannel to + implement protocol logic. """ - super(ServerChannel, self).__init__(root_pdu_class = PDU) - self.init_file_dispatcher(sys.stdin.fileno()) - self.writer = ServerWriteChannel() - self.logger = logger - self.refresh = refresh - self.retry = retry - self.expire = expire - self.get_serial() - self.start_new_pdu() - - def writable(self): - """ - This channel is never writable. - """ + def __init__(self, logger, refresh, retry, expire): + """ + Set up stdin and stdout as connection and start listening for + first PDU. + """ - return False + super(ServerChannel, self).__init__(root_pdu_class = PDU) + self.init_file_dispatcher(sys.stdin.fileno()) + self.writer = ServerWriteChannel() + self.logger = logger + self.refresh = refresh + self.retry = retry + self.expire = expire + self.get_serial() + self.start_new_pdu() - def push(self, data): - """ - Redirect to writer channel. - """ + def writable(self): + """ + This channel is never writable. + """ - return self.writer.push(data) + return False - def push_with_producer(self, producer): - """ - Redirect to writer channel. - """ + def push(self, data): + """ + Redirect to writer channel. + """ - return self.writer.push_with_producer(producer) + return self.writer.push(data) - def push_pdu(self, pdu): - """ - Redirect to writer channel. - """ + def push_with_producer(self, producer): + """ + Redirect to writer channel. + """ - return self.writer.push_pdu(pdu) + return self.writer.push_with_producer(producer) - def push_file(self, f): - """ - Redirect to writer channel. - """ + def push_pdu(self, pdu): + """ + Redirect to writer channel. + """ - return self.writer.push_file(f) + return self.writer.push_pdu(pdu) - def deliver_pdu(self, pdu): - """ - Handle received PDU. - """ + def push_file(self, f): + """ + Redirect to writer channel. + """ - pdu.serve(self) + return self.writer.push_file(f) - def get_serial(self): - """ - Read, cache, and return current serial number, or None if we can't - find the serial number file. The latter condition should never - happen, but maybe we got started in server mode while the cronjob - mode instance is still building its database. - """ + def deliver_pdu(self, pdu): + """ + Handle received PDU. + """ - self.current_serial, self.current_nonce = read_current(self.version) - return self.current_serial + pdu.serve(self) - def check_serial(self): - """ - Check for a new serial number. - """ + def get_serial(self): + """ + Read, cache, and return current serial number, or None if we can't + find the serial number file. The latter condition should never + happen, but maybe we got started in server mode while the cronjob + mode instance is still building its database. + """ - old_serial = self.current_serial - return old_serial != self.get_serial() + self.current_serial, self.current_nonce = read_current(self.version) + return self.current_serial - def notify(self, data = None, force = False): - """ - Cronjob instance kicked us: check whether our serial number has - changed, and send a notify message if so. + def check_serial(self): + """ + Check for a new serial number. + """ - We have to check rather than just blindly notifying when kicked - because the cronjob instance has no good way of knowing which - protocol version we're running, thus has no good way of knowing - whether we care about a particular change set or not. - """ + old_serial = self.current_serial + return old_serial != self.get_serial() - if force or self.check_serial(): - self.push_pdu(SerialNotifyPDU(version = self.version, - serial = self.current_serial, - nonce = self.current_nonce)) - else: - self.logger.debug("Cronjob kicked me but I see no serial change, ignoring") + def notify(self, data = None, force = False): + """ + Cronjob instance kicked us: check whether our serial number has + changed, and send a notify message if so. + + We have to check rather than just blindly notifying when kicked + because the cronjob instance has no good way of knowing which + protocol version we're running, thus has no good way of knowing + whether we care about a particular change set or not. + """ + + if force or self.check_serial(): + self.push_pdu(SerialNotifyPDU(version = self.version, + serial = self.current_serial, + nonce = self.current_nonce)) + else: + self.logger.debug("Cronjob kicked me but I see no serial change, ignoring") class KickmeChannel(asyncore.dispatcher, object): - """ - asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to - kick servers when it's time to send notify PDUs to clients. - """ - - def __init__(self, server): - asyncore.dispatcher.__init__(self) # Old-style class - self.server = server - self.sockname = "%s.%d" % (kickme_base, os.getpid()) - self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) - try: - self.bind(self.sockname) - os.chmod(self.sockname, 0660) - except socket.error, e: - self.server.logger.exception("Couldn't bind() kickme socket: %r", e) - self.close() - except OSError, e: - self.server.logger.exception("Couldn't chmod() kickme socket: %r", e) - - def writable(self): """ - This socket is read-only, never writable. + asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to + kick servers when it's time to send notify PDUs to clients. """ - return False + def __init__(self, server): + asyncore.dispatcher.__init__(self) # Old-style class + self.server = server + self.sockname = "%s.%d" % (kickme_base, os.getpid()) + self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) + try: + self.bind(self.sockname) + os.chmod(self.sockname, 0660) + except socket.error, e: + self.server.logger.exception("Couldn't bind() kickme socket: %r", e) + self.close() + except OSError, e: + self.server.logger.exception("Couldn't chmod() kickme socket: %r", e) + + def writable(self): + """ + This socket is read-only, never writable. + """ + + return False + + def handle_connect(self): + """ + Ignore connect events (not very useful on datagram socket). + """ + + pass + + def handle_read(self): + """ + Handle receipt of a datagram. + """ + + data = self.recv(512) + self.server.notify(data) + + def cleanup(self): + """ + Clean up this dispatcher's socket. + """ + + self.close() + try: + os.unlink(self.sockname) + except: + pass - def handle_connect(self): - """ - Ignore connect events (not very useful on datagram socket). - """ + def log(self, msg): + """ + Intercept asyncore's logging. + """ - pass + self.server.logger.info(msg) - def handle_read(self): - """ - Handle receipt of a datagram. - """ + def log_info(self, msg, tag = "info"): + """ + Intercept asyncore's logging. + """ - data = self.recv(512) - self.server.notify(data) + self.server.logger.info("asyncore: %s: %s", tag, msg) - def cleanup(self): - """ - Clean up this dispatcher's socket. - """ + def handle_error(self): + """ + Handle errors caught by asyncore main loop. + """ + + self.server.logger.exception("[Unhandled exception]") + self.server.logger.critical("[Exiting after unhandled exception]") + sys.exit(1) - self.close() - try: - os.unlink(self.sockname) - except: # pylint: disable=W0702 - pass - def log(self, msg): +def hostport_tag(): """ - Intercept asyncore's logging. + Construct hostname/address + port when we're running under a + protocol we understand well enough to do that. This is all + kludgery. Just grit your teeth, or perhaps just close your eyes. """ - self.server.logger.info(msg) + proto = None - def log_info(self, msg, tag = "info"): - """ - Intercept asyncore's logging. - """ + if proto is None: + try: + host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername() + proto = "tcp" + except: + pass + + if proto is None: + try: + host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2] + proto = "tcp" + except: + pass + + if proto is None: + try: + host, port = os.environ["SSH_CONNECTION"].split()[0:2] + proto = "ssh" + except: + pass + + if proto is None: + try: + host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT") + proto = "ssl" + except: + pass + + if proto is None: + return "" + elif not port: + return "/%s/%s" % (proto, host) + elif ":" in host: + return "/%s/%s.%s" % (proto, host, port) + else: + return "/%s/%s:%s" % (proto, host, port) - self.server.logger.info("asyncore: %s: %s", tag, msg) - def handle_error(self): +def server_main(args): """ - Handle errors caught by asyncore main loop. + Implement the server side of the rpkk-router protocol. Other than + one PF_UNIX socket inode, this doesn't write anything to disk, so it + can be run with minimal privileges. Most of the work has already + been done by the database generator, so all this server has to do is + pass the results along to a client. """ - self.server.logger.exception("[Unhandled exception]") - self.server.logger.critical("[Exiting after unhandled exception]") - sys.exit(1) - + logger = logging.LoggerAdapter(logging.root, dict(connection = hostport_tag())) -def _hostport_tag(): - """ - Construct hostname/address + port when we're running under a - protocol we understand well enough to do that. This is all - kludgery. Just grit your teeth, or perhaps just close your eyes. - """ + logger.debug("[Starting]") - proto = None + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + logger.error("[Couldn't chdir(%r), exiting: %s]", args.rpki_rtr_dir, e) + sys.exit(1) - if proto is None: + kickme = None try: - host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername() - proto = "tcp" - except: # pylint: disable=W0702 - pass + server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) + kickme = rpki.rtr.server.KickmeChannel(server = server) + asyncore.loop(timeout = None) + signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition + except KeyboardInterrupt: + sys.exit(0) + finally: + signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition + if kickme is not None: + kickme.cleanup() - if proto is None: - try: - host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2] - proto = "tcp" - except: # pylint: disable=W0702 - pass - if proto is None: - try: - host, port = os.environ["SSH_CONNECTION"].split()[0:2] - proto = "ssh" - except: # pylint: disable=W0702 - pass +def listener_main(args): + """ + Totally insecure TCP listener for rpki-rtr protocol. We only + implement this because it's all that the routers currently support. + In theory, we will all be running TCP-AO in the future, at which + point this listener will go away or become a TCP-AO listener. + """ - if proto is None: - try: - host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT") - proto = "ssl" - except: # pylint: disable=W0702 - pass + # Perhaps we should daemonize? Deal with that later. - if proto is None: - return "" - elif not port: - return "/%s/%s" % (proto, host) - elif ":" in host: - return "/%s/%s.%s" % (proto, host, port) - else: - return "/%s/%s:%s" % (proto, host, port) + # server_main() handles args.rpki_rtr_dir. + listener = None + try: + listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except: + if listener is not None: + listener.close() + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except AttributeError: + pass + listener.bind(("", args.port)) + listener.listen(5) + logging.debug("[Listening on port %s]", args.port) + while True: + try: + s, ai = listener.accept() + except KeyboardInterrupt: + sys.exit(0) + logging.debug("[Received connection from %r]", ai) + pid = os.fork() + if pid == 0: + os.dup2(s.fileno(), 0) # pylint: disable=E1101 + os.dup2(s.fileno(), 1) # pylint: disable=E1101 + s.close() + #os.closerange(3, os.sysconf("SC_OPEN_MAX")) + server_main(args) + sys.exit() + else: + logging.debug("[Spawned server %d]", pid) + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + if pid: + logging.debug("[Server %s exited with status 0x%x]", pid, status) + continue + except: + pass + break -def server_main(args): - """ - Implement the server side of the rpkk-router protocol. Other than - one PF_UNIX socket inode, this doesn't write anything to disk, so it - can be run with minimal privileges. Most of the work has already - been done by the database generator, so all this server has to do is - pass the results along to a client. - """ - logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag())) +def argparse_setup(subparsers): + """ + Set up argparse stuff for commands in this module. + """ - logger.debug("[Starting]") + # These could have been lambdas, but doing it this way results in + # more useful error messages on argparse failures. - if args.rpki_rtr_dir: - try: - os.chdir(args.rpki_rtr_dir) - except OSError, e: - logger.error("[Couldn't chdir(%r), exiting: %s]", args.rpki_rtr_dir, e) - sys.exit(1) - - kickme = None - try: - server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire) - kickme = rpki.rtr.server.KickmeChannel(server = server) - asyncore.loop(timeout = None) - signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition - except KeyboardInterrupt: - sys.exit(0) - finally: - signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition - if kickme is not None: - kickme.cleanup() + def refresh(v): + return rpki.rtr.pdus.valid_refresh(int(v)) + def retry(v): + return rpki.rtr.pdus.valid_retry(int(v)) -def listener_main(args): - """ - Totally insecure TCP listener for rpki-rtr protocol. We only - implement this because it's all that the routers currently support. - In theory, we will all be running TCP-AO in the future, at which - point this listener will go away or become a TCP-AO listener. - """ - - # Perhaps we should daemonize? Deal with that later. - - # server_main() handles args.rpki_rtr_dir. - - listener = None - try: - listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - except: # pylint: disable=W0702 - if listener is not None: - listener.close() - listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - try: - listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) - except AttributeError: - pass - listener.bind(("", args.port)) - listener.listen(5) - logging.debug("[Listening on port %s]", args.port) - while True: - try: - s, ai = listener.accept() - except KeyboardInterrupt: - sys.exit(0) - logging.debug("[Received connection from %r]", ai) - pid = os.fork() - if pid == 0: - os.dup2(s.fileno(), 0) # pylint: disable=E1103 - os.dup2(s.fileno(), 1) # pylint: disable=E1103 - s.close() - #os.closerange(3, os.sysconf("SC_OPEN_MAX")) - server_main(args) - sys.exit() - else: - logging.debug("[Spawned server %d]", pid) - while True: - try: - pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612 - if pid: - logging.debug("[Server %s exited]", pid) - continue - except: # pylint: disable=W0702 - pass - break + def expire(v): + return rpki.rtr.pdus.valid_expire(int(v)) + # Some duplication of arguments here, not enough to be worth huge + # effort to clean up, worry about it later in any case. -def argparse_setup(subparsers): - """ - Set up argparse stuff for commands in this module. - """ - - # These could have been lambdas, but doing it this way results in - # more useful error messages on argparse failures. - - def refresh(v): - return rpki.rtr.pdus.valid_refresh(int(v)) - - def retry(v): - return rpki.rtr.pdus.valid_retry(int(v)) - - def expire(v): - return rpki.rtr.pdus.valid_expire(int(v)) - - # Some duplication of arguments here, not enough to be worth huge - # effort to clean up, worry about it later in any case. - - subparser = subparsers.add_parser("server", description = server_main.__doc__, - help = "RPKI-RTR protocol server") - subparser.set_defaults(func = server_main, default_log_to = "syslog") - subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") - subparser.add_argument("--retry", type = retry, help = "override default retry timer") - subparser.add_argument("--expire", type = expire, help = "override default expire timer") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") - - subparser = subparsers.add_parser("listener", description = listener_main.__doc__, - help = "TCP listener for RPKI-RTR protocol server") - subparser.set_defaults(func = listener_main, default_log_to = "syslog") - subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") - subparser.add_argument("--retry", type = retry, help = "override default retry timer") - subparser.add_argument("--expire", type = expire, help = "override default expire timer") - subparser.add_argument("port", type = int, help = "TCP port on which to listen") - subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + subparser = subparsers.add_parser("server", description = server_main.__doc__, + help = "RPKI-RTR protocol server") + subparser.set_defaults(func = server_main, default_log_destination = "syslog") + subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") + subparser.add_argument("--retry", type = retry, help = "override default retry timer") + subparser.add_argument("--expire", type = expire, help = "override default expire timer") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + + subparser = subparsers.add_parser("listener", description = listener_main.__doc__, + help = "TCP listener for RPKI-RTR protocol server") + subparser.set_defaults(func = listener_main, default_log_destination = "syslog") + subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer") + subparser.add_argument("--retry", type = retry, help = "override default retry timer") + subparser.add_argument("--expire", type = expire, help = "override default expire timer") + subparser.add_argument("port", type = int, help = "TCP port on which to listen") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") |