aboutsummaryrefslogtreecommitdiff
path: root/rp/rpki-rtr
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2014-04-18 23:50:21 +0000
committerRob Austein <sra@hactrn.net>2014-04-18 23:50:21 +0000
commit3c45a5e24353eaa4df854b0e9402c0b2e54f595a (patch)
treea9e4baf3c08a1c06e30d974e40cb4fbab8868025 /rp/rpki-rtr
parent2b60e7f371b87db763ba6a8df726fadd85d90d9a (diff)
Add support for multiple versions of rpki-rtr protocol, along with
rudimentary use of the new End Of Data timing parameters in the rpki-rtr test client. cronjob mode now generates a separate parallel database for each protocol version; server supports both current versions and picks which one to use for a particular session based on the initial client request. test client doesn't yet handle version fallback or make proper use of retry or expire parameters, but does now use refresh parameter. svn path=/trunk/; revision=5812
Diffstat (limited to 'rp/rpki-rtr')
-rwxr-xr-xrp/rpki-rtr/rtr-origin588
1 files changed, 404 insertions, 184 deletions
diff --git a/rp/rpki-rtr/rtr-origin b/rp/rpki-rtr/rtr-origin
index e127b2b2..b2b3d1ab 100755
--- a/rp/rpki-rtr/rtr-origin
+++ b/rp/rpki-rtr/rtr-origin
@@ -59,8 +59,10 @@ class Timestamp(int):
Wrapper around time module.
"""
- def __new__(cls, x):
- return int.__new__(cls, x)
+ def __new__(cls, t):
+ # http://stackoverflow.com/questions/7471255/pythons-super-and-new-confused-me
+ #return int.__new__(cls, t)
+ return super(Timestamp, cls).__new__(cls, t)
@classmethod
def now(cls, delta = 0):
@@ -70,39 +72,36 @@ class Timestamp(int):
return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self))
-def read_current():
+def read_current(version):
"""
Read current serial number and nonce. Return None for both if
serial and nonce not recorded. For backwards compatibility, treat
file containing just a serial number as having a nonce of zero.
"""
+ if version is None:
+ return None, None
try:
- f = open("current", "r")
- values = tuple(int(s) for s in f.read().split())
- f.close()
+ with open("current.v%d" % version, "r") as f:
+ values = tuple(int(s) for s in f.read().split())
return values[0], values[1]
except IndexError:
return values[0], 0
except IOError:
return None, None
-def write_current(serial, nonce):
+
+def write_current(serial, nonce, version):
"""
Write serial number and nonce.
"""
- tmpfn = "current.%d.tmp" % os.getpid()
- try:
- f = open(tmpfn, "w")
+ assert version in PDU.version_map
+ curfn = "current.v%d" % version
+ tmpfn = curfn + "%d.tmp" % os.getpid()
+ with open(tmpfn, "w") as f:
f.write("%d %d\n" % (serial, nonce))
- f.close()
- os.rename(tmpfn, "current")
- finally:
- try:
- os.unlink(tmpfn)
- except:
- pass
+ os.rename(tmpfn, curfn)
def new_nonce():
@@ -121,10 +120,14 @@ def new_nonce():
class ReadBuffer(object):
"""
Wrapper around synchronous/asynchronous read state.
+
+ This also handles tracking the current protocol version,
+ because it has to go somewhere and there's no better place.
"""
def __init__(self):
self.buffer = ""
+ self.version = None
def update(self, need, callback):
"""
@@ -133,6 +136,13 @@ class ReadBuffer(object):
self.need = need
self.callback = callback
+ return self.retry()
+
+ def retry(self):
+ """
+ Try dispatching to the callback again.
+ """
+
return self.callback(self)
def available(self):
@@ -172,12 +182,20 @@ class ReadBuffer(object):
self.buffer += b
- def retry(self):
+ def check_version(self, version):
"""
- Try dispatching to the callback again.
+ Track version number of PDUs read from this buffer.
+ Once set, the version must not change.
"""
- return self.callback(self)
+ if self.version is not None and version != self.version:
+ raise CorruptData(
+ "Received PDU version %d, expected %d" % (version, self.version))
+ if self.version is None and version not in PDU.version_map:
+ raise UnsupportedProtocolVersion(
+ "Received PDU version %d, known versions %s" % (version, ", ".PDU.version_map.iterkeys()))
+ self.version = version
+
class PDUException(Exception):
"""
@@ -188,7 +206,7 @@ class PDUException(Exception):
"""
def __init__(self, msg = None, pdu = None):
- Exception.__init__(self)
+ super(PDUException, self).__init__()
assert msg is None or isinstance(msg, (str, unicode))
self.error_report_msg = msg
self.error_report_pdu = pdu
@@ -196,10 +214,11 @@ class PDUException(Exception):
def __str__(self):
return self.error_report_msg or self.__class__.__name__
- def make_error_report(self):
- return ErrorReportPDU(errno = self.error_report_code,
- errmsg = self.error_report_msg,
- errpdu = self.error_report_pdu)
+ def make_error_report(self, version):
+ return ErrorReportPDU(version = version,
+ errno = self.error_report_code,
+ errmsg = self.error_report_msg,
+ errpdu = self.error_report_pdu)
class UnsupportedProtocolVersion(PDUException):
error_report_code = 4
@@ -211,41 +230,51 @@ class CorruptData(PDUException):
error_report_code = 0
-def wire_pdu(cls):
+def wire_pdu(cls, versions = None):
"""
- Class decorator to add a PDU class to the set of known PDUs.
-
- In the long run, this decorator may take additional arguments
- specifying which protocol version(s) use this particular PDU,
- but we're not there yet.
+ Class decorator to add a PDU class to the set of known PDUs
+ for all supported protocol versions.
"""
- PDU.pdu_map[cls.pdu_type] = cls
+ for v in PDU.version_map.iterkeys() if versions is None else versions:
+ assert cls.pdu_type not in PDU.version_map[v]
+ PDU.version_map[v][cls.pdu_type] = cls
return cls
+def wire_pdu_only(*versions):
+ """
+ Class decorator to add a PDU class to the set of known PDUs
+ for specific protocol versions.
+ """
+
+ assert versions and all(v in PDU.version_map for v in versions)
+ return lambda cls: wire_pdu(cls, versions)
+
class PDU(object):
"""
Object representing a generic PDU in the rpki-router protocol.
Real PDUs are subclasses of this class.
"""
- pdu_map = {} # Updated by @wire_pdu
-
- version = 0 # Protocol version
+ version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu
_pdu = None # Cached when first generated
- header_struct = struct.Struct("!BBHL")
+ header_struct = struct.Struct("!BB2xL")
+
+ def __init__(self, version):
+ assert version in self.version_map
+ self.version = version
def __cmp__(self, other):
return cmp(self.to_pdu(), other.to_pdu())
- def check(self):
- """
- Check attributes to make sure they're within range.
- """
+ @property
+ def default_version(self):
+ return max(self.version_map.iterkeys())
+ def check(self):
pass
@classmethod
@@ -257,23 +286,21 @@ class PDU(object):
if not reader.ready():
return None
assert reader.available() >= cls.header_struct.size
- version, pdu_type, whatever, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
- if version != cls.version:
- raise UnsupportedProtocolVersion(
- "Received PDU version %d, expected %d" % (version, cls.version))
- if pdu_type not in cls.pdu_map:
+ version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
+ reader.check_version(version)
+ if pdu_type not in cls.version_map[version]:
raise UnsupportedPDUType(
"Received unsupported PDU type %d" % pdu_type)
if length < 8:
raise CorruptData(
"Received PDU with length %d, which is too short to be valid" % length)
- self = cls.pdu_map[pdu_type]()
+ self = cls.version_map[version][pdu_type](version = version)
return reader.update(need = length, callback = self.got_pdu)
def consume(self, client):
"""
Handle results in test client. Default behavior is just to print
- out the PDU.
+ out the PDU; data PDU subclasses may override this.
"""
blather(self)
@@ -283,17 +310,25 @@ class PDU(object):
Send a content of a file as a cache response. Caller should catch IOError.
"""
+ fn2 = os.path.splitext(filename)[1]
+ assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version
+
f = open(filename, "rb")
- server.push_pdu(CacheResponsePDU(nonce = server.current_nonce))
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
server.push_file(f)
- server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce))
def send_nodata(self, server):
"""
Send a nodata error.
"""
- server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["No Data Available"], errpdu = self))
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["No Data Available"],
+ errpdu = self))
class PDUWithSerial(PDU):
"""
@@ -302,7 +337,8 @@ class PDUWithSerial(PDU):
header_struct = struct.Struct("!BBHLL")
- def __init__(self, serial = None, nonce = None):
+ def __init__(self, version, serial = None, nonce = None):
+ super(PDUWithSerial, self).__init__(version)
if serial is not None:
assert isinstance(serial, int)
self.serial = serial
@@ -328,6 +364,7 @@ class PDUWithSerial(PDU):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
if length != 12:
raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
@@ -340,7 +377,8 @@ class PDUWithNonce(PDU):
header_struct = struct.Struct("!BBHL")
- def __init__(self, nonce = None):
+ def __init__(self, version, nonce = None):
+ super(PDUWithNonce, self).__init__(version)
if nonce is not None:
assert isinstance(nonce, int)
self.nonce = nonce
@@ -362,6 +400,7 @@ class PDUWithNonce(PDU):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
if length != 8:
raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
@@ -391,6 +430,7 @@ class PDUEmpty(PDU):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, zero, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
if zero != 0:
raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
if length != 8:
@@ -414,9 +454,11 @@ class SerialNotifyPDU(PDUWithSerial):
blather(self)
if client.current_serial is None or client.current_nonce != self.nonce:
- client.push_pdu(ResetQueryPDU())
+ client.push_pdu(ResetQueryPDU(version = client.version))
elif self.serial != client.current_serial:
- client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce))
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.current_serial,
+ nonce = client.current_nonce))
else:
blather("[Notify did not change serial number, ignoring]")
@@ -428,6 +470,9 @@ class SerialQueryPDU(PDUWithSerial):
pdu_type = 1
+ def __init__(self, version, serial = None, nonce = None):
+ super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce)
+
def serve(self, server):
"""
Received a serial query, send incremental transfer in response.
@@ -440,18 +485,21 @@ class SerialQueryPDU(PDUWithSerial):
self.send_nodata(server)
elif server.current_nonce != self.nonce:
log("[Client requested wrong nonce, resetting client]")
- server.push_pdu(CacheResetPDU())
+ server.push_pdu(CacheResetPDU(version = server.version))
elif server.current_serial == self.serial:
blather("[Client is already current, sending empty IXFR]")
- server.push_pdu(CacheResponsePDU(nonce = server.current_nonce))
- server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce))
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce))
elif disable_incrementals:
- server.push_pdu(CacheResetPDU())
+ server.push_pdu(CacheResetPDU(version = server.version))
else:
try:
- self.send_file(server, "%d.ix.%d" % (server.current_serial, self.serial))
+ self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version))
except IOError:
- server.push_pdu(CacheResetPDU())
+ server.push_pdu(CacheResetPDU(version = server.version))
@wire_pdu
class ResetQueryPDU(PDUEmpty):
@@ -461,6 +509,9 @@ class ResetQueryPDU(PDUEmpty):
pdu_type = 2
+ def __init__(self, version):
+ super(ResetQueryPDU, self).__init__(self.default_version if version is None else version)
+
def serve(self, server):
"""
Received a reset query, send full current state in response.
@@ -471,11 +522,13 @@ class ResetQueryPDU(PDUEmpty):
self.send_nodata(server)
else:
try:
- fn = "%d.ax" % server.current_serial
+ fn = "%d.ax.v%d" % (server.current_serial, server.version)
self.send_file(server, fn)
except IOError:
- server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["Internal Error"],
- errpdu = self, errmsg = "Couldn't open %s" % fn))
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["Internal Error"],
+ errpdu = self,
+ errmsg = "Couldn't open %s" % fn))
@wire_pdu
class CacheResponsePDU(PDUWithNonce):
@@ -495,21 +548,98 @@ class CacheResponsePDU(PDUWithNonce):
blather("[Nonce changed, resetting]")
client.cache_reset()
-@wire_pdu
-class EndOfDataPDU(PDUWithSerial):
+
+def EndOfDataPDU(version, *args, **kwargs):
"""
- End of Data PDU.
+ Factory for the EndOfDataPDU classes, which take different forms in
+ different protocol versions.
+ """
+
+ if version == 0:
+ return EndOfDataPDUv0(version, *args, **kwargs)
+ if version == 1:
+ return EndOfDataPDUv1(version, *args, **kwargs)
+ raise NotImplementedError
+
+
+@wire_pdu_only(0)
+class EndOfDataPDUv0(PDUWithSerial):
+ """
+ End of Data PDU, protocol version 0.
"""
pdu_type = 7
+ # Default values, from the current RFC 6810 bis I-D.
+ # Putting these here lets us use them in our client API for both
+ # protocol versions, even though they can only be set in the
+ # protocol in version 1.
+
+ refresh = 3600
+ retry = 600
+ expire = 7200
+
def consume(self, client):
"""
Handle EndOfDataPDU response.
"""
blather(self)
- client.end_of_data(self.serial, self.nonce)
+ client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+
+@wire_pdu_only(1)
+class EndOfDataPDUv1(EndOfDataPDUv0):
+ """
+ End of Data PDU, protocol version 1.
+ """
+
+ header_struct = struct.Struct("!BBHLLLLL")
+
+ def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None):
+ super(EndOfDataPDUv1, self).__init__(version)
+ if serial is not None:
+ assert isinstance(serial, int)
+ self.serial = serial
+ if nonce is not None:
+ assert isinstance(nonce, int)
+ self.nonce = nonce
+ if refresh is not None:
+ assert isinstance(refresh, int)
+ self.refresh = refresh
+ if retry is not None:
+ assert isinstance(retry, int)
+ self.retry = retry
+ if expire is not None:
+ assert isinstance(expire, int)
+ self.expire = expire
+
+ def __str__(self):
+ return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % (
+ self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
+
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
+ self.header_struct.size, self.serial,
+ self.refresh, self.retry, self.expire)
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \
+ = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 24:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
+
@wire_pdu
class CacheResetPDU(PDUEmpty):
@@ -526,7 +656,7 @@ class CacheResetPDU(PDUEmpty):
blather(self)
client.cache_reset()
- client.push_pdu(ResetQueryPDU())
+ client.push_pdu(ResetQueryPDU(version = client.version))
class PrefixPDU(PDU):
"""
@@ -543,14 +673,14 @@ class PrefixPDU(PDU):
asnum_struct = struct.Struct("!L")
@staticmethod
- def from_text(asnum, addr):
+ def from_text(version, asn, addr):
"""
Construct a prefix from its text form.
"""
cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
- self = cls()
- self.asn = long(asnum)
+ self = cls(version = version)
+ self.asn = long(asn)
p, l = addr.split("/")
self.prefix = rpki.POW.IPAddress(p)
if "-" in l:
@@ -562,15 +692,15 @@ class PrefixPDU(PDU):
return self
@staticmethod
- def from_roa(asnum, prefix_tuple):
+ def from_roa(version, asn, prefix_tuple):
"""
Construct a prefix from a ROA.
"""
address, length, maxlength = prefix_tuple
cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
- self = cls()
- self.asn = asnum
+ self = cls(version = version)
+ self.asn = asn
self.prefix = address
self.prefixlen = length
self.max_prefixlen = length if maxlength is None else maxlength
@@ -643,6 +773,7 @@ class PrefixPDU(PDU):
b2 = reader.get(self.address_byte_count)
b3 = reader.get(self.asnum_struct.size)
version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
+ assert version == self.version and pdu_type == self.pdu_type
if length != len(b1) + len(b2) + len(b3):
raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
self.prefix = rpki.POW.IPAddress.fromBytes(b2)
@@ -712,7 +843,7 @@ class IPv6PrefixPDU(PrefixPDU):
pdu_type = 6
address_byte_count = 16
-@wire_pdu
+@wire_pdu_only(1)
class RouterKeyPDU(PDU):
"""
Router Key PDU.
@@ -723,13 +854,13 @@ class RouterKeyPDU(PDU):
header_struct = struct.Struct("!BBBxL20sL")
@classmethod
- def from_text(cls, asnum, gski, key):
+ def from_text(cls, version, asn, gski, key):
"""
Construct a router key from its text form.
"""
- self = cls()
- self.asn = long(asnum)
+ self = cls(version = version)
+ self.asn = long(asn)
self.ski = base64.urlsafe_b64decode(gski + "=")
self.key = base64.b64decode(key)
self.announce = 1
@@ -737,13 +868,13 @@ class RouterKeyPDU(PDU):
return self
@classmethod
- def from_certificate(cls, asnum, ski, key):
+ def from_certificate(cls, version, asn, ski, key):
"""
Construct a router key from a certificate.
"""
- self = cls()
- self.asn = asnum
+ self = cls(version = version)
+ self.asn = asn
self.ski = ski
self.key = key
self.announce = 1
@@ -799,6 +930,7 @@ class RouterKeyPDU(PDU):
return None
header = reader.get(self.header_struct.size)
version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
remaining = length - self.header_struct.size
if remaining <= 0:
raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self)
@@ -836,7 +968,8 @@ class ErrorReportPDU(PDU):
codes = dict((v, k) for k, v in errors.items())
- def __init__(self, errno = None, errpdu = None, errmsg = None):
+ def __init__(self, version, errno = None, errpdu = None, errmsg = None):
+ super(ErrorReportPDU, self).__init__(version)
assert errno is None or errno in self.errors
self.errno = errno
self.errpdu = errpdu
@@ -879,6 +1012,7 @@ class ErrorReportPDU(PDU):
return None
header = reader.get(self.header_struct.size)
version, pdu_type, self.errno, length = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
remaining = length - self.header_struct.size
self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
@@ -903,7 +1037,7 @@ class ErrorReportPDU(PDU):
sys.exit(1)
-class ROA(rpki.POW.ROA):
+class ROA(rpki.POW.ROA): # pylint: disable=W0232
"""
Minor additions to rpki.POW.ROA.
"""
@@ -924,7 +1058,7 @@ class ROA(rpki.POW.ROA):
for p in v6:
yield p
-class X509(rpki.POW.X509):
+class X509(rpki.POW.X509): # pylint: disable=W0232
"""
Minor additions to rpki.POW.X509.
"""
@@ -945,13 +1079,18 @@ class PDUSet(list):
from rcynic's output.
"""
+ def __init__(self, version):
+ assert version in PDU.version_map
+ super(PDUSet, self).__init__()
+ self.version = version
+
@classmethod
- def _load_file(cls, filename):
+ def _load_file(cls, filename, version):
"""
Low-level method to read PDUSet from a file.
"""
- self = cls()
+ self = cls(version = version)
f = open(filename, "rb")
r = ReadBuffer()
while True:
@@ -963,6 +1102,7 @@ class PDUSet(list):
return self
r.put(b)
p = r.retry()
+ assert p.version == self.version
self.append(p)
@staticmethod
@@ -979,7 +1119,7 @@ class AXFRSet(PDUSet):
"""
@classmethod
- def parse_rcynic(cls, rcynic_dir):
+ def parse_rcynic(cls, rcynic_dir, version):
"""
Parse ROAS and router certificates fetched (and validated!) by
rcynic to create a new AXFRSet.
@@ -994,24 +1134,26 @@ class AXFRSet(PDUSet):
can make this one a bit simpler and faster.
"""
- self = cls()
+ self = cls(version = version)
self.serial = Timestamp.now()
- if scan_roas is None or scan_routercerts is None:
- for root, dirs, files in os.walk(rcynic_dir):
+ include_routercerts = RouterKeyPDU.pdu_type in PDU.version_map[version]
+
+ if scan_roas is None or (scan_routercerts is None and include_routercerts):
+ for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612
for fn in files:
if scan_roas is None and fn.endswith(".roa"):
roa = ROA.derReadFile(os.path.join(root, fn))
asn = roa.getASID()
- self.extend(PrefixPDU.from_roa(asn, roa_prefix)
- for roa_prefix in roa.prefixes)
- if scan_routercerts is None and fn.endswith(".cer"):
+ self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple)
+ for prefix_tuple in roa.prefixes)
+ if include_routercerts and scan_routercerts is None and fn.endswith(".cer"):
x = X509.derReadFile(os.path.join(root, fn))
eku = x.getEKU()
if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
ski = x.getSKI()
key = x.getPublicKey().derWritePublic()
- self.extend(RouterKeyPDU.from_certificate(asn, ski, key)
+ self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key)
for asn in x.asns)
if scan_roas is not None:
@@ -1020,18 +1162,20 @@ class AXFRSet(PDUSet):
for line in p.stdout:
line = line.split()
asn = line[1]
- self.extend(PrefixPDU.from_text(asn, addr) for addr in line[2:])
+ self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr)
+ for addr in line[2:])
except OSError, e:
sys.exit("Could not run %s: %s" % (scan_roas, e))
- if scan_routercerts is not None:
+ if include_routercerts and scan_routercerts is not None:
try:
p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE)
for line in p.stdout:
line = line.split()
gski = line[0]
key = line[-1]
- self.extend(RouterKeyPDU.from_text(asn, gski, key) for asn in line[1:-1])
+ self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key)
+ for asn in line[1:-1])
except OSError, e:
sys.exit("Could not run %s: %s" % (scan_routercerts, e))
@@ -1044,12 +1188,13 @@ class AXFRSet(PDUSet):
@classmethod
def load(cls, filename):
"""
- Load an AXFRSet from a file, parse filename to obtain serial.
+ Load an AXFRSet from a file, parse filename to obtain version and serial.
"""
- fn1, fn2 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ax"
- self = cls._load_file(filename)
+ fn1, fn2, fn3 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit()
+ version = int(fn3[1:])
+ self = cls._load_file(filename, version)
self.serial = Timestamp(fn1)
return self
@@ -1058,19 +1203,19 @@ class AXFRSet(PDUSet):
Generate filename for this AXFRSet.
"""
- return "%d.ax" % self.serial
+ return "%d.ax.v%d" % (self.serial, self.version)
@classmethod
- def load_current(cls):
+ def load_current(cls, version):
"""
Load current AXFRSet. Return None if can't.
"""
- serial = read_current()[0]
+ serial = read_current(version)[0]
if serial is None:
return None
try:
- return cls.load("%d.ax" % serial)
+ return cls.load("%d.ax.v%d" % (serial, version))
except IOError:
return None
@@ -1090,9 +1235,9 @@ class AXFRSet(PDUSet):
the old serial numbers are no longer valid.
"""
- for i in glob.iglob("*.ix.*"):
+ for i in glob.iglob("*.ix.*.v%d" % self.version):
os.unlink(i)
- for i in glob.iglob("*.ax"):
+ for i in glob.iglob("*.ax.v%d" % self.version):
if i != self.filename():
os.unlink(i)
@@ -1103,12 +1248,12 @@ class AXFRSet(PDUSet):
the new nonce invalidates all old serial numbers.
"""
- old_serial, nonce = read_current()
+ old_serial, nonce = read_current(self.version)
if old_serial is None or self.seq_ge(old_serial, self.serial):
blather("Creating new nonce and deleting stale data")
nonce = new_nonce()
self.destroy_old_data()
- write_current(self.serial, nonce)
+ write_current(self.serial, nonce, self.version)
def save_ixfr(self, other):
"""
@@ -1118,7 +1263,7 @@ class AXFRSet(PDUSet):
comparison.
"""
- f = open("%d.ix.%d" % (self.serial, other.serial), "wb")
+ f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb")
old = other
new = self
len_old = len(old)
@@ -1145,7 +1290,7 @@ class AXFRSet(PDUSet):
Print this AXFRSet.
"""
- blather("# AXFR %d (%s)" % (self.serial, self.serial))
+ blather("# AXFR %d (%s) v%d" % (self.serial, self.serial, self.version))
for p in self:
blather(p)
@@ -1207,12 +1352,13 @@ class IXFRSet(PDUSet):
@classmethod
def load(cls, filename):
"""
- Load an IXFRSet from a file, parse filename to obtain serials.
+ Load an IXFRSet from a file, parse filename to obtain version and serials.
"""
- fn1, fn2, fn3 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit()
- self = cls._load_file(filename)
+ fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit()
+ version = int(fn4[1:])
+ self = cls._load_file(filename, version)
self.from_serial = Timestamp(fn3)
self.to_serial = Timestamp(fn1)
return self
@@ -1222,15 +1368,16 @@ class IXFRSet(PDUSet):
Generate filename for this IXFRSet.
"""
- return "%d.ix.%d" % (self.to_serial, self.from_serial)
+ return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version)
def show(self):
"""
Print this IXFRSet.
"""
- blather("# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial,
- self.to_serial, self.to_serial))
+ blather("# IXFR %d (%s) -> %d (%s) v%d" % (self.from_serial, self.from_serial,
+ self.to_serial, self.to_serial,
+ self.version))
for p in self:
blather(p)
@@ -1246,7 +1393,7 @@ class FileProducer(object):
def more(self):
return self.handle.read(self.buffersize)
-class PDUChannel(asynchat.async_chat):
+class PDUChannel(asynchat.async_chat, object):
"""
asynchat subclass that understands our PDUs. This just handles
network I/O. Specific engines (client, server) should be subclasses
@@ -1255,9 +1402,17 @@ class PDUChannel(asynchat.async_chat):
"""
def __init__(self, conn = None):
- asynchat.async_chat.__init__(self, conn)
+ asynchat.async_chat.__init__(self, conn) # Old-style class
self.reader = ReadBuffer()
+ @property
+ def version(self):
+ return self.reader.version
+
+ @version.setter
+ def version(self, version):
+ self.reader.check_version(version)
+
def start_new_pdu(self):
"""
Start read of a new PDU.
@@ -1269,7 +1424,7 @@ class PDUChannel(asynchat.async_chat):
self.deliver_pdu(p)
p = PDU.read_pdu(self.reader)
except PDUException, e:
- self.push_pdu(e.make_error_report())
+ self.push_pdu(e.make_error_report(version = self.version))
self.close_when_done()
else:
assert not self.reader.ready()
@@ -1382,7 +1537,7 @@ class ServerWriteChannel(PDUChannel):
Set up stdout.
"""
- PDUChannel.__init__(self)
+ super(ServerWriteChannel, self).__init__()
self.init_file_dispatcher(sys.stdout.fileno())
def readable(self):
@@ -1392,6 +1547,7 @@ class ServerWriteChannel(PDUChannel):
return False
+
class ServerChannel(PDUChannel):
"""
Server protocol engine, handles upcalls from PDUChannel to
@@ -1404,7 +1560,7 @@ class ServerChannel(PDUChannel):
first PDU.
"""
- PDUChannel.__init__(self)
+ super(ServerChannel, self).__init__()
self.init_file_dispatcher(sys.stdin.fileno())
self.writer = ServerWriteChannel()
self.get_serial()
@@ -1460,7 +1616,7 @@ class ServerChannel(PDUChannel):
mode instance is still building its database.
"""
- self.current_serial, self.current_nonce = read_current()
+ self.current_serial, self.current_nonce = read_current(self.version)
return self.current_serial
def check_serial(self):
@@ -1473,13 +1629,21 @@ class ServerChannel(PDUChannel):
def notify(self, data = None):
"""
- Cronjob instance kicked us, send a notify message.
+ Cronjob instance kicked us: check whether our serial number has
+ changed, and send a notify message if so.
+
+ We have to check rather than just blindly notifying when kicked
+ because the cronjob instance has no good way of knowing which
+ protocol version we're running, thus has no good way of knowing
+ whether we care about a particular change set or not.
"""
- if self.check_serial() is not None:
- self.push_pdu(SerialNotifyPDU(serial = self.current_serial, nonce = self.current_nonce))
+ if self.check_serial():
+ self.push_pdu(SerialNotifyPDU(version = self.version,
+ serial = self.current_serial,
+ nonce = self.current_nonce))
else:
- log("Cronjob kicked me without a valid current serial number")
+ log("Cronjob kicked me but I see no serial change, ignoring")
class ClientChannel(PDUChannel):
"""
@@ -1493,12 +1657,21 @@ class ClientChannel(PDUChannel):
port = None
cache_id = None
+ # For initial test purposes, let's use the minimum allowed values
+ # from the RFC 6810 bis I-D as the initial defaults for refresh and
+ # retry, and the maximum allowed for expire; these will be overriden
+ # as soon as we receive an EndOfDataPDU.
+ #
+ refresh = 120
+ retry = 120
+ expire = 172800
+
def __init__(self, sock, proc, killsig, host, port):
self.killsig = killsig
self.proc = proc
self.host = host
self.port = port
- PDUChannel.__init__(self, conn = sock)
+ super(ClientChannel, self).__init__(conn = sock)
self.start_new_pdu()
@classmethod
@@ -1529,7 +1702,7 @@ class ClientChannel(PDUChannel):
blather("[socket.getaddrinfo() failed: %s]" % e)
else:
for ai in addrinfo:
- af, socktype, proto, cn, sa = ai
+ af, socktype, proto, cn, sa = ai # pylint: disable=W0612
blather("[Trying addr %s port %s]" % sa[:2])
try:
s = socket.socket(af, socktype, proto)
@@ -1601,9 +1774,13 @@ class ClientChannel(PDUChannel):
cache_id INTEGER PRIMARY KEY NOT NULL,
host TEXT NOT NULL,
port TEXT NOT NULL,
+ version INTEGER,
nonce INTEGER,
serial INTEGER,
updated INTEGER,
+ refresh INTEGER,
+ retry INTEGER,
+ expire INTEGER,
UNIQUE (host, port))''')
cur.execute('''
CREATE TABLE prefix (
@@ -1629,14 +1806,26 @@ class ClientChannel(PDUChannel):
UNIQUE (cache_id, asn, ski),
UNIQUE (cache_id, asn, key))''')
- cur.execute("SELECT cache_id, nonce, serial FROM cache WHERE host = ? AND port = ?",
+ cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire "
+ "FROM cache WHERE host = ? AND port = ?",
(self.host, self.port))
try:
- self.cache_id, self.current_nonce, self.current_serial = cur.fetchone()
+ self.cache_id, version, self.current_nonce, self.current_serial, refresh, retry, expire = cur.fetchone()
+ if version is not None:
+ self.version = version
+ if refresh is not None:
+ self.refresh = refresh
+ if retry is not None:
+ self.retry = retry
+ if expire is not None:
+ self.expire = expire
except TypeError:
cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port))
self.cache_id = cur.lastrowid
self.sql.commit()
+ log("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s]" % (
+ self.cache_id, self.version, self.current_nonce, self.current_serial, self.refresh, self.retry, self.expire))
+
def cache_reset(self):
"""
@@ -1645,20 +1834,34 @@ class ClientChannel(PDUChannel):
self.current_serial = None
if self.sql:
+ #
+ # For some reason there was no commit here. Dunno why.
+ # See if adding one breaks anything....
+ #
cur = self.sql.cursor()
cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,))
- cur.execute("UPDATE cache SET serial = NULL WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id))
+ self.sql.commit()
- def end_of_data(self, serial, nonce):
+ def end_of_data(self, version, serial, nonce, refresh, retry, expire):
"""
Handle EndOfDataPDU actions.
"""
- self.current_serial = serial
- self.current_nonce = nonce
+ assert version == self.version
+ self.current_serial = serial
+ self.current_nonce = nonce
+ self.refresh = refresh
+ self.retry = retry
+ self.expire = expire
if self.sql:
- self.sql.execute("UPDATE cache SET serial = ?, nonce = ?, updated = datetime('now') WHERE cache_id = ?",
- (serial, nonce, self.cache_id))
+ self.sql.execute("UPDATE cache SET"
+ " version = ?, serial = ?, nonce = ?,"
+ " refresh = ?, retry = ?, expire = ?,"
+ " updated = datetime('now') "
+ "WHERE cache_id = ?",
+ (version, serial, nonce, refresh, retry, expire, self.cache_id))
self.sql.commit()
def consume_prefix(self, prefix):
@@ -1733,14 +1936,14 @@ class ClientChannel(PDUChannel):
blather("Server closed channel")
PDUChannel.handle_close(self)
-class KickmeChannel(asyncore.dispatcher):
+class KickmeChannel(asyncore.dispatcher, object):
"""
asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
kick servers when it's time to send notify PDUs to clients.
"""
def __init__(self, server):
- asyncore.dispatcher.__init__(self)
+ asyncore.dispatcher.__init__(self) # Old-style class
self.server = server
self.sockname = "%s.%d" % (kickme_base, os.getpid())
self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM)
@@ -1912,37 +2115,41 @@ def cronjob_main(argv):
if len(argv) != 1:
sys.exit("Expected one argument, got %r" % (argv,))
- old_ixfrs = glob.glob("*.ix.*")
+ for version in sorted(PDU.version_map.iterkeys(), reverse = True):
- current = read_current()[0]
- cutoff = Timestamp.now(-(24 * 60 * 60))
- for f in glob.iglob("*.ax"):
- t = Timestamp(int(f.split(".")[0]))
- if t < cutoff and t != current:
- blather("# Deleting old file %s, timestamp %s" % (f, t))
- os.unlink(f)
+ blather("# Generating updates for protocol version %d" % version)
- pdus = AXFRSet.parse_rcynic(argv[0])
- if pdus == AXFRSet.load_current():
- blather("# No change, new version not needed")
- sys.exit()
- pdus.save_axfr()
- for axfr in glob.iglob("*.ax"):
- if axfr != pdus.filename():
- pdus.save_ixfr(AXFRSet.load(axfr))
- pdus.mark_current()
+ old_ixfrs = glob.glob("*.ix.*.v%d" % version)
- blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial))
+ current = read_current(version)[0]
+ cutoff = Timestamp.now(-(24 * 60 * 60))
+ for f in glob.iglob("*.ax.v%d" % version):
+ t = Timestamp(int(f.split(".")[0]))
+ if t < cutoff and t != current:
+ blather("# Deleting old file %s, timestamp %s" % (f, t))
+ os.unlink(f)
- kick_all(pdus.serial)
+ pdus = AXFRSet.parse_rcynic(argv[0], version)
+ if pdus == AXFRSet.load_current(version):
+ blather("# No change, new serial not needed")
+ continue
+ pdus.save_axfr()
+ for axfr in glob.iglob("*.ax.v%d" % version):
+ if axfr != pdus.filename():
+ pdus.save_ixfr(AXFRSet.load(axfr))
+ pdus.mark_current()
- old_ixfrs.sort()
- for ixfr in old_ixfrs:
- try:
- blather("# Deleting old file %s" % ixfr)
- os.unlink(ixfr)
- except OSError:
- pass
+ blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial))
+
+ kick_all(pdus.serial)
+
+ old_ixfrs.sort()
+ for ixfr in old_ixfrs:
+ try:
+ blather("# Deleting old file %s" % ixfr)
+ os.unlink(ixfr)
+ except OSError:
+ pass
def show_main(argv):
"""
@@ -1956,12 +2163,12 @@ def show_main(argv):
if argv:
sys.exit("Unexpected arguments: %r" % (argv,))
- g = glob.glob("*.ax")
+ g = glob.glob("*.ax.v*")
g.sort()
for f in g:
AXFRSet.load(f).show()
- g = glob.glob("*.ix.*")
+ g = glob.glob("*.ix.*.v*")
g.sort()
for f in g:
IXFRSet.load(f).show()
@@ -2068,8 +2275,8 @@ def listener_tcp_main(argv):
blather("[Received connection from %r]" % (ai,))
pid = os.fork()
if pid == 0:
- os.dup2(s.fileno(), 0)
- os.dup2(s.fileno(), 1)
+ os.dup2(s.fileno(), 0) # pylint: disable=E1103
+ os.dup2(s.fileno(), 1) # pylint: disable=E1103
s.close()
#os.closerange(3, os.sysconf("SC_OPEN_MAX"))
global log_tag
@@ -2082,7 +2289,7 @@ def listener_tcp_main(argv):
blather("[Spawned server %d]" % pid)
try:
while True:
- pid, status = os.waitpid(0, os.WNOHANG)
+ pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612
if pid:
blather("[Server %s exited]" % pid)
else:
@@ -2146,11 +2353,17 @@ def client_main(argv):
client.setup_sql(sqlname)
while True:
if client.current_serial is None or client.current_nonce is None:
- client.push_pdu(ResetQueryPDU())
+ client.push_pdu(ResetQueryPDU(version = client.version))
else:
- client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce))
- wakeup = time.time() + 600
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.current_serial,
+ nonce = client.current_nonce))
+ polled = Timestamp.now()
+ wakeup = None
while True:
+ if wakeup != polled + client.refresh:
+ wakeup = Timestamp(polled + client.refresh)
+ log("[Last client poll %s, next %s]" % (polled, wakeup))
remaining = wakeup - time.time()
if remaining < 0:
break
@@ -2183,10 +2396,11 @@ def bgpdump_convert_main(argv):
first = True
db = None
axfrs = []
+ version = max(PDU.version_map.iterkeys())
for filename in argv:
- if filename.endswith(".ax"):
+ if ".ax.v" in filename:
blather("Reading %s" % filename)
db = AXFRSet.load(filename)
@@ -2203,7 +2417,7 @@ def bgpdump_convert_main(argv):
sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename)
blather("DB serial now %d (%s)" % (db.serial, db.serial))
- if first and read_current() == (None, None):
+ if first and read_current(version) == (None, None):
db.mark_current()
first = False
@@ -2230,21 +2444,22 @@ def bgpdump_select_main(argv):
You have been warned.
"""
+ version = max(PDU.version_map.iterkeys())
serial = None
try:
head, sep, tail = os.path.basename(argv[0]).partition(".")
- if len(argv) == 1 and head.isdigit() and sep == "." and tail == "ax":
+ if len(argv) == 1 and head.isdigit() and sep == "." and tail.startswith("ax.v") and tail[4:].isdigit():
serial = Timestamp(head)
except:
pass
if serial is None:
sys.exit("Argument must be name of a .ax file")
- nonce = read_current()[1]
+ nonce = read_current(version)[1]
if nonce is None:
nonce = new_nonce()
- write_current(serial, nonce)
+ write_current(serial, nonce, version)
kick_all(serial)
@@ -2264,7 +2479,7 @@ class BGPDumpReplayClock(object):
"""
def __init__(self):
- self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")]
+ self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")]
self.timestamps.sort()
self.offset = self.timestamps[0] - int(time.time())
self.nonce = new_nonce()
@@ -2275,7 +2490,7 @@ class BGPDumpReplayClock(object):
def now(self):
return Timestamp.now(self.offset)
- def read_current(self):
+ def read_current(self, version):
now = self.now()
while len(self.timestamps) > 1 and now >= self.timestamps[1]:
del self.timestamps[0]
@@ -2340,6 +2555,7 @@ def bgpdump_server_main(argv):
scan_roas = None
scan_routercerts = None
force_zero_nonce = False
+debug = False
kickme_dir = "sockets"
kickme_base = os.path.join(kickme_dir, "kickme")
@@ -2360,6 +2576,8 @@ def usage(msg = None):
f.write("\n")
f.write("where options are zero or more of:\n")
f.write("\n")
+ f.write("--debug\n")
+ f.write("\n")
f.write("--scan-roas /path/to/scan_roas\n")
f.write("\n")
f.write("--scan-routercerts /path/to/scan_routercerts\n")
@@ -2385,13 +2603,15 @@ if __name__ == "__main__":
syslog_facility, syslog_warning, syslog_info = syslog.LOG_DAEMON, syslog.LOG_WARNING, syslog.LOG_INFO
- opts, argv = getopt.getopt(sys.argv[1:], "hs:z?", ["help", "scan-roas=", "scan-routercerts=",
- "syslog=", "zero-nonce"] + main_dispatch.keys())
+ opts, argv = getopt.getopt(sys.argv[1:], "dhs:z?", ["help", "debug", "scan-roas=", "scan-routercerts=",
+ "syslog=", "zero-nonce"] + main_dispatch.keys())
for o, a in opts:
if o in ("-h", "--help", "-?"):
usage()
elif o in ("-z", "--zero-nonce"):
force_zero_nonce = True
+ elif o in ("-d", "--debug"):
+ debug = True
elif o in ("-s", "--syslog"):
try:
a = [getattr(syslog, "LOG_" + i.upper()) for i in a.split(".")]
@@ -2419,7 +2639,7 @@ if __name__ == "__main__":
if mode in ("server", "bgpdump_server"):
log_tag += hostport_tag()
- if mode in ("cronjob", "server" , "bgpdump_server"):
+ if not debug and mode in ("cronjob", "server" , "bgpdump_server"):
syslog.openlog(log_tag, syslog.LOG_PID, syslog_facility)
def log(msg):
return syslog.syslog(syslog_warning, str(msg))