aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xrp/rpki-rtr/rtr-origin588
1 files changed, 404 insertions, 184 deletions
diff --git a/rp/rpki-rtr/rtr-origin b/rp/rpki-rtr/rtr-origin
index e127b2b2..b2b3d1ab 100755
--- a/rp/rpki-rtr/rtr-origin
+++ b/rp/rpki-rtr/rtr-origin
@@ -59,8 +59,10 @@ class Timestamp(int):
Wrapper around time module.
"""
- def __new__(cls, x):
- return int.__new__(cls, x)
+ def __new__(cls, t):
+ # http://stackoverflow.com/questions/7471255/pythons-super-and-new-confused-me
+ #return int.__new__(cls, t)
+ return super(Timestamp, cls).__new__(cls, t)
@classmethod
def now(cls, delta = 0):
@@ -70,39 +72,36 @@ class Timestamp(int):
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self))
-def read_current():
+def read_current(version):
"""
Read current serial number and nonce. Return None for both if
serial and nonce not recorded. For backwards compatibility, treat
file containing just a serial number as having a nonce of zero.
"""
+ if version is None:
+ return None, None
try:
- f = open("current", "r")
- values = tuple(int(s) for s in f.read().split())
- f.close()
+ with open("current.v%d" % version, "r") as f:
+ values = tuple(int(s) for s in f.read().split())
return values[0], values[1]
except IndexError:
return values[0], 0
except IOError:
return None, None
-def write_current(serial, nonce):
+
+def write_current(serial, nonce, version):
"""
Write serial number and nonce.
"""
- tmpfn = "current.%d.tmp" % os.getpid()
- try:
- f = open(tmpfn, "w")
+ assert version in PDU.version_map
+ curfn = "current.v%d" % version
+ tmpfn = curfn + "%d.tmp" % os.getpid()
+ with open(tmpfn, "w") as f:
f.write("%d %d\n" % (serial, nonce))
- f.close()
- os.rename(tmpfn, "current")
- finally:
- try:
- os.unlink(tmpfn)
- except:
- pass
+ os.rename(tmpfn, curfn)
def new_nonce():
@@ -121,10 +120,14 @@ def new_nonce():
class ReadBuffer(object):
"""
Wrapper around synchronous/asynchronous read state.
+
+ This also handles tracking the current protocol version,
+ because it has to go somewhere and there's no better place.
"""
def __init__(self):
self.buffer = ""
+ self.version = None
def update(self, need, callback):
"""
@@ -133,6 +136,13 @@ class ReadBuffer(object):
self.need = need
self.callback = callback
+ return self.retry()
+
+ def retry(self):
+ """
+ Try dispatching to the callback again.
+ """
+
return self.callback(self)
def available(self):
@@ -172,12 +182,20 @@ class ReadBuffer(object):
self.buffer += b
- def retry(self):
+ def check_version(self, version):
"""
- Try dispatching to the callback again.
+ Track version number of PDUs read from this buffer.
+ Once set, the version must not change.
"""
- return self.callback(self)
+ if self.version is not None and version != self.version:
+ raise CorruptData(
+ "Received PDU version %d, expected %d" % (version, self.version))
+ if self.version is None and version not in PDU.version_map:
+ raise UnsupportedProtocolVersion(
+ "Received PDU version %d, known versions %s" % (version, ", ".PDU.version_map.iterkeys()))
+ self.version = version
+
class PDUException(Exception):
"""
@@ -188,7 +206,7 @@ class PDUException(Exception):
"""
def __init__(self, msg = None, pdu = None):
- Exception.__init__(self)
+ super(PDUException, self).__init__()
assert msg is None or isinstance(msg, (str, unicode))
self.error_report_msg = msg
self.error_report_pdu = pdu
@@ -196,10 +214,11 @@ class PDUException(Exception):
def __str__(self):
return self.error_report_msg or self.__class__.__name__
- def make_error_report(self):
- return ErrorReportPDU(errno = self.error_report_code,
- errmsg = self.error_report_msg,
- errpdu = self.error_report_pdu)
+ def make_error_report(self, version):
+ return ErrorReportPDU(version = version,
+ errno = self.error_report_code,
+ errmsg = self.error_report_msg,
+ errpdu = self.error_report_pdu)
class UnsupportedProtocolVersion(PDUException):
error_report_code = 4
@@ -211,41 +230,51 @@ class CorruptData(PDUException):
error_report_code = 0
-def wire_pdu(cls):
+def wire_pdu(cls, versions = None):
"""
- Class decorator to add a PDU class to the set of known PDUs.
-
- In the long run, this decorator may take additional arguments
- specifying which protocol version(s) use this particular PDU,
- but we're not there yet.
+ Class decorator to add a PDU class to the set of known PDUs
+ for all supported protocol versions.
"""
- PDU.pdu_map[cls.pdu_type] = cls
+ for v in PDU.version_map.iterkeys() if versions is None else versions:
+ assert cls.pdu_type not in PDU.version_map[v]
+ PDU.version_map[v][cls.pdu_type] = cls
return cls
+def wire_pdu_only(*versions):
+ """
+ Class decorator to add a PDU class to the set of known PDUs
+ for specific protocol versions.
+ """
+
+ assert versions and all(v in PDU.version_map for v in versions)
+ return lambda cls: wire_pdu(cls, versions)
+
class PDU(object):
"""
Object representing a generic PDU in the rpki-router protocol.
Real PDUs are subclasses of this class.
"""
- pdu_map = {} # Updated by @wire_pdu
-
- version = 0 # Protocol version
+ version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu
_pdu = None # Cached when first generated
- header_struct = struct.Struct("!BBHL")
+ header_struct = struct.Struct("!BB2xL")
+
+ def __init__(self, version):
+ assert version in self.version_map
+ self.version = version
def __cmp__(self, other):
return cmp(self.to_pdu(), other.to_pdu())
- def check(self):
- """
- Check attributes to make sure they're within range.
- """
+ @property
+ def default_version(self):
+ return max(self.version_map.iterkeys())
+ def check(self):
pass
@classmethod
@@ -257,23 +286,21 @@ class PDU(object):
if not reader.ready():
return None
assert reader.available() >= cls.header_struct.size
- version, pdu_type, whatever, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
- if version != cls.version:
- raise UnsupportedProtocolVersion(
- "Received PDU version %d, expected %d" % (version, cls.version))
- if pdu_type not in cls.pdu_map:
+ version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
+ reader.check_version(version)
+ if pdu_type not in cls.version_map[version]:
raise UnsupportedPDUType(
"Received unsupported PDU type %d" % pdu_type)
if length < 8:
raise CorruptData(
"Received PDU with length %d, which is too short to be valid" % length)
- self = cls.pdu_map[pdu_type]()
+ self = cls.version_map[version][pdu_type](version = version)
return reader.update(need = length, callback = self.got_pdu)
def consume(self, client):
"""
Handle results in test client. Default behavior is just to print
- out the PDU.
+ out the PDU; data PDU subclasses may override this.
"""
blather(self)
@@ -283,17 +310,25 @@ class PDU(object):
Send a content of a file as a cache response. Caller should catch IOError.
"""
+ fn2 = os.path.splitext(filename)[1]
+ assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version
+
f = open(filename, "rb")
- server.push_pdu(CacheResponsePDU(nonce = server.current_nonce))
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
server.push_file(f)
- server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce))
def send_nodata(self, server):
"""
Send a nodata error.
"""
- server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["No Data Available"], errpdu = self))
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["No Data Available"],
+ errpdu = self))
class PDUWithSerial(PDU):
"""
@@ -302,7 +337,8 @@ class PDUWithSerial(PDU):
header_struct = struct.Struct("!BBHLL")
- def __init__(self, serial = None, nonce = None):
+ def __init__(self, version, serial = None, nonce = None):
+ super(PDUWithSerial, self).__init__(version)
if serial is not None:
assert isinstance(serial, int)
self.serial = serial
@@ -328,6 +364,7 @@ class PDUWithSerial(PDU):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
if length != 12:
raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
@@ -340,7 +377,8 @@ class PDUWithNonce(PDU):
header_struct = struct.Struct("!BBHL")
- def __init__(self, nonce = None):
+ def __init__(self, version, nonce = None):
+ super(PDUWithNonce, self).__init__(version)
if nonce is not None:
assert isinstance(nonce, int)
self.nonce = nonce
@@ -362,6 +400,7 @@ class PDUWithNonce(PDU):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
if length != 8:
raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
@@ -391,6 +430,7 @@ class PDUEmpty(PDU):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, zero, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
if zero != 0:
raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
if length != 8:
@@ -414,9 +454,11 @@ class SerialNotifyPDU(PDUWithSerial):
blather(self)
if client.current_serial is None or client.current_nonce != self.nonce:
- client.push_pdu(ResetQueryPDU())
+ client.push_pdu(ResetQueryPDU(version = client.version))
elif self.serial != client.current_serial:
- client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce))
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.current_serial,
+ nonce = client.current_nonce))
else:
blather("[Notify did not change serial number, ignoring]")
@@ -428,6 +470,9 @@ class SerialQueryPDU(PDUWithSerial):
pdu_type = 1
+ def __init__(self, version, serial = None, nonce = None):
+ super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce)
+
def serve(self, server):
"""
Received a serial query, send incremental transfer in response.
@@ -440,18 +485,21 @@ class SerialQueryPDU(PDUWithSerial):
self.send_nodata(server)
elif server.current_nonce != self.nonce:
log("[Client requested wrong nonce, resetting client]")
- server.push_pdu(CacheResetPDU())
+ server.push_pdu(CacheResetPDU(version = server.version))
elif server.current_serial == self.serial:
blather("[Client is already current, sending empty IXFR]")
- server.push_pdu(CacheResponsePDU(nonce = server.current_nonce))
- server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce))
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce))
elif disable_incrementals:
- server.push_pdu(CacheResetPDU())
+ server.push_pdu(CacheResetPDU(version = server.version))
else:
try:
- self.send_file(server, "%d.ix.%d" % (server.current_serial, self.serial))
+ self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version))
except IOError:
- server.push_pdu(CacheResetPDU())
+ server.push_pdu(CacheResetPDU(version = server.version))
@wire_pdu
class ResetQueryPDU(PDUEmpty):
@@ -461,6 +509,9 @@ class ResetQueryPDU(PDUEmpty):
pdu_type = 2
+ def __init__(self, version):
+ super(ResetQueryPDU, self).__init__(self.default_version if version is None else version)
+
def serve(self, server):
"""
Received a reset query, send full current state in response.
@@ -471,11 +522,13 @@ class ResetQueryPDU(PDUEmpty):
self.send_nodata(server)
else:
try:
- fn = "%d.ax" % server.current_serial
+ fn = "%d.ax.v%d" % (server.current_serial, server.version)
self.send_file(server, fn)
except IOError:
- server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["Internal Error"],
- errpdu = self, errmsg = "Couldn't open %s" % fn))
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["Internal Error"],
+ errpdu = self,
+ errmsg = "Couldn't open %s" % fn))
@wire_pdu
class CacheResponsePDU(PDUWithNonce):
@@ -495,21 +548,98 @@ class CacheResponsePDU(PDUWithNonce):
blather("[Nonce changed, resetting]")
client.cache_reset()
-@wire_pdu
-class EndOfDataPDU(PDUWithSerial):
+
+def EndOfDataPDU(version, *args, **kwargs):
"""
- End of Data PDU.
+ Factory for the EndOfDataPDU classes, which take different forms in
+ different protocol versions.
+ """
+
+ if version == 0:
+ return EndOfDataPDUv0(version, *args, **kwargs)
+ if version == 1:
+ return EndOfDataPDUv1(version, *args, **kwargs)
+ raise NotImplementedError
+
+
+@wire_pdu_only(0)
+class EndOfDataPDUv0(PDUWithSerial):
+ """
+ End of Data PDU, protocol version 0.
"""
pdu_type = 7
+ # Default values, from the current RFC 6810 bis I-D.
+ # Putting these here lets us use them in our client API for both
+ # protocol versions, even though they can only be set in the
+ # protocol in version 1.
+
+ refresh = 3600
+ retry = 600
+ expire = 7200
+
def consume(self, client):
"""
Handle EndOfDataPDU response.
"""
blather(self)
- client.end_of_data(self.serial, self.nonce)
+ client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+
+@wire_pdu_only(1)
+class EndOfDataPDUv1(EndOfDataPDUv0):
+ """
+ End of Data PDU, protocol version 1.
+ """
+
+ header_struct = struct.Struct("!BBHLLLLL")
+
+ def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None):
+ super(EndOfDataPDUv1, self).__init__(version)
+ if serial is not None:
+ assert isinstance(serial, int)
+ self.serial = serial
+ if nonce is not None:
+ assert isinstance(nonce, int)
+ self.nonce = nonce
+ if refresh is not None:
+ assert isinstance(refresh, int)
+ self.refresh = refresh
+ if retry is not None:
+ assert isinstance(retry, int)
+ self.retry = retry
+ if expire is not None:
+ assert isinstance(expire, int)
+ self.expire = expire
+
+ def __str__(self):
+ return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % (
+ self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
+
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
+ self.header_struct.size, self.serial,
+ self.refresh, self.retry, self.expire)
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \
+ = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 24:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
+
@wire_pdu
class CacheResetPDU(PDUEmpty):
@@ -526,7 +656,7 @@ class CacheResetPDU(PDUEmpty):
blather(self)
client.cache_reset()
- client.push_pdu(ResetQueryPDU())
+ client.push_pdu(ResetQueryPDU(version = client.version))
class PrefixPDU(PDU):
"""
@@ -543,14 +673,14 @@ class PrefixPDU(PDU):
asnum_struct = struct.Struct("!L")
@staticmethod
- def from_text(asnum, addr):
+ def from_text(version, asn, addr):
"""
Construct a prefix from its text form.
"""
cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
- self = cls()
- self.asn = long(asnum)
+ self = cls(version = version)
+ self.asn = long(asn)
p, l = addr.split("/")
self.prefix = rpki.POW.IPAddress(p)
if "-" in l:
@@ -562,15 +692,15 @@ class PrefixPDU(PDU):
return self
@staticmethod
- def from_roa(asnum, prefix_tuple):
+ def from_roa(version, asn, prefix_tuple):
"""
Construct a prefix from a ROA.
"""
address, length, maxlength = prefix_tuple
cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
- self = cls()
- self.asn = asnum
+ self = cls(version = version)
+ self.asn = asn
self.prefix = address
self.prefixlen = length
self.max_prefixlen = length if maxlength is None else maxlength
@@ -643,6 +773,7 @@ class PrefixPDU(PDU):
b2 = reader.get(self.address_byte_count)
b3 = reader.get(self.asnum_struct.size)
version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
+ assert version == self.version and pdu_type == self.pdu_type
if length != len(b1) + len(b2) + len(b3):
raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
self.prefix = rpki.POW.IPAddress.fromBytes(b2)
@@ -712,7 +843,7 @@ class IPv6PrefixPDU(PrefixPDU):
pdu_type = 6
address_byte_count = 16
-@wire_pdu
+@wire_pdu_only(1)
class RouterKeyPDU(PDU):
"""
Router Key PDU.
@@ -723,13 +854,13 @@ class RouterKeyPDU(PDU):
header_struct = struct.Struct("!BBBxL20sL")
@classmethod
- def from_text(cls, asnum, gski, key):
+ def from_text(cls, version, asn, gski, key):
"""
Construct a router key from its text form.
"""
- self = cls()
- self.asn = long(asnum)
+ self = cls(version = version)
+ self.asn = long(asn)
self.ski = base64.urlsafe_b64decode(gski + "=")
self.key = base64.b64decode(key)
self.announce = 1
@@ -737,13 +868,13 @@ class RouterKeyPDU(PDU):
return self
@classmethod
- def from_certificate(cls, asnum, ski, key):
+ def from_certificate(cls, version, asn, ski, key):
"""
Construct a router key from a certificate.
"""
- self = cls()
- self.asn = asnum
+ self = cls(version = version)
+ self.asn = asn
self.ski = ski
self.key = key
self.announce = 1
@@ -799,6 +930,7 @@ class RouterKeyPDU(PDU):
return None
header = reader.get(self.header_struct.size)
version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
remaining = length - self.header_struct.size
if remaining <= 0:
raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self)
@@ -836,7 +968,8 @@ class ErrorReportPDU(PDU):
codes = dict((v, k) for k, v in errors.items())
- def __init__(self, errno = None, errpdu = None, errmsg = None):
+ def __init__(self, version, errno = None, errpdu = None, errmsg = None):
+ super(ErrorReportPDU, self).__init__(version)
assert errno is None or errno in self.errors
self.errno = errno
self.errpdu = errpdu
@@ -879,6 +1012,7 @@ class ErrorReportPDU(PDU):
return None
header = reader.get(self.header_struct.size)
version, pdu_type, self.errno, length = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
remaining = length - self.header_struct.size
self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
@@ -903,7 +1037,7 @@ class ErrorReportPDU(PDU):
sys.exit(1)
-class ROA(rpki.POW.ROA):
+class ROA(rpki.POW.ROA): # pylint: disable=W0232
"""
Minor additions to rpki.POW.ROA.
"""
@@ -924,7 +1058,7 @@ class ROA(rpki.POW.ROA):
for p in v6:
yield p
-class X509(rpki.POW.X509):
+class X509(rpki.POW.X509): # pylint: disable=W0232
"""
Minor additions to rpki.POW.X509.
"""
@@ -945,13 +1079,18 @@ class PDUSet(list):
from rcynic's output.
"""
+ def __init__(self, version):
+ assert version in PDU.version_map
+ super(PDUSet, self).__init__()
+ self.version = version
+
@classmethod
- def _load_file(cls, filename):
+ def _load_file(cls, filename, version):
"""
Low-level method to read PDUSet from a file.
"""
- self = cls()
+ self = cls(version = version)
f = open(filename, "rb")
r = ReadBuffer()
while True:
@@ -963,6 +1102,7 @@ class PDUSet(list):
return self
r.put(b)
p = r.retry()
+ assert p.version == self.version
self.append(p)
@staticmethod
@@ -979,7 +1119,7 @@ class AXFRSet(PDUSet):
"""
@classmethod
- def parse_rcynic(cls, rcynic_dir):
+ def parse_rcynic(cls, rcynic_dir, version):
"""
Parse ROAS and router certificates fetched (and validated!) by
rcynic to create a new AXFRSet.
@@ -994,24 +1134,26 @@ class AXFRSet(PDUSet):
can make this one a bit simpler and faster.
"""
- self = cls()
+ self = cls(version = version)
self.serial = Timestamp.now()
- if scan_roas is None or scan_routercerts is None:
- for root, dirs, files in os.walk(rcynic_dir):
+ include_routercerts = RouterKeyPDU.pdu_type in PDU.version_map[version]
+
+ if scan_roas is None or (scan_routercerts is None and include_routercerts):
+ for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612
for fn in files:
if scan_roas is None and fn.endswith(".roa"):
roa = ROA.derReadFile(os.path.join(root, fn))
asn = roa.getASID()
- self.extend(PrefixPDU.from_roa(asn, roa_prefix)
- for roa_prefix in roa.prefixes)
- if scan_routercerts is None and fn.endswith(".cer"):
+ self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple)
+ for prefix_tuple in roa.prefixes)
+ if include_routercerts and scan_routercerts is None and fn.endswith(".cer"):
x = X509.derReadFile(os.path.join(root, fn))
eku = x.getEKU()
if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
ski = x.getSKI()
key = x.getPublicKey().derWritePublic()
- self.extend(RouterKeyPDU.from_certificate(asn, ski, key)
+ self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key)
for asn in x.asns)
if scan_roas is not None:
@@ -1020,18 +1162,20 @@ class AXFRSet(PDUSet):
for line in p.stdout:
line = line.split()
asn = line[1]
- self.extend(PrefixPDU.from_text(asn, addr) for addr in line[2:])
+ self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr)
+ for addr in line[2:])
except OSError, e:
sys.exit("Could not run %s: %s" % (scan_roas, e))
- if scan_routercerts is not None:
+ if include_routercerts and scan_routercerts is not None:
try:
p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE)
for line in p.stdout:
line = line.split()
gski = line[0]
key = line[-1]
- self.extend(RouterKeyPDU.from_text(asn, gski, key) for asn in line[1:-1])
+ self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key)
+ for asn in line[1:-1])
except OSError, e:
sys.exit("Could not run %s: %s" % (scan_routercerts, e))
@@ -1044,12 +1188,13 @@ class AXFRSet(PDUSet):
@classmethod
def load(cls, filename):
"""
- Load an AXFRSet from a file, parse filename to obtain serial.
+ Load an AXFRSet from a file, parse filename to obtain version and serial.
"""
- fn1, fn2 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ax"
- self = cls._load_file(filename)
+ fn1, fn2, fn3 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit()
+ version = int(fn3[1:])
+ self = cls._load_file(filename, version)
self.serial = Timestamp(fn1)
return self
@@ -1058,19 +1203,19 @@ class AXFRSet(PDUSet):
Generate filename for this AXFRSet.
"""
- return "%d.ax" % self.serial
+ return "%d.ax.v%d" % (self.serial, self.version)
@classmethod
- def load_current(cls):
+ def load_current(cls, version):
"""
Load current AXFRSet. Return None if can't.
"""
- serial = read_current()[0]
+ serial = read_current(version)[0]
if serial is None:
return None
try:
- return cls.load("%d.ax" % serial)
+ return cls.load("%d.ax.v%d" % (serial, version))
except IOError:
return None
@@ -1090,9 +1235,9 @@ class AXFRSet(PDUSet):
the old serial numbers are no longer valid.
"""
- for i in glob.iglob("*.ix.*"):
+ for i in glob.iglob("*.ix.*.v%d" % self.version):
os.unlink(i)
- for i in glob.iglob("*.ax"):
+ for i in glob.iglob("*.ax.v%d" % self.version):
if i != self.filename():
os.unlink(i)
@@ -1103,12 +1248,12 @@ class AXFRSet(PDUSet):
the new nonce invalidates all old serial numbers.
"""
- old_serial, nonce = read_current()
+ old_serial, nonce = read_current(self.version)
if old_serial is None or self.seq_ge(old_serial, self.serial):
blather("Creating new nonce and deleting stale data")
nonce = new_nonce()
self.destroy_old_data()
- write_current(self.serial, nonce)
+ write_current(self.serial, nonce, self.version)
def save_ixfr(self, other):
"""
@@ -1118,7 +1263,7 @@ class AXFRSet(PDUSet):
comparison.
"""
- f = open("%d.ix.%d" % (self.serial, other.serial), "wb")
+ f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb")
old = other
new = self
len_old = len(old)
@@ -1145,7 +1290,7 @@ class AXFRSet(PDUSet):
Print this AXFRSet.
"""
- blather("# AXFR %d (%s)" % (self.serial, self.serial))
+ blather("# AXFR %d (%s) v%d" % (self.serial, self.serial, self.version))
for p in self:
blather(p)
@@ -1207,12 +1352,13 @@ class IXFRSet(PDUSet):
@classmethod
def load(cls, filename):
"""
- Load an IXFRSet from a file, parse filename to obtain serials.
+ Load an IXFRSet from a file, parse filename to obtain version and serials.
"""
- fn1, fn2, fn3 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit()
- self = cls._load_file(filename)
+ fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit()
+ version = int(fn4[1:])
+ self = cls._load_file(filename, version)
self.from_serial = Timestamp(fn3)
self.to_serial = Timestamp(fn1)
return self
@@ -1222,15 +1368,16 @@ class IXFRSet(PDUSet):
Generate filename for this IXFRSet.
"""
- return "%d.ix.%d" % (self.to_serial, self.from_serial)
+ return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version)
def show(self):
"""
Print this IXFRSet.
"""
- blather("# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial,
- self.to_serial, self.to_serial))
+ blather("# IXFR %d (%s) -> %d (%s) v%d" % (self.from_serial, self.from_serial,
+ self.to_serial, self.to_serial,
+ self.version))
for p in self:
blather(p)
@@ -1246,7 +1393,7 @@ class FileProducer(object):
def more(self):
return self.handle.read(self.buffersize)
-class PDUChannel(asynchat.async_chat):
+class PDUChannel(asynchat.async_chat, object):
"""
asynchat subclass that understands our PDUs. This just handles
network I/O. Specific engines (client, server) should be subclasses
@@ -1255,9 +1402,17 @@ class PDUChannel(asynchat.async_chat):
"""
def __init__(self, conn = None):
- asynchat.async_chat.__init__(self, conn)
+ asynchat.async_chat.__init__(self, conn) # Old-style class
self.reader = ReadBuffer()
+ @property
+ def version(self):
+ return self.reader.version
+
+ @version.setter
+ def version(self, version):
+ self.reader.check_version(version)
+
def start_new_pdu(self):
"""
Start read of a new PDU.
@@ -1269,7 +1424,7 @@ class PDUChannel(asynchat.async_chat):
self.deliver_pdu(p)
p = PDU.read_pdu(self.reader)
except PDUException, e:
- self.push_pdu(e.make_error_report())
+ self.push_pdu(e.make_error_report(version = self.version))
self.close_when_done()
else:
assert not self.reader.ready()
@@ -1382,7 +1537,7 @@ class ServerWriteChannel(PDUChannel):
Set up stdout.
"""
- PDUChannel.__init__(self)
+ super(ServerWriteChannel, self).__init__()
self.init_file_dispatcher(sys.stdout.fileno())
def readable(self):
@@ -1392,6 +1547,7 @@ class ServerWriteChannel(PDUChannel):
return False
+
class ServerChannel(PDUChannel):
"""
Server protocol engine, handles upcalls from PDUChannel to
@@ -1404,7 +1560,7 @@ class ServerChannel(PDUChannel):
first PDU.
"""
- PDUChannel.__init__(self)
+ super(ServerChannel, self).__init__()
self.init_file_dispatcher(sys.stdin.fileno())
self.writer = ServerWriteChannel()
self.get_serial()
@@ -1460,7 +1616,7 @@ class ServerChannel(PDUChannel):
mode instance is still building its database.
"""
- self.current_serial, self.current_nonce = read_current()
+ self.current_serial, self.current_nonce = read_current(self.version)
return self.current_serial
def check_serial(self):
@@ -1473,13 +1629,21 @@ class ServerChannel(PDUChannel):
def notify(self, data = None):
"""
- Cronjob instance kicked us, send a notify message.
+ Cronjob instance kicked us: check whether our serial number has
+ changed, and send a notify message if so.
+
+ We have to check rather than just blindly notifying when kicked
+ because the cronjob instance has no good way of knowing which
+ protocol version we're running, thus has no good way of knowing
+ whether we care about a particular change set or not.
"""
- if self.check_serial() is not None:
- self.push_pdu(SerialNotifyPDU(serial = self.current_serial, nonce = self.current_nonce))
+ if self.check_serial():
+ self.push_pdu(SerialNotifyPDU(version = self.version,
+ serial = self.current_serial,
+ nonce = self.current_nonce))
else:
- log("Cronjob kicked me without a valid current serial number")
+ log("Cronjob kicked me but I see no serial change, ignoring")
class ClientChannel(PDUChannel):
"""
@@ -1493,12 +1657,21 @@ class ClientChannel(PDUChannel):
port = None
cache_id = None
+ # For initial test purposes, let's use the minimum allowed values
+ # from the RFC 6810 bis I-D as the initial defaults for refresh and
+ # retry, and the maximum allowed for expire; these will be overriden
+ # as soon as we receive an EndOfDataPDU.
+ #
+ refresh = 120
+ retry = 120
+ expire = 172800
+
def __init__(self, sock, proc, killsig, host, port):
self.killsig = killsig
self.proc = proc
self.host = host
self.port = port
- PDUChannel.__init__(self, conn = sock)
+ super(ClientChannel, self).__init__(conn = sock)
self.start_new_pdu()
@classmethod
@@ -1529,7 +1702,7 @@ class ClientChannel(PDUChannel):
blather("[socket.getaddrinfo() failed: %s]" % e)
else:
for ai in addrinfo:
- af, socktype, proto, cn, sa = ai
+ af, socktype, proto, cn, sa = ai # pylint: disable=W0612
blather("[Trying addr %s port %s]" % sa[:2])
try:
s = socket.socket(af, socktype, proto)
@@ -1601,9 +1774,13 @@ class ClientChannel(PDUChannel):
cache_id INTEGER PRIMARY KEY NOT NULL,
host TEXT NOT NULL,
port TEXT NOT NULL,
+ version INTEGER,
nonce INTEGER,
serial INTEGER,
updated INTEGER,
+ refresh INTEGER,
+ retry INTEGER,
+ expire INTEGER,
UNIQUE (host, port))''')
cur.execute('''
CREATE TABLE prefix (
@@ -1629,14 +1806,26 @@ class ClientChannel(PDUChannel):
UNIQUE (cache_id, asn, ski),
UNIQUE (cache_id, asn, key))''')
- cur.execute("SELECT cache_id, nonce, serial FROM cache WHERE host = ? AND port = ?",
+ cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire "
+ "FROM cache WHERE host = ? AND port = ?",
(self.host, self.port))
try:
- self.cache_id, self.current_nonce, self.current_serial = cur.fetchone()
+ self.cache_id, version, self.current_nonce, self.current_serial, refresh, retry, expire = cur.fetchone()
+ if version is not None:
+ self.version = version
+ if refresh is not None:
+ self.refresh = refresh
+ if retry is not None:
+ self.retry = retry
+ if expire is not None:
+ self.expire = expire
except TypeError:
cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port))
self.cache_id = cur.lastrowid
self.sql.commit()
+ log("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s]" % (
+ self.cache_id, self.version, self.current_nonce, self.current_serial, self.refresh, self.retry, self.expire))
+
def cache_reset(self):
"""
@@ -1645,20 +1834,34 @@ class ClientChannel(PDUChannel):
self.current_serial = None
if self.sql:
+ #
+ # For some reason there was no commit here. Dunno why.
+ # See if adding one breaks anything....
+ #
cur = self.sql.cursor()
cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,))
- cur.execute("UPDATE cache SET serial = NULL WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id))
+ self.sql.commit()
- def end_of_data(self, serial, nonce):
+ def end_of_data(self, version, serial, nonce, refresh, retry, expire):
"""
Handle EndOfDataPDU actions.
"""
- self.current_serial = serial
- self.current_nonce = nonce
+ assert version == self.version
+ self.current_serial = serial
+ self.current_nonce = nonce
+ self.refresh = refresh
+ self.retry = retry
+ self.expire = expire
if self.sql:
- self.sql.execute("UPDATE cache SET serial = ?, nonce = ?, updated = datetime('now') WHERE cache_id = ?",
- (serial, nonce, self.cache_id))
+ self.sql.execute("UPDATE cache SET"
+ " version = ?, serial = ?, nonce = ?,"
+ " refresh = ?, retry = ?, expire = ?,"
+ " updated = datetime('now') "
+ "WHERE cache_id = ?",
+ (version, serial, nonce, refresh, retry, expire, self.cache_id))
self.sql.commit()
def consume_prefix(self, prefix):
@@ -1733,14 +1936,14 @@ class ClientChannel(PDUChannel):
blather("Server closed channel")
PDUChannel.handle_close(self)
-class KickmeChannel(asyncore.dispatcher):
+class KickmeChannel(asyncore.dispatcher, object):
"""
asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
kick servers when it's time to send notify PDUs to clients.
"""
def __init__(self, server):
- asyncore.dispatcher.__init__(self)
+ asyncore.dispatcher.__init__(self) # Old-style class
self.server = server
self.sockname = "%s.%d" % (kickme_base, os.getpid())
self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM)
@@ -1912,37 +2115,41 @@ def cronjob_main(argv):
if len(argv) != 1:
sys.exit("Expected one argument, got %r" % (argv,))
- old_ixfrs = glob.glob("*.ix.*")
+ for version in sorted(PDU.version_map.iterkeys(), reverse = True):
- current = read_current()[0]
- cutoff = Timestamp.now(-(24 * 60 * 60))
- for f in glob.iglob("*.ax"):
- t = Timestamp(int(f.split(".")[0]))
- if t < cutoff and t != current:
- blather("# Deleting old file %s, timestamp %s" % (f, t))
- os.unlink(f)
+ blather("# Generating updates for protocol version %d" % version)
- pdus = AXFRSet.parse_rcynic(argv[0])
- if pdus == AXFRSet.load_current():
- blather("# No change, new version not needed")
- sys.exit()
- pdus.save_axfr()
- for axfr in glob.iglob("*.ax"):
- if axfr != pdus.filename():
- pdus.save_ixfr(AXFRSet.load(axfr))
- pdus.mark_current()
+ old_ixfrs = glob.glob("*.ix.*.v%d" % version)
- blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial))
+ current = read_current(version)[0]
+ cutoff = Timestamp.now(-(24 * 60 * 60))
+ for f in glob.iglob("*.ax.v%d" % version):
+ t = Timestamp(int(f.split(".")[0]))
+ if t < cutoff and t != current:
+ blather("# Deleting old file %s, timestamp %s" % (f, t))
+ os.unlink(f)
- kick_all(pdus.serial)
+ pdus = AXFRSet.parse_rcynic(argv[0], version)
+ if pdus == AXFRSet.load_current(version):
+ blather("# No change, new serial not needed")
+ continue
+ pdus.save_axfr()
+ for axfr in glob.iglob("*.ax.v%d" % version):
+ if axfr != pdus.filename():
+ pdus.save_ixfr(AXFRSet.load(axfr))
+ pdus.mark_current()
- old_ixfrs.sort()
- for ixfr in old_ixfrs:
- try:
- blather("# Deleting old file %s" % ixfr)
- os.unlink(ixfr)
- except OSError:
- pass
+ blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial))
+
+ kick_all(pdus.serial)
+
+ old_ixfrs.sort()
+ for ixfr in old_ixfrs:
+ try:
+ blather("# Deleting old file %s" % ixfr)
+ os.unlink(ixfr)
+ except OSError:
+ pass
def show_main(argv):
"""
@@ -1956,12 +2163,12 @@ def show_main(argv):
if argv:
sys.exit("Unexpected arguments: %r" % (argv,))
- g = glob.glob("*.ax")
+ g = glob.glob("*.ax.v*")
g.sort()
for f in g:
AXFRSet.load(f).show()
- g = glob.glob("*.ix.*")
+ g = glob.glob("*.ix.*.v*")
g.sort()
for f in g:
IXFRSet.load(f).show()
@@ -2068,8 +2275,8 @@ def listener_tcp_main(argv):
blather("[Received connection from %r]" % (ai,))
pid = os.fork()
if pid == 0:
- os.dup2(s.fileno(), 0)
- os.dup2(s.fileno(), 1)
+ os.dup2(s.fileno(), 0) # pylint: disable=E1103
+ os.dup2(s.fileno(), 1) # pylint: disable=E1103
s.close()
#os.closerange(3, os.sysconf("SC_OPEN_MAX"))
global log_tag
@@ -2082,7 +2289,7 @@ def listener_tcp_main(argv):
blather("[Spawned server %d]" % pid)
try:
while True:
- pid, status = os.waitpid(0, os.WNOHANG)
+ pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612
if pid:
blather("[Server %s exited]" % pid)
else:
@@ -2146,11 +2353,17 @@ def client_main(argv):
client.setup_sql(sqlname)
while True:
if client.current_serial is None or client.current_nonce is None:
- client.push_pdu(ResetQueryPDU())
+ client.push_pdu(ResetQueryPDU(version = client.version))
else:
- client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce))
- wakeup = time.time() + 600
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.current_serial,
+ nonce = client.current_nonce))
+ polled = Timestamp.now()
+ wakeup = None
while True:
+ if wakeup != polled + client.refresh:
+ wakeup = Timestamp(polled + client.refresh)
+ log("[Last client poll %s, next %s]" % (polled, wakeup))
remaining = wakeup - time.time()
if remaining < 0:
break
@@ -2183,10 +2396,11 @@ def bgpdump_convert_main(argv):
first = True
db = None
axfrs = []
+ version = max(PDU.version_map.iterkeys())
for filename in argv:
- if filename.endswith(".ax"):
+ if ".ax.v" in filename:
blather("Reading %s" % filename)
db = AXFRSet.load(filename)
@@ -2203,7 +2417,7 @@ def bgpdump_convert_main(argv):
sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename)
blather("DB serial now %d (%s)" % (db.serial, db.serial))
- if first and read_current() == (None, None):
+ if first and read_current(version) == (None, None):
db.mark_current()
first = False
@@ -2230,21 +2444,22 @@ def bgpdump_select_main(argv):
You have been warned.
"""
+ version = max(PDU.version_map.iterkeys())
serial = None
try:
head, sep, tail = os.path.basename(argv[0]).partition(".")
- if len(argv) == 1 and head.isdigit() and sep == "." and tail == "ax":
+ if len(argv) == 1 and head.isdigit() and sep == "." and tail.startswith("ax.v") and tail[4:].isdigit():
serial = Timestamp(head)
except:
pass
if serial is None:
sys.exit("Argument must be name of a .ax file")
- nonce = read_current()[1]
+ nonce = read_current(version)[1]
if nonce is None:
nonce = new_nonce()
- write_current(serial, nonce)
+ write_current(serial, nonce, version)
kick_all(serial)
@@ -2264,7 +2479,7 @@ class BGPDumpReplayClock(object):
"""
def __init__(self):
- self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")]
+ self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")]
self.timestamps.sort()
self.offset = self.timestamps[0] - int(time.time())
self.nonce = new_nonce()
@@ -2275,7 +2490,7 @@ class BGPDumpReplayClock(object):
def now(self):
return Timestamp.now(self.offset)
- def read_current(self):
+ def read_current(self, version):
now = self.now()
while len(self.timestamps) > 1 and now >= self.timestamps[1]:
del self.timestamps[0]
@@ -2340,6 +2555,7 @@ def bgpdump_server_main(argv):
scan_roas = None
scan_routercerts = None
force_zero_nonce = False
+debug = False
kickme_dir = "sockets"
kickme_base = os.path.join(kickme_dir, "kickme")
@@ -2360,6 +2576,8 @@ def usage(msg = None):
f.write("\n")
f.write("where options are zero or more of:\n")
f.write("\n")
+ f.write("--debug\n")
+ f.write("\n")
f.write("--scan-roas /path/to/scan_roas\n")
f.write("\n")
f.write("--scan-routercerts /path/to/scan_routercerts\n")
@@ -2385,13 +2603,15 @@ if __name__ == "__main__":
syslog_facility, syslog_warning, syslog_info = syslog.LOG_DAEMON, syslog.LOG_WARNING, syslog.LOG_INFO
- opts, argv = getopt.getopt(sys.argv[1:], "hs:z?", ["help", "scan-roas=", "scan-routercerts=",
- "syslog=", "zero-nonce"] + main_dispatch.keys())
+ opts, argv = getopt.getopt(sys.argv[1:], "dhs:z?", ["help", "debug", "scan-roas=", "scan-routercerts=",
+ "syslog=", "zero-nonce"] + main_dispatch.keys())
for o, a in opts:
if o in ("-h", "--help", "-?"):
usage()
elif o in ("-z", "--zero-nonce"):
force_zero_nonce = True
+ elif o in ("-d", "--debug"):
+ debug = True
elif o in ("-s", "--syslog"):
try:
a = [getattr(syslog, "LOG_" + i.upper()) for i in a.split(".")]
@@ -2419,7 +2639,7 @@ if __name__ == "__main__":
if mode in ("server", "bgpdump_server"):
log_tag += hostport_tag()
- if mode in ("cronjob", "server" , "bgpdump_server"):
+ if not debug and mode in ("cronjob", "server" , "bgpdump_server"):
syslog.openlog(log_tag, syslog.LOG_PID, syslog_facility)
def log(msg):
return syslog.syslog(syslog_warning, str(msg))