diff options
-rwxr-xr-x | rp/rpki-rtr/rtr-origin | 588 |
1 files changed, 404 insertions, 184 deletions
diff --git a/rp/rpki-rtr/rtr-origin b/rp/rpki-rtr/rtr-origin index e127b2b2..b2b3d1ab 100755 --- a/rp/rpki-rtr/rtr-origin +++ b/rp/rpki-rtr/rtr-origin @@ -59,8 +59,10 @@ class Timestamp(int): Wrapper around time module. """ - def __new__(cls, x): - return int.__new__(cls, x) + def __new__(cls, t): + # http://stackoverflow.com/questions/7471255/pythons-super-and-new-confused-me + #return int.__new__(cls, t) + return super(Timestamp, cls).__new__(cls, t) @classmethod def now(cls, delta = 0): @@ -70,39 +72,36 @@ class Timestamp(int): return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self)) -def read_current(): +def read_current(version): """ Read current serial number and nonce. Return None for both if serial and nonce not recorded. For backwards compatibility, treat file containing just a serial number as having a nonce of zero. """ + if version is None: + return None, None try: - f = open("current", "r") - values = tuple(int(s) for s in f.read().split()) - f.close() + with open("current.v%d" % version, "r") as f: + values = tuple(int(s) for s in f.read().split()) return values[0], values[1] except IndexError: return values[0], 0 except IOError: return None, None -def write_current(serial, nonce): + +def write_current(serial, nonce, version): """ Write serial number and nonce. """ - tmpfn = "current.%d.tmp" % os.getpid() - try: - f = open(tmpfn, "w") + assert version in PDU.version_map + curfn = "current.v%d" % version + tmpfn = curfn + "%d.tmp" % os.getpid() + with open(tmpfn, "w") as f: f.write("%d %d\n" % (serial, nonce)) - f.close() - os.rename(tmpfn, "current") - finally: - try: - os.unlink(tmpfn) - except: - pass + os.rename(tmpfn, curfn) def new_nonce(): @@ -121,10 +120,14 @@ def new_nonce(): class ReadBuffer(object): """ Wrapper around synchronous/asynchronous read state. + + This also handles tracking the current protocol version, + because it has to go somewhere and there's no better place. """ def __init__(self): self.buffer = "" + self.version = None def update(self, need, callback): """ @@ -133,6 +136,13 @@ class ReadBuffer(object): self.need = need self.callback = callback + return self.retry() + + def retry(self): + """ + Try dispatching to the callback again. + """ + return self.callback(self) def available(self): @@ -172,12 +182,20 @@ class ReadBuffer(object): self.buffer += b - def retry(self): + def check_version(self, version): """ - Try dispatching to the callback again. + Track version number of PDUs read from this buffer. + Once set, the version must not change. """ - return self.callback(self) + if self.version is not None and version != self.version: + raise CorruptData( + "Received PDU version %d, expected %d" % (version, self.version)) + if self.version is None and version not in PDU.version_map: + raise UnsupportedProtocolVersion( + "Received PDU version %d, known versions %s" % (version, ", ".PDU.version_map.iterkeys())) + self.version = version + class PDUException(Exception): """ @@ -188,7 +206,7 @@ class PDUException(Exception): """ def __init__(self, msg = None, pdu = None): - Exception.__init__(self) + super(PDUException, self).__init__() assert msg is None or isinstance(msg, (str, unicode)) self.error_report_msg = msg self.error_report_pdu = pdu @@ -196,10 +214,11 @@ class PDUException(Exception): def __str__(self): return self.error_report_msg or self.__class__.__name__ - def make_error_report(self): - return ErrorReportPDU(errno = self.error_report_code, - errmsg = self.error_report_msg, - errpdu = self.error_report_pdu) + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) class UnsupportedProtocolVersion(PDUException): error_report_code = 4 @@ -211,41 +230,51 @@ class CorruptData(PDUException): error_report_code = 0 -def wire_pdu(cls): +def wire_pdu(cls, versions = None): """ - Class decorator to add a PDU class to the set of known PDUs. - - In the long run, this decorator may take additional arguments - specifying which protocol version(s) use this particular PDU, - but we're not there yet. + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. """ - PDU.pdu_map[cls.pdu_type] = cls + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls return cls +def wire_pdu_only(*versions): + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ + + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) + class PDU(object): """ Object representing a generic PDU in the rpki-router protocol. Real PDUs are subclasses of this class. """ - pdu_map = {} # Updated by @wire_pdu - - version = 0 # Protocol version + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu _pdu = None # Cached when first generated - header_struct = struct.Struct("!BBHL") + header_struct = struct.Struct("!BB2xL") + + def __init__(self, version): + assert version in self.version_map + self.version = version def __cmp__(self, other): return cmp(self.to_pdu(), other.to_pdu()) - def check(self): - """ - Check attributes to make sure they're within range. - """ + @property + def default_version(self): + return max(self.version_map.iterkeys()) + def check(self): pass @classmethod @@ -257,23 +286,21 @@ class PDU(object): if not reader.ready(): return None assert reader.available() >= cls.header_struct.size - version, pdu_type, whatever, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - if version != cls.version: - raise UnsupportedProtocolVersion( - "Received PDU version %d, expected %d" % (version, cls.version)) - if pdu_type not in cls.pdu_map: + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: raise UnsupportedPDUType( "Received unsupported PDU type %d" % pdu_type) if length < 8: raise CorruptData( "Received PDU with length %d, which is too short to be valid" % length) - self = cls.pdu_map[pdu_type]() + self = cls.version_map[version][pdu_type](version = version) return reader.update(need = length, callback = self.got_pdu) def consume(self, client): """ Handle results in test client. Default behavior is just to print - out the PDU. + out the PDU; data PDU subclasses may override this. """ blather(self) @@ -283,17 +310,25 @@ class PDU(object): Send a content of a file as a cache response. Caller should catch IOError. """ + fn2 = os.path.splitext(filename)[1] + assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version + f = open(filename, "rb") - server.push_pdu(CacheResponsePDU(nonce = server.current_nonce)) + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) server.push_file(f) - server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce)) def send_nodata(self, server): """ Send a nodata error. """ - server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["No Data Available"], errpdu = self)) + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["No Data Available"], + errpdu = self)) class PDUWithSerial(PDU): """ @@ -302,7 +337,8 @@ class PDUWithSerial(PDU): header_struct = struct.Struct("!BBHLL") - def __init__(self, serial = None, nonce = None): + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) if serial is not None: assert isinstance(serial, int) self.serial = serial @@ -328,6 +364,7 @@ class PDUWithSerial(PDU): return None b = reader.get(self.header_struct.size) version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type if length != 12: raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() @@ -340,7 +377,8 @@ class PDUWithNonce(PDU): header_struct = struct.Struct("!BBHL") - def __init__(self, nonce = None): + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) if nonce is not None: assert isinstance(nonce, int) self.nonce = nonce @@ -362,6 +400,7 @@ class PDUWithNonce(PDU): return None b = reader.get(self.header_struct.size) version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type if length != 8: raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() @@ -391,6 +430,7 @@ class PDUEmpty(PDU): return None b = reader.get(self.header_struct.size) version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type if zero != 0: raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) if length != 8: @@ -414,9 +454,11 @@ class SerialNotifyPDU(PDUWithSerial): blather(self) if client.current_serial is None or client.current_nonce != self.nonce: - client.push_pdu(ResetQueryPDU()) + client.push_pdu(ResetQueryPDU(version = client.version)) elif self.serial != client.current_serial: - client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce)) + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.current_serial, + nonce = client.current_nonce)) else: blather("[Notify did not change serial number, ignoring]") @@ -428,6 +470,9 @@ class SerialQueryPDU(PDUWithSerial): pdu_type = 1 + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + def serve(self, server): """ Received a serial query, send incremental transfer in response. @@ -440,18 +485,21 @@ class SerialQueryPDU(PDUWithSerial): self.send_nodata(server) elif server.current_nonce != self.nonce: log("[Client requested wrong nonce, resetting client]") - server.push_pdu(CacheResetPDU()) + server.push_pdu(CacheResetPDU(version = server.version)) elif server.current_serial == self.serial: blather("[Client is already current, sending empty IXFR]") - server.push_pdu(CacheResponsePDU(nonce = server.current_nonce)) - server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce)) + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce)) elif disable_incrementals: - server.push_pdu(CacheResetPDU()) + server.push_pdu(CacheResetPDU(version = server.version)) else: try: - self.send_file(server, "%d.ix.%d" % (server.current_serial, self.serial)) + self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version)) except IOError: - server.push_pdu(CacheResetPDU()) + server.push_pdu(CacheResetPDU(version = server.version)) @wire_pdu class ResetQueryPDU(PDUEmpty): @@ -461,6 +509,9 @@ class ResetQueryPDU(PDUEmpty): pdu_type = 2 + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + def serve(self, server): """ Received a reset query, send full current state in response. @@ -471,11 +522,13 @@ class ResetQueryPDU(PDUEmpty): self.send_nodata(server) else: try: - fn = "%d.ax" % server.current_serial + fn = "%d.ax.v%d" % (server.current_serial, server.version) self.send_file(server, fn) except IOError: - server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["Internal Error"], - errpdu = self, errmsg = "Couldn't open %s" % fn)) + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["Internal Error"], + errpdu = self, + errmsg = "Couldn't open %s" % fn)) @wire_pdu class CacheResponsePDU(PDUWithNonce): @@ -495,21 +548,98 @@ class CacheResponsePDU(PDUWithNonce): blather("[Nonce changed, resetting]") client.cache_reset() -@wire_pdu -class EndOfDataPDU(PDUWithSerial): + +def EndOfDataPDU(version, *args, **kwargs): """ - End of Data PDU. + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ + + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError + + +@wire_pdu_only(0) +class EndOfDataPDUv0(PDUWithSerial): + """ + End of Data PDU, protocol version 0. """ pdu_type = 7 + # Default values, from the current RFC 6810 bis I-D. + # Putting these here lets us use them in our client API for both + # protocol versions, even though they can only be set in the + # protocol in version 1. + + refresh = 3600 + retry = 600 + expire = 7200 + def consume(self, client): """ Handle EndOfDataPDU response. """ blather(self) - client.end_of_data(self.serial, self.nonce) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + +@wire_pdu_only(1) +class EndOfDataPDUv1(EndOfDataPDUv0): + """ + End of Data PDU, protocol version 1. + """ + + header_struct = struct.Struct("!BBHLLLLL") + + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv1, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + if refresh is not None: + assert isinstance(refresh, int) + self.refresh = refresh + if retry is not None: + assert isinstance(retry, int) + self.retry = retry + if expire is not None: + assert isinstance(expire, int) + self.expire = expire + + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + @wire_pdu class CacheResetPDU(PDUEmpty): @@ -526,7 +656,7 @@ class CacheResetPDU(PDUEmpty): blather(self) client.cache_reset() - client.push_pdu(ResetQueryPDU()) + client.push_pdu(ResetQueryPDU(version = client.version)) class PrefixPDU(PDU): """ @@ -543,14 +673,14 @@ class PrefixPDU(PDU): asnum_struct = struct.Struct("!L") @staticmethod - def from_text(asnum, addr): + def from_text(version, asn, addr): """ Construct a prefix from its text form. """ cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU - self = cls() - self.asn = long(asnum) + self = cls(version = version) + self.asn = long(asn) p, l = addr.split("/") self.prefix = rpki.POW.IPAddress(p) if "-" in l: @@ -562,15 +692,15 @@ class PrefixPDU(PDU): return self @staticmethod - def from_roa(asnum, prefix_tuple): + def from_roa(version, asn, prefix_tuple): """ Construct a prefix from a ROA. """ address, length, maxlength = prefix_tuple cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU - self = cls() - self.asn = asnum + self = cls(version = version) + self.asn = asn self.prefix = address self.prefixlen = length self.max_prefixlen = length if maxlength is None else maxlength @@ -643,6 +773,7 @@ class PrefixPDU(PDU): b2 = reader.get(self.address_byte_count) b3 = reader.get(self.asnum_struct.size) version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type if length != len(b1) + len(b2) + len(b3): raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) self.prefix = rpki.POW.IPAddress.fromBytes(b2) @@ -712,7 +843,7 @@ class IPv6PrefixPDU(PrefixPDU): pdu_type = 6 address_byte_count = 16 -@wire_pdu +@wire_pdu_only(1) class RouterKeyPDU(PDU): """ Router Key PDU. @@ -723,13 +854,13 @@ class RouterKeyPDU(PDU): header_struct = struct.Struct("!BBBxL20sL") @classmethod - def from_text(cls, asnum, gski, key): + def from_text(cls, version, asn, gski, key): """ Construct a router key from its text form. """ - self = cls() - self.asn = long(asnum) + self = cls(version = version) + self.asn = long(asn) self.ski = base64.urlsafe_b64decode(gski + "=") self.key = base64.b64decode(key) self.announce = 1 @@ -737,13 +868,13 @@ class RouterKeyPDU(PDU): return self @classmethod - def from_certificate(cls, asnum, ski, key): + def from_certificate(cls, version, asn, ski, key): """ Construct a router key from a certificate. """ - self = cls() - self.asn = asnum + self = cls(version = version) + self.asn = asn self.ski = ski self.key = key self.announce = 1 @@ -799,6 +930,7 @@ class RouterKeyPDU(PDU): return None header = reader.get(self.header_struct.size) version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type remaining = length - self.header_struct.size if remaining <= 0: raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) @@ -836,7 +968,8 @@ class ErrorReportPDU(PDU): codes = dict((v, k) for k, v in errors.items()) - def __init__(self, errno = None, errpdu = None, errmsg = None): + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) assert errno is None or errno in self.errors self.errno = errno self.errpdu = errpdu @@ -879,6 +1012,7 @@ class ErrorReportPDU(PDU): return None header = reader.get(self.header_struct.size) version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type remaining = length - self.header_struct.size self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) @@ -903,7 +1037,7 @@ class ErrorReportPDU(PDU): sys.exit(1) -class ROA(rpki.POW.ROA): +class ROA(rpki.POW.ROA): # pylint: disable=W0232 """ Minor additions to rpki.POW.ROA. """ @@ -924,7 +1058,7 @@ class ROA(rpki.POW.ROA): for p in v6: yield p -class X509(rpki.POW.X509): +class X509(rpki.POW.X509): # pylint: disable=W0232 """ Minor additions to rpki.POW.X509. """ @@ -945,13 +1079,18 @@ class PDUSet(list): from rcynic's output. """ + def __init__(self, version): + assert version in PDU.version_map + super(PDUSet, self).__init__() + self.version = version + @classmethod - def _load_file(cls, filename): + def _load_file(cls, filename, version): """ Low-level method to read PDUSet from a file. """ - self = cls() + self = cls(version = version) f = open(filename, "rb") r = ReadBuffer() while True: @@ -963,6 +1102,7 @@ class PDUSet(list): return self r.put(b) p = r.retry() + assert p.version == self.version self.append(p) @staticmethod @@ -979,7 +1119,7 @@ class AXFRSet(PDUSet): """ @classmethod - def parse_rcynic(cls, rcynic_dir): + def parse_rcynic(cls, rcynic_dir, version): """ Parse ROAS and router certificates fetched (and validated!) by rcynic to create a new AXFRSet. @@ -994,24 +1134,26 @@ class AXFRSet(PDUSet): can make this one a bit simpler and faster. """ - self = cls() + self = cls(version = version) self.serial = Timestamp.now() - if scan_roas is None or scan_routercerts is None: - for root, dirs, files in os.walk(rcynic_dir): + include_routercerts = RouterKeyPDU.pdu_type in PDU.version_map[version] + + if scan_roas is None or (scan_routercerts is None and include_routercerts): + for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612 for fn in files: if scan_roas is None and fn.endswith(".roa"): roa = ROA.derReadFile(os.path.join(root, fn)) asn = roa.getASID() - self.extend(PrefixPDU.from_roa(asn, roa_prefix) - for roa_prefix in roa.prefixes) - if scan_routercerts is None and fn.endswith(".cer"): + self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple) + for prefix_tuple in roa.prefixes) + if include_routercerts and scan_routercerts is None and fn.endswith(".cer"): x = X509.derReadFile(os.path.join(root, fn)) eku = x.getEKU() if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: ski = x.getSKI() key = x.getPublicKey().derWritePublic() - self.extend(RouterKeyPDU.from_certificate(asn, ski, key) + self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key) for asn in x.asns) if scan_roas is not None: @@ -1020,18 +1162,20 @@ class AXFRSet(PDUSet): for line in p.stdout: line = line.split() asn = line[1] - self.extend(PrefixPDU.from_text(asn, addr) for addr in line[2:]) + self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr) + for addr in line[2:]) except OSError, e: sys.exit("Could not run %s: %s" % (scan_roas, e)) - if scan_routercerts is not None: + if include_routercerts and scan_routercerts is not None: try: p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE) for line in p.stdout: line = line.split() gski = line[0] key = line[-1] - self.extend(RouterKeyPDU.from_text(asn, gski, key) for asn in line[1:-1]) + self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key) + for asn in line[1:-1]) except OSError, e: sys.exit("Could not run %s: %s" % (scan_routercerts, e)) @@ -1044,12 +1188,13 @@ class AXFRSet(PDUSet): @classmethod def load(cls, filename): """ - Load an AXFRSet from a file, parse filename to obtain serial. + Load an AXFRSet from a file, parse filename to obtain version and serial. """ - fn1, fn2 = os.path.basename(filename).split(".") - assert fn1.isdigit() and fn2 == "ax" - self = cls._load_file(filename) + fn1, fn2, fn3 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit() + version = int(fn3[1:]) + self = cls._load_file(filename, version) self.serial = Timestamp(fn1) return self @@ -1058,19 +1203,19 @@ class AXFRSet(PDUSet): Generate filename for this AXFRSet. """ - return "%d.ax" % self.serial + return "%d.ax.v%d" % (self.serial, self.version) @classmethod - def load_current(cls): + def load_current(cls, version): """ Load current AXFRSet. Return None if can't. """ - serial = read_current()[0] + serial = read_current(version)[0] if serial is None: return None try: - return cls.load("%d.ax" % serial) + return cls.load("%d.ax.v%d" % (serial, version)) except IOError: return None @@ -1090,9 +1235,9 @@ class AXFRSet(PDUSet): the old serial numbers are no longer valid. """ - for i in glob.iglob("*.ix.*"): + for i in glob.iglob("*.ix.*.v%d" % self.version): os.unlink(i) - for i in glob.iglob("*.ax"): + for i in glob.iglob("*.ax.v%d" % self.version): if i != self.filename(): os.unlink(i) @@ -1103,12 +1248,12 @@ class AXFRSet(PDUSet): the new nonce invalidates all old serial numbers. """ - old_serial, nonce = read_current() + old_serial, nonce = read_current(self.version) if old_serial is None or self.seq_ge(old_serial, self.serial): blather("Creating new nonce and deleting stale data") nonce = new_nonce() self.destroy_old_data() - write_current(self.serial, nonce) + write_current(self.serial, nonce, self.version) def save_ixfr(self, other): """ @@ -1118,7 +1263,7 @@ class AXFRSet(PDUSet): comparison. """ - f = open("%d.ix.%d" % (self.serial, other.serial), "wb") + f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb") old = other new = self len_old = len(old) @@ -1145,7 +1290,7 @@ class AXFRSet(PDUSet): Print this AXFRSet. """ - blather("# AXFR %d (%s)" % (self.serial, self.serial)) + blather("# AXFR %d (%s) v%d" % (self.serial, self.serial, self.version)) for p in self: blather(p) @@ -1207,12 +1352,13 @@ class IXFRSet(PDUSet): @classmethod def load(cls, filename): """ - Load an IXFRSet from a file, parse filename to obtain serials. + Load an IXFRSet from a file, parse filename to obtain version and serials. """ - fn1, fn2, fn3 = os.path.basename(filename).split(".") - assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() - self = cls._load_file(filename) + fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit() + version = int(fn4[1:]) + self = cls._load_file(filename, version) self.from_serial = Timestamp(fn3) self.to_serial = Timestamp(fn1) return self @@ -1222,15 +1368,16 @@ class IXFRSet(PDUSet): Generate filename for this IXFRSet. """ - return "%d.ix.%d" % (self.to_serial, self.from_serial) + return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version) def show(self): """ Print this IXFRSet. """ - blather("# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial, - self.to_serial, self.to_serial)) + blather("# IXFR %d (%s) -> %d (%s) v%d" % (self.from_serial, self.from_serial, + self.to_serial, self.to_serial, + self.version)) for p in self: blather(p) @@ -1246,7 +1393,7 @@ class FileProducer(object): def more(self): return self.handle.read(self.buffersize) -class PDUChannel(asynchat.async_chat): +class PDUChannel(asynchat.async_chat, object): """ asynchat subclass that understands our PDUs. This just handles network I/O. Specific engines (client, server) should be subclasses @@ -1255,9 +1402,17 @@ class PDUChannel(asynchat.async_chat): """ def __init__(self, conn = None): - asynchat.async_chat.__init__(self, conn) + asynchat.async_chat.__init__(self, conn) # Old-style class self.reader = ReadBuffer() + @property + def version(self): + return self.reader.version + + @version.setter + def version(self, version): + self.reader.check_version(version) + def start_new_pdu(self): """ Start read of a new PDU. @@ -1269,7 +1424,7 @@ class PDUChannel(asynchat.async_chat): self.deliver_pdu(p) p = PDU.read_pdu(self.reader) except PDUException, e: - self.push_pdu(e.make_error_report()) + self.push_pdu(e.make_error_report(version = self.version)) self.close_when_done() else: assert not self.reader.ready() @@ -1382,7 +1537,7 @@ class ServerWriteChannel(PDUChannel): Set up stdout. """ - PDUChannel.__init__(self) + super(ServerWriteChannel, self).__init__() self.init_file_dispatcher(sys.stdout.fileno()) def readable(self): @@ -1392,6 +1547,7 @@ class ServerWriteChannel(PDUChannel): return False + class ServerChannel(PDUChannel): """ Server protocol engine, handles upcalls from PDUChannel to @@ -1404,7 +1560,7 @@ class ServerChannel(PDUChannel): first PDU. """ - PDUChannel.__init__(self) + super(ServerChannel, self).__init__() self.init_file_dispatcher(sys.stdin.fileno()) self.writer = ServerWriteChannel() self.get_serial() @@ -1460,7 +1616,7 @@ class ServerChannel(PDUChannel): mode instance is still building its database. """ - self.current_serial, self.current_nonce = read_current() + self.current_serial, self.current_nonce = read_current(self.version) return self.current_serial def check_serial(self): @@ -1473,13 +1629,21 @@ class ServerChannel(PDUChannel): def notify(self, data = None): """ - Cronjob instance kicked us, send a notify message. + Cronjob instance kicked us: check whether our serial number has + changed, and send a notify message if so. + + We have to check rather than just blindly notifying when kicked + because the cronjob instance has no good way of knowing which + protocol version we're running, thus has no good way of knowing + whether we care about a particular change set or not. """ - if self.check_serial() is not None: - self.push_pdu(SerialNotifyPDU(serial = self.current_serial, nonce = self.current_nonce)) + if self.check_serial(): + self.push_pdu(SerialNotifyPDU(version = self.version, + serial = self.current_serial, + nonce = self.current_nonce)) else: - log("Cronjob kicked me without a valid current serial number") + log("Cronjob kicked me but I see no serial change, ignoring") class ClientChannel(PDUChannel): """ @@ -1493,12 +1657,21 @@ class ClientChannel(PDUChannel): port = None cache_id = None + # For initial test purposes, let's use the minimum allowed values + # from the RFC 6810 bis I-D as the initial defaults for refresh and + # retry, and the maximum allowed for expire; these will be overriden + # as soon as we receive an EndOfDataPDU. + # + refresh = 120 + retry = 120 + expire = 172800 + def __init__(self, sock, proc, killsig, host, port): self.killsig = killsig self.proc = proc self.host = host self.port = port - PDUChannel.__init__(self, conn = sock) + super(ClientChannel, self).__init__(conn = sock) self.start_new_pdu() @classmethod @@ -1529,7 +1702,7 @@ class ClientChannel(PDUChannel): blather("[socket.getaddrinfo() failed: %s]" % e) else: for ai in addrinfo: - af, socktype, proto, cn, sa = ai + af, socktype, proto, cn, sa = ai # pylint: disable=W0612 blather("[Trying addr %s port %s]" % sa[:2]) try: s = socket.socket(af, socktype, proto) @@ -1601,9 +1774,13 @@ class ClientChannel(PDUChannel): cache_id INTEGER PRIMARY KEY NOT NULL, host TEXT NOT NULL, port TEXT NOT NULL, + version INTEGER, nonce INTEGER, serial INTEGER, updated INTEGER, + refresh INTEGER, + retry INTEGER, + expire INTEGER, UNIQUE (host, port))''') cur.execute(''' CREATE TABLE prefix ( @@ -1629,14 +1806,26 @@ class ClientChannel(PDUChannel): UNIQUE (cache_id, asn, ski), UNIQUE (cache_id, asn, key))''') - cur.execute("SELECT cache_id, nonce, serial FROM cache WHERE host = ? AND port = ?", + cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire " + "FROM cache WHERE host = ? AND port = ?", (self.host, self.port)) try: - self.cache_id, self.current_nonce, self.current_serial = cur.fetchone() + self.cache_id, version, self.current_nonce, self.current_serial, refresh, retry, expire = cur.fetchone() + if version is not None: + self.version = version + if refresh is not None: + self.refresh = refresh + if retry is not None: + self.retry = retry + if expire is not None: + self.expire = expire except TypeError: cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) self.cache_id = cur.lastrowid self.sql.commit() + log("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s]" % ( + self.cache_id, self.version, self.current_nonce, self.current_serial, self.refresh, self.retry, self.expire)) + def cache_reset(self): """ @@ -1645,20 +1834,34 @@ class ClientChannel(PDUChannel): self.current_serial = None if self.sql: + # + # For some reason there was no commit here. Dunno why. + # See if adding one breaks anything.... + # cur = self.sql.cursor() cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) - cur.execute("UPDATE cache SET serial = NULL WHERE cache_id = ?", (self.cache_id,)) + cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,)) + cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id)) + self.sql.commit() - def end_of_data(self, serial, nonce): + def end_of_data(self, version, serial, nonce, refresh, retry, expire): """ Handle EndOfDataPDU actions. """ - self.current_serial = serial - self.current_nonce = nonce + assert version == self.version + self.current_serial = serial + self.current_nonce = nonce + self.refresh = refresh + self.retry = retry + self.expire = expire if self.sql: - self.sql.execute("UPDATE cache SET serial = ?, nonce = ?, updated = datetime('now') WHERE cache_id = ?", - (serial, nonce, self.cache_id)) + self.sql.execute("UPDATE cache SET" + " version = ?, serial = ?, nonce = ?," + " refresh = ?, retry = ?, expire = ?," + " updated = datetime('now') " + "WHERE cache_id = ?", + (version, serial, nonce, refresh, retry, expire, self.cache_id)) self.sql.commit() def consume_prefix(self, prefix): @@ -1733,14 +1936,14 @@ class ClientChannel(PDUChannel): blather("Server closed channel") PDUChannel.handle_close(self) -class KickmeChannel(asyncore.dispatcher): +class KickmeChannel(asyncore.dispatcher, object): """ asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to kick servers when it's time to send notify PDUs to clients. """ def __init__(self, server): - asyncore.dispatcher.__init__(self) + asyncore.dispatcher.__init__(self) # Old-style class self.server = server self.sockname = "%s.%d" % (kickme_base, os.getpid()) self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) @@ -1912,37 +2115,41 @@ def cronjob_main(argv): if len(argv) != 1: sys.exit("Expected one argument, got %r" % (argv,)) - old_ixfrs = glob.glob("*.ix.*") + for version in sorted(PDU.version_map.iterkeys(), reverse = True): - current = read_current()[0] - cutoff = Timestamp.now(-(24 * 60 * 60)) - for f in glob.iglob("*.ax"): - t = Timestamp(int(f.split(".")[0])) - if t < cutoff and t != current: - blather("# Deleting old file %s, timestamp %s" % (f, t)) - os.unlink(f) + blather("# Generating updates for protocol version %d" % version) - pdus = AXFRSet.parse_rcynic(argv[0]) - if pdus == AXFRSet.load_current(): - blather("# No change, new version not needed") - sys.exit() - pdus.save_axfr() - for axfr in glob.iglob("*.ax"): - if axfr != pdus.filename(): - pdus.save_ixfr(AXFRSet.load(axfr)) - pdus.mark_current() + old_ixfrs = glob.glob("*.ix.*.v%d" % version) - blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial)) + current = read_current(version)[0] + cutoff = Timestamp.now(-(24 * 60 * 60)) + for f in glob.iglob("*.ax.v%d" % version): + t = Timestamp(int(f.split(".")[0])) + if t < cutoff and t != current: + blather("# Deleting old file %s, timestamp %s" % (f, t)) + os.unlink(f) - kick_all(pdus.serial) + pdus = AXFRSet.parse_rcynic(argv[0], version) + if pdus == AXFRSet.load_current(version): + blather("# No change, new serial not needed") + continue + pdus.save_axfr() + for axfr in glob.iglob("*.ax.v%d" % version): + if axfr != pdus.filename(): + pdus.save_ixfr(AXFRSet.load(axfr)) + pdus.mark_current() - old_ixfrs.sort() - for ixfr in old_ixfrs: - try: - blather("# Deleting old file %s" % ixfr) - os.unlink(ixfr) - except OSError: - pass + blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial)) + + kick_all(pdus.serial) + + old_ixfrs.sort() + for ixfr in old_ixfrs: + try: + blather("# Deleting old file %s" % ixfr) + os.unlink(ixfr) + except OSError: + pass def show_main(argv): """ @@ -1956,12 +2163,12 @@ def show_main(argv): if argv: sys.exit("Unexpected arguments: %r" % (argv,)) - g = glob.glob("*.ax") + g = glob.glob("*.ax.v*") g.sort() for f in g: AXFRSet.load(f).show() - g = glob.glob("*.ix.*") + g = glob.glob("*.ix.*.v*") g.sort() for f in g: IXFRSet.load(f).show() @@ -2068,8 +2275,8 @@ def listener_tcp_main(argv): blather("[Received connection from %r]" % (ai,)) pid = os.fork() if pid == 0: - os.dup2(s.fileno(), 0) - os.dup2(s.fileno(), 1) + os.dup2(s.fileno(), 0) # pylint: disable=E1103 + os.dup2(s.fileno(), 1) # pylint: disable=E1103 s.close() #os.closerange(3, os.sysconf("SC_OPEN_MAX")) global log_tag @@ -2082,7 +2289,7 @@ def listener_tcp_main(argv): blather("[Spawned server %d]" % pid) try: while True: - pid, status = os.waitpid(0, os.WNOHANG) + pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612 if pid: blather("[Server %s exited]" % pid) else: @@ -2146,11 +2353,17 @@ def client_main(argv): client.setup_sql(sqlname) while True: if client.current_serial is None or client.current_nonce is None: - client.push_pdu(ResetQueryPDU()) + client.push_pdu(ResetQueryPDU(version = client.version)) else: - client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce)) - wakeup = time.time() + 600 + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.current_serial, + nonce = client.current_nonce)) + polled = Timestamp.now() + wakeup = None while True: + if wakeup != polled + client.refresh: + wakeup = Timestamp(polled + client.refresh) + log("[Last client poll %s, next %s]" % (polled, wakeup)) remaining = wakeup - time.time() if remaining < 0: break @@ -2183,10 +2396,11 @@ def bgpdump_convert_main(argv): first = True db = None axfrs = [] + version = max(PDU.version_map.iterkeys()) for filename in argv: - if filename.endswith(".ax"): + if ".ax.v" in filename: blather("Reading %s" % filename) db = AXFRSet.load(filename) @@ -2203,7 +2417,7 @@ def bgpdump_convert_main(argv): sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename) blather("DB serial now %d (%s)" % (db.serial, db.serial)) - if first and read_current() == (None, None): + if first and read_current(version) == (None, None): db.mark_current() first = False @@ -2230,21 +2444,22 @@ def bgpdump_select_main(argv): You have been warned. """ + version = max(PDU.version_map.iterkeys()) serial = None try: head, sep, tail = os.path.basename(argv[0]).partition(".") - if len(argv) == 1 and head.isdigit() and sep == "." and tail == "ax": + if len(argv) == 1 and head.isdigit() and sep == "." and tail.startswith("ax.v") and tail[4:].isdigit(): serial = Timestamp(head) except: pass if serial is None: sys.exit("Argument must be name of a .ax file") - nonce = read_current()[1] + nonce = read_current(version)[1] if nonce is None: nonce = new_nonce() - write_current(serial, nonce) + write_current(serial, nonce, version) kick_all(serial) @@ -2264,7 +2479,7 @@ class BGPDumpReplayClock(object): """ def __init__(self): - self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")] + self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")] self.timestamps.sort() self.offset = self.timestamps[0] - int(time.time()) self.nonce = new_nonce() @@ -2275,7 +2490,7 @@ class BGPDumpReplayClock(object): def now(self): return Timestamp.now(self.offset) - def read_current(self): + def read_current(self, version): now = self.now() while len(self.timestamps) > 1 and now >= self.timestamps[1]: del self.timestamps[0] @@ -2340,6 +2555,7 @@ def bgpdump_server_main(argv): scan_roas = None scan_routercerts = None force_zero_nonce = False +debug = False kickme_dir = "sockets" kickme_base = os.path.join(kickme_dir, "kickme") @@ -2360,6 +2576,8 @@ def usage(msg = None): f.write("\n") f.write("where options are zero or more of:\n") f.write("\n") + f.write("--debug\n") + f.write("\n") f.write("--scan-roas /path/to/scan_roas\n") f.write("\n") f.write("--scan-routercerts /path/to/scan_routercerts\n") @@ -2385,13 +2603,15 @@ if __name__ == "__main__": syslog_facility, syslog_warning, syslog_info = syslog.LOG_DAEMON, syslog.LOG_WARNING, syslog.LOG_INFO - opts, argv = getopt.getopt(sys.argv[1:], "hs:z?", ["help", "scan-roas=", "scan-routercerts=", - "syslog=", "zero-nonce"] + main_dispatch.keys()) + opts, argv = getopt.getopt(sys.argv[1:], "dhs:z?", ["help", "debug", "scan-roas=", "scan-routercerts=", + "syslog=", "zero-nonce"] + main_dispatch.keys()) for o, a in opts: if o in ("-h", "--help", "-?"): usage() elif o in ("-z", "--zero-nonce"): force_zero_nonce = True + elif o in ("-d", "--debug"): + debug = True elif o in ("-s", "--syslog"): try: a = [getattr(syslog, "LOG_" + i.upper()) for i in a.split(".")] @@ -2419,7 +2639,7 @@ if __name__ == "__main__": if mode in ("server", "bgpdump_server"): log_tag += hostport_tag() - if mode in ("cronjob", "server" , "bgpdump_server"): + if not debug and mode in ("cronjob", "server" , "bgpdump_server"): syslog.openlog(log_tag, syslog.LOG_PID, syslog_facility) def log(msg): return syslog.syslog(syslog_warning, str(msg)) |