diff options
author | Rob Austein <sra@hactrn.net> | 2014-04-17 22:27:29 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2014-04-17 22:27:29 +0000 |
commit | bcd5cf161f3b28ee3ecc37c88aa7768e7ffb7e00 (patch) | |
tree | 8c277bbc8a555104d0747620ccb9a22f0b46fe41 /rp | |
parent | 85589d2c84ce1eb91c04ef7534db6a303f28297a (diff) |
Use class decorator to construct PDU dispatch list (preparation for
supporting multiple protocol versions). Start dragging coding
standard up to something a little more recent.
svn path=/trunk/; revision=5810
Diffstat (limited to 'rp')
-rwxr-xr-x | rp/rpki-rtr/rtr-origin | 371 |
1 files changed, 239 insertions, 132 deletions
diff --git a/rp/rpki-rtr/rtr-origin b/rp/rpki-rtr/rtr-origin index 06ae2ee4..c676e0d7 100755 --- a/rp/rpki-rtr/rtr-origin +++ b/rp/rpki-rtr/rtr-origin @@ -2,19 +2,19 @@ # Router origin-authentication rpki-router protocol implementation. See # draft-ietf-sidr-rpki-rtr in fine Internet-Draft repositories near you. -# +# # Run the program with the --help argument for usage information, or see # documentation for the *_main() functions. # -# +# # $Id$ -# +# # Copyright (C) 2009-2013 Internet Systems Consortium ("ISC") -# +# # Permission to use, copy, modify, and distribute this software for any # purpose with or without fee is hereby granted, provided that the above # copyright notice and this permission notice appear in all copies. -# +# # THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH # REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY # AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT, @@ -54,7 +54,7 @@ class IgnoreThisRecord(Exception): pass -class timestamp(int): +class Timestamp(int): """ Wrapper around time module. """ @@ -102,6 +102,7 @@ def read_current(): serial and nonce not recorded. For backwards compatibility, treat file containing just a serial number as having a nonce of zero. """ + try: f = open("current", "r") values = tuple(int(s) for s in f.read().split()) @@ -116,6 +117,7 @@ def write_current(serial, nonce): """ Write serial number and nonce. """ + tmpfn = "current.%d.tmp" % os.getpid() try: f = open(tmpfn, "w") @@ -133,6 +135,7 @@ def new_nonce(): """ Create and return a new nonce value. """ + if force_zero_nonce: return 0 try: @@ -141,7 +144,7 @@ def new_nonce(): return int(random.getrandbits(16)) -class read_buffer(object): +class ReadBuffer(object): """ Wrapper around synchronous/asynchronous read state. """ @@ -153,6 +156,7 @@ class read_buffer(object): """ Update count of needed bytes and callback, then dispatch to callback. """ + self.need = need self.callback = callback return self.callback(self) @@ -161,24 +165,28 @@ class read_buffer(object): """ How much data do we have available in this buffer? """ + return len(self.buffer) def needed(self): """ How much more data does this buffer need to become ready? """ + return self.need - self.available() def ready(self): """ Is this buffer ready to read yet? """ + return self.available() >= self.need def get(self, n): """ Hand some data to the caller. """ + b = self.buffer[:n] self.buffer = self.buffer[n:] return b @@ -187,23 +195,26 @@ class read_buffer(object): """ Accumulate some data. """ + self.buffer += b def retry(self): """ Try dispatching to the callback again. """ + return self.callback(self) class PDUException(Exception): """ Parent exception type for exceptions that signal particular protocol errors. String value of exception instance will be the message to - put in the error_report PDU, error_report_code value of exception + put in the ErrorReportPDU, error_report_code value of exception will be the numeric code to use. """ def __init__(self, msg = None, pdu = None): + Exception.__init__(self) assert msg is None or isinstance(msg, (str, unicode)) self.error_report_msg = msg self.error_report_pdu = pdu @@ -212,9 +223,9 @@ class PDUException(Exception): return self.error_report_msg or self.__class__.__name__ def make_error_report(self): - return error_report(errno = self.error_report_code, - errmsg = self.error_report_msg, - errpdu = self.error_report_pdu) + return ErrorReportPDU(errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) class UnsupportedProtocolVersion(PDUException): error_report_code = 4 @@ -225,12 +236,28 @@ class UnsupportedPDUType(PDUException): class CorruptData(PDUException): error_report_code = 0 -class pdu(object): + +def wire_pdu(cls): + """ + Class decorator to add a PDU class to the set of known PDUs. + + In the long run, this decorator may take additional arguments + specifying which protocol version(s) use this particular PDU, + but we're not there yet. + """ + + PDU.pdu_map[cls.pdu_type] = cls + return cls + + +class PDU(object): """ Object representing a generic PDU in the rpki-router protocol. Real PDUs are subclasses of this class. """ + pdu_map = {} # Updated by @wire_pdu + version = 0 # Protocol version _pdu = None # Cached when first generated @@ -244,6 +271,7 @@ class pdu(object): """ Check attributes to make sure they're within range. """ + pass @classmethod @@ -273,24 +301,27 @@ class pdu(object): Handle results in test client. Default behavior is just to print out the PDU. """ + blather(self) def send_file(self, server, filename): """ Send a content of a file as a cache response. Caller should catch IOError. """ + f = open(filename, "rb") - server.push_pdu(cache_response(nonce = server.current_nonce)) + server.push_pdu(CacheResponsePDU(nonce = server.current_nonce)) server.push_file(f) - server.push_pdu(end_of_data(serial = server.current_serial, nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce)) def send_nodata(self, server): """ Send a nodata error. """ - server.push_pdu(error_report(errno = error_report.codes["No Data Available"], errpdu = self)) -class pdu_with_serial(pdu): + server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["No Data Available"], errpdu = self)) + +class PDUWithSerial(PDU): """ Base class for PDUs consisting of just a serial number and nonce. """ @@ -312,6 +343,7 @@ class pdu_with_serial(pdu): """ Generate the wire format PDU. """ + if self._pdu is None: self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size, self.serial) @@ -327,7 +359,7 @@ class pdu_with_serial(pdu): assert b == self.to_pdu() return self -class pdu_nonce(pdu): +class PDUWithNonce(PDU): """ Base class for PDUs consisting of just a nonce. """ @@ -346,6 +378,7 @@ class pdu_nonce(pdu): """ Generate the wire format PDU. """ + if self._pdu is None: self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) return self._pdu @@ -360,7 +393,7 @@ class pdu_nonce(pdu): assert b == self.to_pdu() return self -class pdu_empty(pdu): +class PDUEmpty(PDU): """ Base class for empty PDUs. """ @@ -374,6 +407,7 @@ class pdu_empty(pdu): """ Generate the wire format PDU for this prefix. """ + if self._pdu is None: self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) return self._pdu @@ -390,7 +424,8 @@ class pdu_empty(pdu): assert b == self.to_pdu() return self -class serial_notify(pdu_with_serial): +@wire_pdu +class SerialNotifyPDU(PDUWithSerial): """ Serial Notify PDU. """ @@ -399,18 +434,20 @@ class serial_notify(pdu_with_serial): def consume(self, client): """ - Respond to a serial_notify message with either a serial_query or - reset_query, depending on what we already know. + Respond to a SerialNotifyPDU with either a SerialQueryPDU or a + ResetQueryPDU, depending on what we already know. """ + blather(self) if client.current_serial is None or client.current_nonce != self.nonce: - client.push_pdu(reset_query()) + client.push_pdu(ResetQueryPDU()) elif self.serial != client.current_serial: - client.push_pdu(serial_query(serial = client.current_serial, nonce = client.current_nonce)) + client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce)) else: blather("[Notify did not change serial number, ignoring]") -class serial_query(pdu_with_serial): +@wire_pdu +class SerialQueryPDU(PDUWithSerial): """ Serial Query PDU. """ @@ -423,25 +460,27 @@ class serial_query(pdu_with_serial): If client is already up to date, just send an empty incremental transfer. """ + blather(self) if server.get_serial() is None: self.send_nodata(server) elif server.current_nonce != self.nonce: log("[Client requested wrong nonce, resetting client]") - server.push_pdu(cache_reset()) + server.push_pdu(CacheResetPDU()) elif server.current_serial == self.serial: blather("[Client is already current, sending empty IXFR]") - server.push_pdu(cache_response(nonce = server.current_nonce)) - server.push_pdu(end_of_data(serial = server.current_serial, nonce = server.current_nonce)) + server.push_pdu(CacheResponsePDU(nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce)) elif disable_incrementals: - server.push_pdu(cache_reset()) + server.push_pdu(CacheResetPDU()) else: try: self.send_file(server, "%d.ix.%d" % (server.current_serial, self.serial)) except IOError: - server.push_pdu(cache_reset()) + server.push_pdu(CacheResetPDU()) -class reset_query(pdu_empty): +@wire_pdu +class ResetQueryPDU(PDUEmpty): """ Reset Query PDU. """ @@ -452,6 +491,7 @@ class reset_query(pdu_empty): """ Received a reset query, send full current state in response. """ + blather(self) if server.get_serial() is None: self.send_nodata(server) @@ -460,10 +500,11 @@ class reset_query(pdu_empty): fn = "%d.ax" % server.current_serial self.send_file(server, fn) except IOError: - server.push_pdu(error_report(errno = error_report.codes["Internal Error"], - errpdu = self, errmsg = "Couldn't open %s" % fn)) + server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["Internal Error"], + errpdu = self, errmsg = "Couldn't open %s" % fn)) -class cache_response(pdu_nonce): +@wire_pdu +class CacheResponsePDU(PDUWithNonce): """ Cache Response PDU. """ @@ -472,14 +513,16 @@ class cache_response(pdu_nonce): def consume(self, client): """ - Handle cache_response. + Handle CacheResponsePDU. """ + blather(self) if self.nonce != client.current_nonce: blather("[Nonce changed, resetting]") client.cache_reset() -class end_of_data(pdu_with_serial): +@wire_pdu +class EndOfDataPDU(PDUWithSerial): """ End of Data PDU. """ @@ -488,12 +531,14 @@ class end_of_data(pdu_with_serial): def consume(self, client): """ - Handle end_of_data response. + Handle EndOfDataPDU response. """ + blather(self) client.end_of_data(self.serial, self.nonce) -class cache_reset(pdu_empty): +@wire_pdu +class CacheResetPDU(PDUEmpty): """ Cache reset PDU. """ @@ -502,21 +547,22 @@ class cache_reset(pdu_empty): def consume(self, client): """ - Handle cache_reset response, by issuing a reset_query. + Handle CacheResetPDU response, by issuing a ResetQueryPDU. """ + blather(self) client.cache_reset() - client.push_pdu(reset_query()) + client.push_pdu(ResetQueryPDU()) -class prefix(pdu): +class PrefixPDU(PDU): """ Object representing one prefix. This corresponds closely to one PDU in the rpki-router protocol, so closely that we use lexical ordering of the wire format of the PDU as the ordering for this class. This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (ipv4_prefix or - ipv6_prefix) depending on the syntax of its input text. + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. """ header_struct = struct.Struct("!BB2xLBBBx") @@ -528,7 +574,7 @@ class prefix(pdu): Construct a prefix from its text form. """ - cls = ipv6_prefix if ":" in addr else ipv4_prefix + cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU self = cls() self.asn = long(asnum) p, l = addr.split("/") @@ -540,7 +586,7 @@ class prefix(pdu): self.announce = 1 self.check() return self - + @staticmethod def from_roa(asnum, prefix_tuple): """ @@ -548,7 +594,7 @@ class prefix(pdu): """ address, length, maxlength = prefix_tuple - cls = ipv6_prefix if address.version == 6 else ipv4_prefix + cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU self = cls() self.asn = asnum # Kludge: Should just use IPAddress, coersion here is historical @@ -576,6 +622,7 @@ class prefix(pdu): """ Handle one incoming prefix PDU """ + blather(self) client.consume_prefix(self) @@ -583,6 +630,7 @@ class prefix(pdu): """ Check attributes to make sure they're within range. """ + if self.announce not in (0, 1): raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) if self.prefixlen < 0 or self.prefixlen > self.addr_type.size * 8: @@ -597,6 +645,7 @@ class prefix(pdu): """ Generate the wire format PDU for this prefix. """ + if announce is not None: assert announce in (0, 1) elif self._pdu is not None: @@ -633,9 +682,9 @@ class prefix(pdu): fields = line.split("|") # Parse prefix, including figuring out IP protocol version - cls = ipv6_prefix if ":" in fields[5] else ipv4_prefix + cls = IPv6PrefixPDU if ":" in fields[5] else IPv4PrefixPDU self = cls() - self.timestamp = timestamp(fields[1]) + self.timestamp = Timestamp(fields[1]) p, l = fields[5].split("/") self.prefix = self.addr_type(p) self.prefixlen = self.max_prefixlen = int(l) @@ -670,21 +719,26 @@ class prefix(pdu): log("Ignoring line %r: %s" % (line, e)) raise IgnoreThisRecord -class ipv4_prefix(prefix): +@wire_pdu +class IPv4PrefixPDU(PrefixPDU): """ IPv4 flavor of a prefix. """ + pdu_type = 4 addr_type = v4addr -class ipv6_prefix(prefix): +@wire_pdu +class IPv6PrefixPDU(PrefixPDU): """ IPv6 flavor of a prefix. """ + pdu_type = 6 addr_type = v6addr -class router_key(pdu): +@wire_pdu +class RouterKeyPDU(PDU): """ Router Key PDU. """ @@ -754,7 +808,7 @@ class router_key(pdu): return self._pdu pdulen = self.header_struct.size + len(self.key) pdu = (self.header_struct.pack(self.version, - self.pdu_type, + self.pdu_type, announce if announce is not None else self.announce, pdulen, self.ski, @@ -778,7 +832,8 @@ class router_key(pdu): return self -class error_report(pdu): +@wire_pdu +class ErrorReportPDU(PDU): """ Error Report PDU. """ @@ -828,13 +883,14 @@ class error_report(pdu): """ Generate the wire format PDU for this error report. """ + if self._pdu is None: assert isinstance(self.errno, int) - assert not isinstance(self.errpdu, error_report) + assert not isinstance(self.errpdu, ErrorReportPDU) p = self.errpdu if p is None: p = "" - elif isinstance(p, pdu): + elif isinstance(p, PDU): p = p.to_pdu() assert isinstance(p, str) pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) @@ -862,17 +918,15 @@ class error_report(pdu): def serve(self, server): """ - Received an error_report from client. Not much we can do beyond + Received an ErrorReportPDU from client. Not much we can do beyond logging it, then killing the connection if error was fatal. """ + log(self) if self.errno in self.fatal: log("[Shutting down due to reported fatal protocol error]") sys.exit(1) -pdu.pdu_map = dict((p.pdu_type, p) for p in (ipv4_prefix, ipv6_prefix, serial_notify, serial_query, reset_query, - cache_response, end_of_data, cache_reset, router_key, error_report)) - class ROA(rpki.POW.ROA): """ @@ -909,7 +963,7 @@ class X509(rpki.POW.X509): yield asn -class pdu_set(list): +class PDUSet(list): """ Object representing a set of PDUs, that is, one versioned and (theoretically) consistant set of prefixes and router keys extracted @@ -919,13 +973,14 @@ class pdu_set(list): @classmethod def _load_file(cls, filename): """ - Low-level method to read pdu_set from a file. + Low-level method to read PDUSet from a file. """ + self = cls() f = open(filename, "rb") - r = read_buffer() + r = ReadBuffer() while True: - p = pdu.read_pdu(r) + p = PDU.read_pdu(r) while p is None: b = f.read(r.needed()) if b == "": @@ -940,7 +995,7 @@ class pdu_set(list): return ((a - b) % (1 << 32)) < (1 << 31) -class axfr_set(pdu_set): +class AXFRSet(PDUSet): """ Object representing a complete set of PDUs, that is, one versioned and (theoretically) consistant set of prefixes and router @@ -952,7 +1007,7 @@ class axfr_set(pdu_set): def parse_rcynic(cls, rcynic_dir): """ Parse ROAS and router certificates fetched (and validated!) by - rcynic to create a new axfr_set. + rcynic to create a new AXFRSet. In normal operation, we use os.walk() and the rpki.POW library to parse these data directly, but we can, if so instructed, use @@ -965,7 +1020,7 @@ class axfr_set(pdu_set): """ self = cls() - self.serial = timestamp.now() + self.serial = Timestamp.now() if scan_roas is None or scan_routercerts is None: for root, dirs, files in os.walk(rcynic_dir): @@ -973,7 +1028,7 @@ class axfr_set(pdu_set): if scan_roas is None and fn.endswith(".roa"): roa = ROA.derReadFile(os.path.join(root, fn)) asn = roa.getASID() - self.extend(prefix.from_roa(asn, roa_prefix) + self.extend(PrefixPDU.from_roa(asn, roa_prefix) for roa_prefix in roa.prefixes) if scan_routercerts is None and fn.endswith(".cer"): x = X509.derReadFile(os.path.join(root, fn)) @@ -981,7 +1036,7 @@ class axfr_set(pdu_set): if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: ski = x.getSKI() key = x.getPublicKey().derWritePublic() - self.extend(router_key.from_certificate(asn, ski, key) + self.extend(RouterKeyPDU.from_certificate(asn, ski, key) for asn in x.asns) if scan_roas is not None: @@ -990,7 +1045,7 @@ class axfr_set(pdu_set): for line in p.stdout: line = line.split() asn = line[1] - self.extend(prefix.from_text(asn, addr) for addr in line[2:]) + self.extend(PrefixPDU.from_text(asn, addr) for addr in line[2:]) except OSError, e: sys.exit("Could not run %s: %s" % (scan_roas, e)) @@ -1001,7 +1056,7 @@ class axfr_set(pdu_set): line = line.split() gski = line[0] key = line[-1] - self.extend(router_key.from_text(asn, gski, key) for asn in line[1:-1]) + self.extend(RouterKeyPDU.from_text(asn, gski, key) for asn in line[1:-1]) except OSError, e: sys.exit("Could not run %s: %s" % (scan_routercerts, e)) @@ -1014,25 +1069,28 @@ class axfr_set(pdu_set): @classmethod def load(cls, filename): """ - Load an axfr_set from a file, parse filename to obtain serial. + Load an AXFRSet from a file, parse filename to obtain serial. """ + fn1, fn2 = os.path.basename(filename).split(".") assert fn1.isdigit() and fn2 == "ax" self = cls._load_file(filename) - self.serial = timestamp(fn1) + self.serial = Timestamp(fn1) return self def filename(self): """ - Generate filename for this axfr_set. + Generate filename for this AXFRSet. """ + return "%d.ax" % self.serial @classmethod def load_current(cls): """ - Load current axfr_set. Return None if can't. + Load current AXFRSet. Return None if can't. """ + serial = read_current()[0] if serial is None: return None @@ -1043,8 +1101,9 @@ class axfr_set(pdu_set): def save_axfr(self): """ - Write axfr__set to file with magic filename. + Write AXFRSet to file with magic filename. """ + f = open(self.filename(), "wb") for p in self: f.write(p.to_pdu()) @@ -1055,6 +1114,7 @@ class axfr_set(pdu_set): Destroy old data files, presumably because our nonce changed and the old serial numbers are no longer valid. """ + for i in glob.iglob("*.ix.*"): os.unlink(i) for i in glob.iglob("*.ax"): @@ -1067,6 +1127,7 @@ class axfr_set(pdu_set): necessary. Creating a new nonce triggers cleanup of old state, as the new nonce invalidates all old serial numbers. """ + old_serial, nonce = read_current() if old_serial is None or self.seq_ge(old_serial, self.serial): blather("Creating new nonce and deleting stale data") @@ -1076,11 +1137,12 @@ class axfr_set(pdu_set): def save_ixfr(self, other): """ - Comparing this axfr_set with an older one and write the resulting - ixfr_set to file with magic filename. Since we store pdu_sets + Comparing this AXFRSet with an older one and write the resulting + IXFRSet to file with magic filename. Since we store PDUSets in sorted order, computing the difference is a trivial linear comparison. """ + f = open("%d.ix.%d" % (self.serial, other.serial), "wb") old = other new = self @@ -1105,8 +1167,9 @@ class axfr_set(pdu_set): def show(self): """ - Print this axfr_set. + Print this AXFRSet. """ + blather("# AXFR %d (%s)" % (self.serial, self.serial)) for p in self: blather(p) @@ -1126,7 +1189,7 @@ class axfr_set(pdu_set): self.serial = None for line in cls.read_bgpdump(filename): try: - pfx = prefix.from_bgpdump(line, rib_dump = True) + pfx = PrefixPDU.from_bgpdump(line, rib_dump = True) except IgnoreThisRecord: continue self.append(pfx) @@ -1143,7 +1206,7 @@ class axfr_set(pdu_set): assert os.path.basename(filename).startswith("updates.") for line in self.read_bgpdump(filename): try: - pfx = prefix.from_bgpdump(line, rib_dump = False) + pfx = PrefixPDU.from_bgpdump(line, rib_dump = False) except IgnoreThisRecord: continue announce = pfx.announce @@ -1157,7 +1220,7 @@ class axfr_set(pdu_set): del self[i] self.serial = pfx.timestamp -class ixfr_set(pdu_set): +class IXFRSet(PDUSet): """ Object representing an incremental set of PDUs, that is, the differences between one versioned and (theoretically) consistant set @@ -1169,31 +1232,34 @@ class ixfr_set(pdu_set): @classmethod def load(cls, filename): """ - Load an ixfr_set from a file, parse filename to obtain serials. + Load an IXFRSet from a file, parse filename to obtain serials. """ + fn1, fn2, fn3 = os.path.basename(filename).split(".") assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() self = cls._load_file(filename) - self.from_serial = timestamp(fn3) - self.to_serial = timestamp(fn1) + self.from_serial = Timestamp(fn3) + self.to_serial = Timestamp(fn1) return self def filename(self): """ - Generate filename for this ixfr_set. + Generate filename for this IXFRSet. """ + return "%d.ix.%d" % (self.to_serial, self.from_serial) def show(self): """ - Print this ixfr_set. + Print this IXFRSet. """ + blather("# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial, self.to_serial, self.to_serial)) for p in self: blather(p) -class file_producer(object): +class FileProducer(object): """ File-based producer object for asynchat. """ @@ -1205,7 +1271,7 @@ class file_producer(object): def more(self): return self.handle.read(self.buffersize) -class pdu_channel(asynchat.async_chat): +class PDUChannel(asynchat.async_chat): """ asynchat subclass that understands our PDUs. This just handles network I/O. Specific engines (client, server) should be subclasses @@ -1215,17 +1281,18 @@ class pdu_channel(asynchat.async_chat): def __init__(self, conn = None): asynchat.async_chat.__init__(self, conn) - self.reader = read_buffer() + self.reader = ReadBuffer() def start_new_pdu(self): """ Start read of a new PDU. """ + try: - p = pdu.read_pdu(self.reader) + p = PDU.read_pdu(self.reader) while p is not None: self.deliver_pdu(p) - p = pdu.read_pdu(self.reader) + p = PDU.read_pdu(self.reader) except PDUException, e: self.push_pdu(e.make_error_report()) self.close_when_done() @@ -1237,13 +1304,15 @@ class pdu_channel(asynchat.async_chat): """ Collect data into the read buffer. """ + self.reader.put(data) - + def found_terminator(self): """ Got requested data, see if we now have a PDU. If so, pass it along, then restart cycle for a new PDU. """ + p = self.reader.retry() if p is None: self.set_terminator(self.reader.needed()) @@ -1255,6 +1324,7 @@ class pdu_channel(asynchat.async_chat): """ Write PDU to stream. """ + try: self.push(pdu.to_pdu()) except OSError, e: @@ -1265,8 +1335,9 @@ class pdu_channel(asynchat.async_chat): """ Write content of a file to stream. """ + try: - self.push_with_producer(file_producer(f, self.ac_out_buffer_size)) + self.push_with_producer(FileProducer(f, self.ac_out_buffer_size)) except OSError, e: if e.errno != errno.EAGAIN: raise @@ -1275,18 +1346,21 @@ class pdu_channel(asynchat.async_chat): """ Intercept asyncore's logging. """ + log(msg) def log_info(self, msg, tag = "info"): """ Intercept asynchat's logging. """ + log("asynchat: %s: %s" % (tag, msg)) def handle_error(self): """ Handle errors caught by asyncore main loop. """ + c, e = sys.exc_info()[:2] if backtrace_on_exceptions or e == 0: for line in traceback.format_exc().splitlines(): @@ -1300,8 +1374,9 @@ class pdu_channel(asynchat.async_chat): """ Kludge to plug asyncore.file_dispatcher into asynchat. Call from subclass's __init__() method, after calling - pdu_channel.__init__(), and don't read this on a full stomach. + PDUChannel.__init__(), and don't read this on a full stomach. """ + self.connected = True self._fileno = fd self.socket = asyncore.file_wrapper(fd) @@ -1314,10 +1389,11 @@ class pdu_channel(asynchat.async_chat): """ Exit when channel closed. """ + asynchat.async_chat.handle_close(self) sys.exit(0) -class server_write_channel(pdu_channel): +class ServerWriteChannel(PDUChannel): """ Kludge to deal with ssh's habit of sometimes (compile time option) invoking us with two unidirectional pipes instead of one @@ -1330,18 +1406,20 @@ class server_write_channel(pdu_channel): """ Set up stdout. """ - pdu_channel.__init__(self) + + PDUChannel.__init__(self) self.init_file_dispatcher(sys.stdout.fileno()) def readable(self): """ This channel is never readable. """ + return False -class server_channel(pdu_channel): +class ServerChannel(PDUChannel): """ - Server protocol engine, handles upcalls from pdu_channel to + Server protocol engine, handles upcalls from PDUChannel to implement protocol logic. """ @@ -1350,9 +1428,10 @@ class server_channel(pdu_channel): Set up stdin and stdout as connection and start listening for first PDU. """ - pdu_channel.__init__(self) + + PDUChannel.__init__(self) self.init_file_dispatcher(sys.stdin.fileno()) - self.writer = server_write_channel() + self.writer = ServerWriteChannel() self.get_serial() self.start_new_pdu() @@ -1360,36 +1439,42 @@ class server_channel(pdu_channel): """ This channel is never writable. """ + return False def push(self, data): """ Redirect to writer channel. """ + return self.writer.push(data) def push_with_producer(self, producer): """ Redirect to writer channel. """ + return self.writer.push_with_producer(producer) def push_pdu(self, pdu): """ Redirect to writer channel. """ + return self.writer.push_pdu(pdu) def push_file(self, f): """ Redirect to writer channel. """ + return self.writer.push_file(f) def deliver_pdu(self, pdu): """ Handle received PDU. """ + pdu.serve(self) def get_serial(self): @@ -1399,6 +1484,7 @@ class server_channel(pdu_channel): happen, but maybe we got started in server mode while the cronjob mode instance is still building its database. """ + self.current_serial, self.current_nonce = read_current() return self.current_serial @@ -1406,6 +1492,7 @@ class server_channel(pdu_channel): """ Check for a new serial number. """ + old_serial = self.current_serial return old_serial != self.get_serial() @@ -1413,14 +1500,15 @@ class server_channel(pdu_channel): """ Cronjob instance kicked us, send a notify message. """ + if self.check_serial() is not None: - self.push_pdu(serial_notify(serial = self.current_serial, nonce = self.current_nonce)) + self.push_pdu(SerialNotifyPDU(serial = self.current_serial, nonce = self.current_nonce)) else: log("Cronjob kicked me without a valid current serial number") -class client_channel(pdu_channel): +class ClientChannel(PDUChannel): """ - Client protocol engine, handles upcalls from pdu_channel. + Client protocol engine, handles upcalls from PDUChannel. """ current_serial = None @@ -1435,7 +1523,7 @@ class client_channel(pdu_channel): self.proc = proc self.host = host self.port = port - pdu_channel.__init__(self, conn = sock) + PDUChannel.__init__(self, conn = sock) self.start_new_pdu() @classmethod @@ -1443,6 +1531,7 @@ class client_channel(pdu_channel): """ Set up ssh connection and start listening for first PDU. """ + args = ("ssh", "-p", port, "-s", host, "rpki-rtr") blather("[Running ssh: %s]" % " ".join(args)) s = socket.socketpair() @@ -1457,6 +1546,7 @@ class client_channel(pdu_channel): """ Set up TCP connection and start listening for first PDU. """ + blather("[Starting raw TCP connection to %s:%s]" % (host, port)) try: addrinfo = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM) @@ -1486,6 +1576,7 @@ class client_channel(pdu_channel): """ Set up loopback connection and start listening for first PDU. """ + s = socket.socketpair() blather("[Using direct subprocess kludge for testing]") argv = [sys.executable, sys.argv[0], "--server"] @@ -1508,6 +1599,7 @@ class client_channel(pdu_channel): properly (eg, gnutls-cli, or stunnel's client mode if that works for such purposes this week). """ + args = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (host, port)) blather("[Running: %s]" % " ".join(args)) s = socket.socketpair() @@ -1521,6 +1613,7 @@ class client_channel(pdu_channel): Set up an SQLite database to contain the table we receive. If necessary, we will create the database. """ + import sqlite3 missing = not os.path.exists(sqlname) self.sql = sqlite3.connect(sqlname, detect_types = sqlite3.PARSE_DECLTYPES) @@ -1572,8 +1665,9 @@ class client_channel(pdu_channel): def cache_reset(self): """ - Handle cache_reset actions. + Handle CacheResetPDU actions. """ + self.current_serial = None if self.sql: cur = self.sql.cursor() @@ -1582,8 +1676,9 @@ class client_channel(pdu_channel): def end_of_data(self, serial, nonce): """ - Handle end_of_data actions. + Handle EndOfDataPDU actions. """ + self.current_serial = serial self.current_nonce = nonce if self.sql: @@ -1595,6 +1690,7 @@ class client_channel(pdu_channel): """ Handle one prefix PDU. """ + if self.sql: values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) if prefix.announce: @@ -1630,14 +1726,16 @@ class client_channel(pdu_channel): """ Handle received PDU. """ + pdu.consume(self) def push_pdu(self, pdu): """ Log outbound PDU then write it to stream. """ + blather(pdu) - pdu_channel.push_pdu(self, pdu) + PDUChannel.push_pdu(self, pdu) def cleanup(self): """ @@ -1645,6 +1743,7 @@ class client_channel(pdu_channel): well, child will have exited already before this method is called, but we may need to whack it with a stick if something breaks. """ + if self.proc is not None and self.proc.returncode is None: try: os.kill(self.proc.pid, self.killsig) @@ -1655,10 +1754,11 @@ class client_channel(pdu_channel): """ Intercept close event so we can log it, then shut down. """ + blather("Server closed channel") - pdu_channel.handle_close(self) + PDUChannel.handle_close(self) -class kickme_channel(asyncore.dispatcher): +class KickmeChannel(asyncore.dispatcher): """ asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to kick servers when it's time to send notify PDUs to clients. @@ -1682,18 +1782,21 @@ class kickme_channel(asyncore.dispatcher): """ This socket is read-only, never writable. """ + return False def handle_connect(self): """ Ignore connect events (not very useful on datagram socket). """ + pass def handle_read(self): """ Handle receipt of a datagram. """ + data = self.recv(512) self.server.notify(data) @@ -1701,6 +1804,7 @@ class kickme_channel(asyncore.dispatcher): """ Clean up this dispatcher's socket. """ + self.close() try: os.unlink(self.sockname) @@ -1711,18 +1815,21 @@ class kickme_channel(asyncore.dispatcher): """ Intercept asyncore's logging. """ + log(msg) def log_info(self, msg, tag = "info"): """ Intercept asyncore's logging. """ + log("asyncore: %s: %s" % (tag, msg)) def handle_error(self): """ Handle errors caught by asyncore main loop. """ + c, e = sys.exc_info()[:2] if backtrace_on_exceptions or e == 0: for line in traceback.format_exc().splitlines(): @@ -1833,21 +1940,21 @@ def cronjob_main(argv): old_ixfrs = glob.glob("*.ix.*") current = read_current()[0] - cutoff = timestamp.now(-(24 * 60 * 60)) + cutoff = Timestamp.now(-(24 * 60 * 60)) for f in glob.iglob("*.ax"): - t = timestamp(int(f.split(".")[0])) + t = Timestamp(int(f.split(".")[0])) if t < cutoff and t != current: blather("# Deleting old file %s, timestamp %s" % (f, t)) os.unlink(f) - - pdus = axfr_set.parse_rcynic(argv[0]) - if pdus == axfr_set.load_current(): + + pdus = AXFRSet.parse_rcynic(argv[0]) + if pdus == AXFRSet.load_current(): blather("# No change, new version not needed") sys.exit() pdus.save_axfr() for axfr in glob.iglob("*.ax"): if axfr != pdus.filename(): - pdus.save_ixfr(axfr_set.load(axfr)) + pdus.save_ixfr(AXFRSet.load(axfr)) pdus.mark_current() blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial)) @@ -1877,12 +1984,12 @@ def show_main(argv): g = glob.glob("*.ax") g.sort() for f in g: - axfr_set.load(f).show() + AXFRSet.load(f).show() g = glob.glob("*.ix.*") g.sort() for f in g: - ixfr_set.load(f).show() + IXFRSet.load(f).show() def server_main(argv): """ @@ -1926,8 +2033,8 @@ def server_main(argv): sys.exit(e) kickme = None try: - server = server_channel() - kickme = kickme_channel(server = server) + server = ServerChannel() + kickme = KickmeChannel(server = server) asyncore.loop(timeout = None) except KeyboardInterrupt: sys.exit(0) @@ -1972,7 +2079,7 @@ def listener_tcp_main(argv): except: if listener is not None: listener.close() - listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) @@ -2048,11 +2155,11 @@ def client_main(argv): argv = ["loopback"] proto = argv[0] if proto == "loopback" and len(argv) in (1, 2): - constructor = client_channel.loopback + constructor = ClientChannel.loopback host, port = "", "" sqlname = None if len(argv) == 1 else argv[1] elif proto in ("ssh", "tcp", "tls") and len(argv) in (3, 4): - constructor = getattr(client_channel, proto) + constructor = getattr(ClientChannel, proto) host, port = argv[1:3] sqlname = None if len(argv) == 3 else argv[3] else: @@ -2064,9 +2171,9 @@ def client_main(argv): client.setup_sql(sqlname) while True: if client.current_serial is None or client.current_nonce is None: - client.push_pdu(reset_query()) + client.push_pdu(ResetQueryPDU()) else: - client.push_pdu(serial_query(serial = client.current_serial, nonce = client.current_nonce)) + client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce)) wakeup = time.time() + 600 while True: remaining = wakeup - time.time() @@ -2106,10 +2213,10 @@ def bgpdump_convert_main(argv): if filename.endswith(".ax"): blather("Reading %s" % filename) - db = axfr_set.load(filename) + db = AXFRSet.load(filename) elif os.path.basename(filename).startswith("ribs."): - db = axfr_set.parse_bgpdump_rib_dump(filename) + db = AXFRSet.parse_bgpdump_rib_dump(filename) db.save_axfr() elif not first: @@ -2127,7 +2234,7 @@ def bgpdump_convert_main(argv): for axfr in axfrs: blather("Loading %s" % axfr) - ax = axfr_set.load(axfr) + ax = AXFRSet.load(axfr) blather("Computing changes from %d (%s) to %d (%s)" % (ax.serial, ax.serial, db.serial, db.serial)) db.save_ixfr(ax) del ax @@ -2152,7 +2259,7 @@ def bgpdump_select_main(argv): try: head, sep, tail = os.path.basename(argv[0]).partition(".") if len(argv) == 1 and head.isdigit() and sep == "." and tail == "ax": - serial = timestamp(head) + serial = Timestamp(head) except: pass if serial is None: @@ -2166,7 +2273,7 @@ def bgpdump_select_main(argv): kick_all(serial) -class bgpsec_replay_clock(object): +class BGPDumpReplayClock(object): """ Internal clock for replaying BGP dump files. @@ -2182,7 +2289,7 @@ class bgpsec_replay_clock(object): """ def __init__(self): - self.timestamps = [timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")] + self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")] self.timestamps.sort() self.offset = self.timestamps[0] - int(time.time()) self.nonce = new_nonce() @@ -2191,7 +2298,7 @@ class bgpsec_replay_clock(object): return len(self.timestamps) > 0 def now(self): - return timestamp.now(self.offset) + return Timestamp.now(self.offset) def read_current(self): now = self.now() @@ -2238,11 +2345,11 @@ def bgpdump_server_main(argv): # method to our clock object. Fun stuff, huh? # global read_current - clock = bgpsec_replay_clock() + clock = BGPDumpReplayClock() read_current = clock.read_current # try: - server = server_channel() + server = ServerChannel() old_serial = server.get_serial() blather("[Starting at serial %d (%s)]" % (old_serial, old_serial)) while clock: |