aboutsummaryrefslogtreecommitdiff
path: root/rp
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2014-04-17 22:27:29 +0000
committerRob Austein <sra@hactrn.net>2014-04-17 22:27:29 +0000
commitbcd5cf161f3b28ee3ecc37c88aa7768e7ffb7e00 (patch)
tree8c277bbc8a555104d0747620ccb9a22f0b46fe41 /rp
parent85589d2c84ce1eb91c04ef7534db6a303f28297a (diff)
Use class decorator to construct PDU dispatch list (preparation for
supporting multiple protocol versions). Start dragging coding standard up to something a little more recent. svn path=/trunk/; revision=5810
Diffstat (limited to 'rp')
-rwxr-xr-xrp/rpki-rtr/rtr-origin371
1 files changed, 239 insertions, 132 deletions
diff --git a/rp/rpki-rtr/rtr-origin b/rp/rpki-rtr/rtr-origin
index 06ae2ee4..c676e0d7 100755
--- a/rp/rpki-rtr/rtr-origin
+++ b/rp/rpki-rtr/rtr-origin
@@ -2,19 +2,19 @@
# Router origin-authentication rpki-router protocol implementation. See
# draft-ietf-sidr-rpki-rtr in fine Internet-Draft repositories near you.
-#
+#
# Run the program with the --help argument for usage information, or see
# documentation for the *_main() functions.
#
-#
+#
# $Id$
-#
+#
# Copyright (C) 2009-2013 Internet Systems Consortium ("ISC")
-#
+#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
-#
+#
# THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS. IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
@@ -54,7 +54,7 @@ class IgnoreThisRecord(Exception):
pass
-class timestamp(int):
+class Timestamp(int):
"""
Wrapper around time module.
"""
@@ -102,6 +102,7 @@ def read_current():
serial and nonce not recorded. For backwards compatibility, treat
file containing just a serial number as having a nonce of zero.
"""
+
try:
f = open("current", "r")
values = tuple(int(s) for s in f.read().split())
@@ -116,6 +117,7 @@ def write_current(serial, nonce):
"""
Write serial number and nonce.
"""
+
tmpfn = "current.%d.tmp" % os.getpid()
try:
f = open(tmpfn, "w")
@@ -133,6 +135,7 @@ def new_nonce():
"""
Create and return a new nonce value.
"""
+
if force_zero_nonce:
return 0
try:
@@ -141,7 +144,7 @@ def new_nonce():
return int(random.getrandbits(16))
-class read_buffer(object):
+class ReadBuffer(object):
"""
Wrapper around synchronous/asynchronous read state.
"""
@@ -153,6 +156,7 @@ class read_buffer(object):
"""
Update count of needed bytes and callback, then dispatch to callback.
"""
+
self.need = need
self.callback = callback
return self.callback(self)
@@ -161,24 +165,28 @@ class read_buffer(object):
"""
How much data do we have available in this buffer?
"""
+
return len(self.buffer)
def needed(self):
"""
How much more data does this buffer need to become ready?
"""
+
return self.need - self.available()
def ready(self):
"""
Is this buffer ready to read yet?
"""
+
return self.available() >= self.need
def get(self, n):
"""
Hand some data to the caller.
"""
+
b = self.buffer[:n]
self.buffer = self.buffer[n:]
return b
@@ -187,23 +195,26 @@ class read_buffer(object):
"""
Accumulate some data.
"""
+
self.buffer += b
def retry(self):
"""
Try dispatching to the callback again.
"""
+
return self.callback(self)
class PDUException(Exception):
"""
Parent exception type for exceptions that signal particular protocol
errors. String value of exception instance will be the message to
- put in the error_report PDU, error_report_code value of exception
+ put in the ErrorReportPDU, error_report_code value of exception
will be the numeric code to use.
"""
def __init__(self, msg = None, pdu = None):
+ Exception.__init__(self)
assert msg is None or isinstance(msg, (str, unicode))
self.error_report_msg = msg
self.error_report_pdu = pdu
@@ -212,9 +223,9 @@ class PDUException(Exception):
return self.error_report_msg or self.__class__.__name__
def make_error_report(self):
- return error_report(errno = self.error_report_code,
- errmsg = self.error_report_msg,
- errpdu = self.error_report_pdu)
+ return ErrorReportPDU(errno = self.error_report_code,
+ errmsg = self.error_report_msg,
+ errpdu = self.error_report_pdu)
class UnsupportedProtocolVersion(PDUException):
error_report_code = 4
@@ -225,12 +236,28 @@ class UnsupportedPDUType(PDUException):
class CorruptData(PDUException):
error_report_code = 0
-class pdu(object):
+
+def wire_pdu(cls):
+ """
+ Class decorator to add a PDU class to the set of known PDUs.
+
+ In the long run, this decorator may take additional arguments
+ specifying which protocol version(s) use this particular PDU,
+ but we're not there yet.
+ """
+
+ PDU.pdu_map[cls.pdu_type] = cls
+ return cls
+
+
+class PDU(object):
"""
Object representing a generic PDU in the rpki-router protocol.
Real PDUs are subclasses of this class.
"""
+ pdu_map = {} # Updated by @wire_pdu
+
version = 0 # Protocol version
_pdu = None # Cached when first generated
@@ -244,6 +271,7 @@ class pdu(object):
"""
Check attributes to make sure they're within range.
"""
+
pass
@classmethod
@@ -273,24 +301,27 @@ class pdu(object):
Handle results in test client. Default behavior is just to print
out the PDU.
"""
+
blather(self)
def send_file(self, server, filename):
"""
Send a content of a file as a cache response. Caller should catch IOError.
"""
+
f = open(filename, "rb")
- server.push_pdu(cache_response(nonce = server.current_nonce))
+ server.push_pdu(CacheResponsePDU(nonce = server.current_nonce))
server.push_file(f)
- server.push_pdu(end_of_data(serial = server.current_serial, nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce))
def send_nodata(self, server):
"""
Send a nodata error.
"""
- server.push_pdu(error_report(errno = error_report.codes["No Data Available"], errpdu = self))
-class pdu_with_serial(pdu):
+ server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["No Data Available"], errpdu = self))
+
+class PDUWithSerial(PDU):
"""
Base class for PDUs consisting of just a serial number and nonce.
"""
@@ -312,6 +343,7 @@ class pdu_with_serial(pdu):
"""
Generate the wire format PDU.
"""
+
if self._pdu is None:
self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
self.header_struct.size, self.serial)
@@ -327,7 +359,7 @@ class pdu_with_serial(pdu):
assert b == self.to_pdu()
return self
-class pdu_nonce(pdu):
+class PDUWithNonce(PDU):
"""
Base class for PDUs consisting of just a nonce.
"""
@@ -346,6 +378,7 @@ class pdu_nonce(pdu):
"""
Generate the wire format PDU.
"""
+
if self._pdu is None:
self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size)
return self._pdu
@@ -360,7 +393,7 @@ class pdu_nonce(pdu):
assert b == self.to_pdu()
return self
-class pdu_empty(pdu):
+class PDUEmpty(PDU):
"""
Base class for empty PDUs.
"""
@@ -374,6 +407,7 @@ class pdu_empty(pdu):
"""
Generate the wire format PDU for this prefix.
"""
+
if self._pdu is None:
self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size)
return self._pdu
@@ -390,7 +424,8 @@ class pdu_empty(pdu):
assert b == self.to_pdu()
return self
-class serial_notify(pdu_with_serial):
+@wire_pdu
+class SerialNotifyPDU(PDUWithSerial):
"""
Serial Notify PDU.
"""
@@ -399,18 +434,20 @@ class serial_notify(pdu_with_serial):
def consume(self, client):
"""
- Respond to a serial_notify message with either a serial_query or
- reset_query, depending on what we already know.
+ Respond to a SerialNotifyPDU with either a SerialQueryPDU or a
+ ResetQueryPDU, depending on what we already know.
"""
+
blather(self)
if client.current_serial is None or client.current_nonce != self.nonce:
- client.push_pdu(reset_query())
+ client.push_pdu(ResetQueryPDU())
elif self.serial != client.current_serial:
- client.push_pdu(serial_query(serial = client.current_serial, nonce = client.current_nonce))
+ client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce))
else:
blather("[Notify did not change serial number, ignoring]")
-class serial_query(pdu_with_serial):
+@wire_pdu
+class SerialQueryPDU(PDUWithSerial):
"""
Serial Query PDU.
"""
@@ -423,25 +460,27 @@ class serial_query(pdu_with_serial):
If client is already up to date, just send an empty incremental
transfer.
"""
+
blather(self)
if server.get_serial() is None:
self.send_nodata(server)
elif server.current_nonce != self.nonce:
log("[Client requested wrong nonce, resetting client]")
- server.push_pdu(cache_reset())
+ server.push_pdu(CacheResetPDU())
elif server.current_serial == self.serial:
blather("[Client is already current, sending empty IXFR]")
- server.push_pdu(cache_response(nonce = server.current_nonce))
- server.push_pdu(end_of_data(serial = server.current_serial, nonce = server.current_nonce))
+ server.push_pdu(CacheResponsePDU(nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(serial = server.current_serial, nonce = server.current_nonce))
elif disable_incrementals:
- server.push_pdu(cache_reset())
+ server.push_pdu(CacheResetPDU())
else:
try:
self.send_file(server, "%d.ix.%d" % (server.current_serial, self.serial))
except IOError:
- server.push_pdu(cache_reset())
+ server.push_pdu(CacheResetPDU())
-class reset_query(pdu_empty):
+@wire_pdu
+class ResetQueryPDU(PDUEmpty):
"""
Reset Query PDU.
"""
@@ -452,6 +491,7 @@ class reset_query(pdu_empty):
"""
Received a reset query, send full current state in response.
"""
+
blather(self)
if server.get_serial() is None:
self.send_nodata(server)
@@ -460,10 +500,11 @@ class reset_query(pdu_empty):
fn = "%d.ax" % server.current_serial
self.send_file(server, fn)
except IOError:
- server.push_pdu(error_report(errno = error_report.codes["Internal Error"],
- errpdu = self, errmsg = "Couldn't open %s" % fn))
+ server.push_pdu(ErrorReportPDU(errno = ErrorReportPDU.codes["Internal Error"],
+ errpdu = self, errmsg = "Couldn't open %s" % fn))
-class cache_response(pdu_nonce):
+@wire_pdu
+class CacheResponsePDU(PDUWithNonce):
"""
Cache Response PDU.
"""
@@ -472,14 +513,16 @@ class cache_response(pdu_nonce):
def consume(self, client):
"""
- Handle cache_response.
+ Handle CacheResponsePDU.
"""
+
blather(self)
if self.nonce != client.current_nonce:
blather("[Nonce changed, resetting]")
client.cache_reset()
-class end_of_data(pdu_with_serial):
+@wire_pdu
+class EndOfDataPDU(PDUWithSerial):
"""
End of Data PDU.
"""
@@ -488,12 +531,14 @@ class end_of_data(pdu_with_serial):
def consume(self, client):
"""
- Handle end_of_data response.
+ Handle EndOfDataPDU response.
"""
+
blather(self)
client.end_of_data(self.serial, self.nonce)
-class cache_reset(pdu_empty):
+@wire_pdu
+class CacheResetPDU(PDUEmpty):
"""
Cache reset PDU.
"""
@@ -502,21 +547,22 @@ class cache_reset(pdu_empty):
def consume(self, client):
"""
- Handle cache_reset response, by issuing a reset_query.
+ Handle CacheResetPDU response, by issuing a ResetQueryPDU.
"""
+
blather(self)
client.cache_reset()
- client.push_pdu(reset_query())
+ client.push_pdu(ResetQueryPDU())
-class prefix(pdu):
+class PrefixPDU(PDU):
"""
Object representing one prefix. This corresponds closely to one PDU
in the rpki-router protocol, so closely that we use lexical ordering
of the wire format of the PDU as the ordering for this class.
This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (ipv4_prefix or
- ipv6_prefix) depending on the syntax of its input text.
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
"""
header_struct = struct.Struct("!BB2xLBBBx")
@@ -528,7 +574,7 @@ class prefix(pdu):
Construct a prefix from its text form.
"""
- cls = ipv6_prefix if ":" in addr else ipv4_prefix
+ cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
self = cls()
self.asn = long(asnum)
p, l = addr.split("/")
@@ -540,7 +586,7 @@ class prefix(pdu):
self.announce = 1
self.check()
return self
-
+
@staticmethod
def from_roa(asnum, prefix_tuple):
"""
@@ -548,7 +594,7 @@ class prefix(pdu):
"""
address, length, maxlength = prefix_tuple
- cls = ipv6_prefix if address.version == 6 else ipv4_prefix
+ cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
self = cls()
self.asn = asnum
# Kludge: Should just use IPAddress, coersion here is historical
@@ -576,6 +622,7 @@ class prefix(pdu):
"""
Handle one incoming prefix PDU
"""
+
blather(self)
client.consume_prefix(self)
@@ -583,6 +630,7 @@ class prefix(pdu):
"""
Check attributes to make sure they're within range.
"""
+
if self.announce not in (0, 1):
raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
if self.prefixlen < 0 or self.prefixlen > self.addr_type.size * 8:
@@ -597,6 +645,7 @@ class prefix(pdu):
"""
Generate the wire format PDU for this prefix.
"""
+
if announce is not None:
assert announce in (0, 1)
elif self._pdu is not None:
@@ -633,9 +682,9 @@ class prefix(pdu):
fields = line.split("|")
# Parse prefix, including figuring out IP protocol version
- cls = ipv6_prefix if ":" in fields[5] else ipv4_prefix
+ cls = IPv6PrefixPDU if ":" in fields[5] else IPv4PrefixPDU
self = cls()
- self.timestamp = timestamp(fields[1])
+ self.timestamp = Timestamp(fields[1])
p, l = fields[5].split("/")
self.prefix = self.addr_type(p)
self.prefixlen = self.max_prefixlen = int(l)
@@ -670,21 +719,26 @@ class prefix(pdu):
log("Ignoring line %r: %s" % (line, e))
raise IgnoreThisRecord
-class ipv4_prefix(prefix):
+@wire_pdu
+class IPv4PrefixPDU(PrefixPDU):
"""
IPv4 flavor of a prefix.
"""
+
pdu_type = 4
addr_type = v4addr
-class ipv6_prefix(prefix):
+@wire_pdu
+class IPv6PrefixPDU(PrefixPDU):
"""
IPv6 flavor of a prefix.
"""
+
pdu_type = 6
addr_type = v6addr
-class router_key(pdu):
+@wire_pdu
+class RouterKeyPDU(PDU):
"""
Router Key PDU.
"""
@@ -754,7 +808,7 @@ class router_key(pdu):
return self._pdu
pdulen = self.header_struct.size + len(self.key)
pdu = (self.header_struct.pack(self.version,
- self.pdu_type,
+ self.pdu_type,
announce if announce is not None else self.announce,
pdulen,
self.ski,
@@ -778,7 +832,8 @@ class router_key(pdu):
return self
-class error_report(pdu):
+@wire_pdu
+class ErrorReportPDU(PDU):
"""
Error Report PDU.
"""
@@ -828,13 +883,14 @@ class error_report(pdu):
"""
Generate the wire format PDU for this error report.
"""
+
if self._pdu is None:
assert isinstance(self.errno, int)
- assert not isinstance(self.errpdu, error_report)
+ assert not isinstance(self.errpdu, ErrorReportPDU)
p = self.errpdu
if p is None:
p = ""
- elif isinstance(p, pdu):
+ elif isinstance(p, PDU):
p = p.to_pdu()
assert isinstance(p, str)
pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg)
@@ -862,17 +918,15 @@ class error_report(pdu):
def serve(self, server):
"""
- Received an error_report from client. Not much we can do beyond
+ Received an ErrorReportPDU from client. Not much we can do beyond
logging it, then killing the connection if error was fatal.
"""
+
log(self)
if self.errno in self.fatal:
log("[Shutting down due to reported fatal protocol error]")
sys.exit(1)
-pdu.pdu_map = dict((p.pdu_type, p) for p in (ipv4_prefix, ipv6_prefix, serial_notify, serial_query, reset_query,
- cache_response, end_of_data, cache_reset, router_key, error_report))
-
class ROA(rpki.POW.ROA):
"""
@@ -909,7 +963,7 @@ class X509(rpki.POW.X509):
yield asn
-class pdu_set(list):
+class PDUSet(list):
"""
Object representing a set of PDUs, that is, one versioned and
(theoretically) consistant set of prefixes and router keys extracted
@@ -919,13 +973,14 @@ class pdu_set(list):
@classmethod
def _load_file(cls, filename):
"""
- Low-level method to read pdu_set from a file.
+ Low-level method to read PDUSet from a file.
"""
+
self = cls()
f = open(filename, "rb")
- r = read_buffer()
+ r = ReadBuffer()
while True:
- p = pdu.read_pdu(r)
+ p = PDU.read_pdu(r)
while p is None:
b = f.read(r.needed())
if b == "":
@@ -940,7 +995,7 @@ class pdu_set(list):
return ((a - b) % (1 << 32)) < (1 << 31)
-class axfr_set(pdu_set):
+class AXFRSet(PDUSet):
"""
Object representing a complete set of PDUs, that is, one versioned
and (theoretically) consistant set of prefixes and router
@@ -952,7 +1007,7 @@ class axfr_set(pdu_set):
def parse_rcynic(cls, rcynic_dir):
"""
Parse ROAS and router certificates fetched (and validated!) by
- rcynic to create a new axfr_set.
+ rcynic to create a new AXFRSet.
In normal operation, we use os.walk() and the rpki.POW library to
parse these data directly, but we can, if so instructed, use
@@ -965,7 +1020,7 @@ class axfr_set(pdu_set):
"""
self = cls()
- self.serial = timestamp.now()
+ self.serial = Timestamp.now()
if scan_roas is None or scan_routercerts is None:
for root, dirs, files in os.walk(rcynic_dir):
@@ -973,7 +1028,7 @@ class axfr_set(pdu_set):
if scan_roas is None and fn.endswith(".roa"):
roa = ROA.derReadFile(os.path.join(root, fn))
asn = roa.getASID()
- self.extend(prefix.from_roa(asn, roa_prefix)
+ self.extend(PrefixPDU.from_roa(asn, roa_prefix)
for roa_prefix in roa.prefixes)
if scan_routercerts is None and fn.endswith(".cer"):
x = X509.derReadFile(os.path.join(root, fn))
@@ -981,7 +1036,7 @@ class axfr_set(pdu_set):
if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
ski = x.getSKI()
key = x.getPublicKey().derWritePublic()
- self.extend(router_key.from_certificate(asn, ski, key)
+ self.extend(RouterKeyPDU.from_certificate(asn, ski, key)
for asn in x.asns)
if scan_roas is not None:
@@ -990,7 +1045,7 @@ class axfr_set(pdu_set):
for line in p.stdout:
line = line.split()
asn = line[1]
- self.extend(prefix.from_text(asn, addr) for addr in line[2:])
+ self.extend(PrefixPDU.from_text(asn, addr) for addr in line[2:])
except OSError, e:
sys.exit("Could not run %s: %s" % (scan_roas, e))
@@ -1001,7 +1056,7 @@ class axfr_set(pdu_set):
line = line.split()
gski = line[0]
key = line[-1]
- self.extend(router_key.from_text(asn, gski, key) for asn in line[1:-1])
+ self.extend(RouterKeyPDU.from_text(asn, gski, key) for asn in line[1:-1])
except OSError, e:
sys.exit("Could not run %s: %s" % (scan_routercerts, e))
@@ -1014,25 +1069,28 @@ class axfr_set(pdu_set):
@classmethod
def load(cls, filename):
"""
- Load an axfr_set from a file, parse filename to obtain serial.
+ Load an AXFRSet from a file, parse filename to obtain serial.
"""
+
fn1, fn2 = os.path.basename(filename).split(".")
assert fn1.isdigit() and fn2 == "ax"
self = cls._load_file(filename)
- self.serial = timestamp(fn1)
+ self.serial = Timestamp(fn1)
return self
def filename(self):
"""
- Generate filename for this axfr_set.
+ Generate filename for this AXFRSet.
"""
+
return "%d.ax" % self.serial
@classmethod
def load_current(cls):
"""
- Load current axfr_set. Return None if can't.
+ Load current AXFRSet. Return None if can't.
"""
+
serial = read_current()[0]
if serial is None:
return None
@@ -1043,8 +1101,9 @@ class axfr_set(pdu_set):
def save_axfr(self):
"""
- Write axfr__set to file with magic filename.
+ Write AXFRSet to file with magic filename.
"""
+
f = open(self.filename(), "wb")
for p in self:
f.write(p.to_pdu())
@@ -1055,6 +1114,7 @@ class axfr_set(pdu_set):
Destroy old data files, presumably because our nonce changed and
the old serial numbers are no longer valid.
"""
+
for i in glob.iglob("*.ix.*"):
os.unlink(i)
for i in glob.iglob("*.ax"):
@@ -1067,6 +1127,7 @@ class axfr_set(pdu_set):
necessary. Creating a new nonce triggers cleanup of old state, as
the new nonce invalidates all old serial numbers.
"""
+
old_serial, nonce = read_current()
if old_serial is None or self.seq_ge(old_serial, self.serial):
blather("Creating new nonce and deleting stale data")
@@ -1076,11 +1137,12 @@ class axfr_set(pdu_set):
def save_ixfr(self, other):
"""
- Comparing this axfr_set with an older one and write the resulting
- ixfr_set to file with magic filename. Since we store pdu_sets
+ Comparing this AXFRSet with an older one and write the resulting
+ IXFRSet to file with magic filename. Since we store PDUSets
in sorted order, computing the difference is a trivial linear
comparison.
"""
+
f = open("%d.ix.%d" % (self.serial, other.serial), "wb")
old = other
new = self
@@ -1105,8 +1167,9 @@ class axfr_set(pdu_set):
def show(self):
"""
- Print this axfr_set.
+ Print this AXFRSet.
"""
+
blather("# AXFR %d (%s)" % (self.serial, self.serial))
for p in self:
blather(p)
@@ -1126,7 +1189,7 @@ class axfr_set(pdu_set):
self.serial = None
for line in cls.read_bgpdump(filename):
try:
- pfx = prefix.from_bgpdump(line, rib_dump = True)
+ pfx = PrefixPDU.from_bgpdump(line, rib_dump = True)
except IgnoreThisRecord:
continue
self.append(pfx)
@@ -1143,7 +1206,7 @@ class axfr_set(pdu_set):
assert os.path.basename(filename).startswith("updates.")
for line in self.read_bgpdump(filename):
try:
- pfx = prefix.from_bgpdump(line, rib_dump = False)
+ pfx = PrefixPDU.from_bgpdump(line, rib_dump = False)
except IgnoreThisRecord:
continue
announce = pfx.announce
@@ -1157,7 +1220,7 @@ class axfr_set(pdu_set):
del self[i]
self.serial = pfx.timestamp
-class ixfr_set(pdu_set):
+class IXFRSet(PDUSet):
"""
Object representing an incremental set of PDUs, that is, the
differences between one versioned and (theoretically) consistant set
@@ -1169,31 +1232,34 @@ class ixfr_set(pdu_set):
@classmethod
def load(cls, filename):
"""
- Load an ixfr_set from a file, parse filename to obtain serials.
+ Load an IXFRSet from a file, parse filename to obtain serials.
"""
+
fn1, fn2, fn3 = os.path.basename(filename).split(".")
assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit()
self = cls._load_file(filename)
- self.from_serial = timestamp(fn3)
- self.to_serial = timestamp(fn1)
+ self.from_serial = Timestamp(fn3)
+ self.to_serial = Timestamp(fn1)
return self
def filename(self):
"""
- Generate filename for this ixfr_set.
+ Generate filename for this IXFRSet.
"""
+
return "%d.ix.%d" % (self.to_serial, self.from_serial)
def show(self):
"""
- Print this ixfr_set.
+ Print this IXFRSet.
"""
+
blather("# IXFR %d (%s) -> %d (%s)" % (self.from_serial, self.from_serial,
self.to_serial, self.to_serial))
for p in self:
blather(p)
-class file_producer(object):
+class FileProducer(object):
"""
File-based producer object for asynchat.
"""
@@ -1205,7 +1271,7 @@ class file_producer(object):
def more(self):
return self.handle.read(self.buffersize)
-class pdu_channel(asynchat.async_chat):
+class PDUChannel(asynchat.async_chat):
"""
asynchat subclass that understands our PDUs. This just handles
network I/O. Specific engines (client, server) should be subclasses
@@ -1215,17 +1281,18 @@ class pdu_channel(asynchat.async_chat):
def __init__(self, conn = None):
asynchat.async_chat.__init__(self, conn)
- self.reader = read_buffer()
+ self.reader = ReadBuffer()
def start_new_pdu(self):
"""
Start read of a new PDU.
"""
+
try:
- p = pdu.read_pdu(self.reader)
+ p = PDU.read_pdu(self.reader)
while p is not None:
self.deliver_pdu(p)
- p = pdu.read_pdu(self.reader)
+ p = PDU.read_pdu(self.reader)
except PDUException, e:
self.push_pdu(e.make_error_report())
self.close_when_done()
@@ -1237,13 +1304,15 @@ class pdu_channel(asynchat.async_chat):
"""
Collect data into the read buffer.
"""
+
self.reader.put(data)
-
+
def found_terminator(self):
"""
Got requested data, see if we now have a PDU. If so, pass it
along, then restart cycle for a new PDU.
"""
+
p = self.reader.retry()
if p is None:
self.set_terminator(self.reader.needed())
@@ -1255,6 +1324,7 @@ class pdu_channel(asynchat.async_chat):
"""
Write PDU to stream.
"""
+
try:
self.push(pdu.to_pdu())
except OSError, e:
@@ -1265,8 +1335,9 @@ class pdu_channel(asynchat.async_chat):
"""
Write content of a file to stream.
"""
+
try:
- self.push_with_producer(file_producer(f, self.ac_out_buffer_size))
+ self.push_with_producer(FileProducer(f, self.ac_out_buffer_size))
except OSError, e:
if e.errno != errno.EAGAIN:
raise
@@ -1275,18 +1346,21 @@ class pdu_channel(asynchat.async_chat):
"""
Intercept asyncore's logging.
"""
+
log(msg)
def log_info(self, msg, tag = "info"):
"""
Intercept asynchat's logging.
"""
+
log("asynchat: %s: %s" % (tag, msg))
def handle_error(self):
"""
Handle errors caught by asyncore main loop.
"""
+
c, e = sys.exc_info()[:2]
if backtrace_on_exceptions or e == 0:
for line in traceback.format_exc().splitlines():
@@ -1300,8 +1374,9 @@ class pdu_channel(asynchat.async_chat):
"""
Kludge to plug asyncore.file_dispatcher into asynchat. Call from
subclass's __init__() method, after calling
- pdu_channel.__init__(), and don't read this on a full stomach.
+ PDUChannel.__init__(), and don't read this on a full stomach.
"""
+
self.connected = True
self._fileno = fd
self.socket = asyncore.file_wrapper(fd)
@@ -1314,10 +1389,11 @@ class pdu_channel(asynchat.async_chat):
"""
Exit when channel closed.
"""
+
asynchat.async_chat.handle_close(self)
sys.exit(0)
-class server_write_channel(pdu_channel):
+class ServerWriteChannel(PDUChannel):
"""
Kludge to deal with ssh's habit of sometimes (compile time option)
invoking us with two unidirectional pipes instead of one
@@ -1330,18 +1406,20 @@ class server_write_channel(pdu_channel):
"""
Set up stdout.
"""
- pdu_channel.__init__(self)
+
+ PDUChannel.__init__(self)
self.init_file_dispatcher(sys.stdout.fileno())
def readable(self):
"""
This channel is never readable.
"""
+
return False
-class server_channel(pdu_channel):
+class ServerChannel(PDUChannel):
"""
- Server protocol engine, handles upcalls from pdu_channel to
+ Server protocol engine, handles upcalls from PDUChannel to
implement protocol logic.
"""
@@ -1350,9 +1428,10 @@ class server_channel(pdu_channel):
Set up stdin and stdout as connection and start listening for
first PDU.
"""
- pdu_channel.__init__(self)
+
+ PDUChannel.__init__(self)
self.init_file_dispatcher(sys.stdin.fileno())
- self.writer = server_write_channel()
+ self.writer = ServerWriteChannel()
self.get_serial()
self.start_new_pdu()
@@ -1360,36 +1439,42 @@ class server_channel(pdu_channel):
"""
This channel is never writable.
"""
+
return False
def push(self, data):
"""
Redirect to writer channel.
"""
+
return self.writer.push(data)
def push_with_producer(self, producer):
"""
Redirect to writer channel.
"""
+
return self.writer.push_with_producer(producer)
def push_pdu(self, pdu):
"""
Redirect to writer channel.
"""
+
return self.writer.push_pdu(pdu)
def push_file(self, f):
"""
Redirect to writer channel.
"""
+
return self.writer.push_file(f)
def deliver_pdu(self, pdu):
"""
Handle received PDU.
"""
+
pdu.serve(self)
def get_serial(self):
@@ -1399,6 +1484,7 @@ class server_channel(pdu_channel):
happen, but maybe we got started in server mode while the cronjob
mode instance is still building its database.
"""
+
self.current_serial, self.current_nonce = read_current()
return self.current_serial
@@ -1406,6 +1492,7 @@ class server_channel(pdu_channel):
"""
Check for a new serial number.
"""
+
old_serial = self.current_serial
return old_serial != self.get_serial()
@@ -1413,14 +1500,15 @@ class server_channel(pdu_channel):
"""
Cronjob instance kicked us, send a notify message.
"""
+
if self.check_serial() is not None:
- self.push_pdu(serial_notify(serial = self.current_serial, nonce = self.current_nonce))
+ self.push_pdu(SerialNotifyPDU(serial = self.current_serial, nonce = self.current_nonce))
else:
log("Cronjob kicked me without a valid current serial number")
-class client_channel(pdu_channel):
+class ClientChannel(PDUChannel):
"""
- Client protocol engine, handles upcalls from pdu_channel.
+ Client protocol engine, handles upcalls from PDUChannel.
"""
current_serial = None
@@ -1435,7 +1523,7 @@ class client_channel(pdu_channel):
self.proc = proc
self.host = host
self.port = port
- pdu_channel.__init__(self, conn = sock)
+ PDUChannel.__init__(self, conn = sock)
self.start_new_pdu()
@classmethod
@@ -1443,6 +1531,7 @@ class client_channel(pdu_channel):
"""
Set up ssh connection and start listening for first PDU.
"""
+
args = ("ssh", "-p", port, "-s", host, "rpki-rtr")
blather("[Running ssh: %s]" % " ".join(args))
s = socket.socketpair()
@@ -1457,6 +1546,7 @@ class client_channel(pdu_channel):
"""
Set up TCP connection and start listening for first PDU.
"""
+
blather("[Starting raw TCP connection to %s:%s]" % (host, port))
try:
addrinfo = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
@@ -1486,6 +1576,7 @@ class client_channel(pdu_channel):
"""
Set up loopback connection and start listening for first PDU.
"""
+
s = socket.socketpair()
blather("[Using direct subprocess kludge for testing]")
argv = [sys.executable, sys.argv[0], "--server"]
@@ -1508,6 +1599,7 @@ class client_channel(pdu_channel):
properly (eg, gnutls-cli, or stunnel's client mode if that works
for such purposes this week).
"""
+
args = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (host, port))
blather("[Running: %s]" % " ".join(args))
s = socket.socketpair()
@@ -1521,6 +1613,7 @@ class client_channel(pdu_channel):
Set up an SQLite database to contain the table we receive. If
necessary, we will create the database.
"""
+
import sqlite3
missing = not os.path.exists(sqlname)
self.sql = sqlite3.connect(sqlname, detect_types = sqlite3.PARSE_DECLTYPES)
@@ -1572,8 +1665,9 @@ class client_channel(pdu_channel):
def cache_reset(self):
"""
- Handle cache_reset actions.
+ Handle CacheResetPDU actions.
"""
+
self.current_serial = None
if self.sql:
cur = self.sql.cursor()
@@ -1582,8 +1676,9 @@ class client_channel(pdu_channel):
def end_of_data(self, serial, nonce):
"""
- Handle end_of_data actions.
+ Handle EndOfDataPDU actions.
"""
+
self.current_serial = serial
self.current_nonce = nonce
if self.sql:
@@ -1595,6 +1690,7 @@ class client_channel(pdu_channel):
"""
Handle one prefix PDU.
"""
+
if self.sql:
values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen)
if prefix.announce:
@@ -1630,14 +1726,16 @@ class client_channel(pdu_channel):
"""
Handle received PDU.
"""
+
pdu.consume(self)
def push_pdu(self, pdu):
"""
Log outbound PDU then write it to stream.
"""
+
blather(pdu)
- pdu_channel.push_pdu(self, pdu)
+ PDUChannel.push_pdu(self, pdu)
def cleanup(self):
"""
@@ -1645,6 +1743,7 @@ class client_channel(pdu_channel):
well, child will have exited already before this method is called,
but we may need to whack it with a stick if something breaks.
"""
+
if self.proc is not None and self.proc.returncode is None:
try:
os.kill(self.proc.pid, self.killsig)
@@ -1655,10 +1754,11 @@ class client_channel(pdu_channel):
"""
Intercept close event so we can log it, then shut down.
"""
+
blather("Server closed channel")
- pdu_channel.handle_close(self)
+ PDUChannel.handle_close(self)
-class kickme_channel(asyncore.dispatcher):
+class KickmeChannel(asyncore.dispatcher):
"""
asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
kick servers when it's time to send notify PDUs to clients.
@@ -1682,18 +1782,21 @@ class kickme_channel(asyncore.dispatcher):
"""
This socket is read-only, never writable.
"""
+
return False
def handle_connect(self):
"""
Ignore connect events (not very useful on datagram socket).
"""
+
pass
def handle_read(self):
"""
Handle receipt of a datagram.
"""
+
data = self.recv(512)
self.server.notify(data)
@@ -1701,6 +1804,7 @@ class kickme_channel(asyncore.dispatcher):
"""
Clean up this dispatcher's socket.
"""
+
self.close()
try:
os.unlink(self.sockname)
@@ -1711,18 +1815,21 @@ class kickme_channel(asyncore.dispatcher):
"""
Intercept asyncore's logging.
"""
+
log(msg)
def log_info(self, msg, tag = "info"):
"""
Intercept asyncore's logging.
"""
+
log("asyncore: %s: %s" % (tag, msg))
def handle_error(self):
"""
Handle errors caught by asyncore main loop.
"""
+
c, e = sys.exc_info()[:2]
if backtrace_on_exceptions or e == 0:
for line in traceback.format_exc().splitlines():
@@ -1833,21 +1940,21 @@ def cronjob_main(argv):
old_ixfrs = glob.glob("*.ix.*")
current = read_current()[0]
- cutoff = timestamp.now(-(24 * 60 * 60))
+ cutoff = Timestamp.now(-(24 * 60 * 60))
for f in glob.iglob("*.ax"):
- t = timestamp(int(f.split(".")[0]))
+ t = Timestamp(int(f.split(".")[0]))
if t < cutoff and t != current:
blather("# Deleting old file %s, timestamp %s" % (f, t))
os.unlink(f)
-
- pdus = axfr_set.parse_rcynic(argv[0])
- if pdus == axfr_set.load_current():
+
+ pdus = AXFRSet.parse_rcynic(argv[0])
+ if pdus == AXFRSet.load_current():
blather("# No change, new version not needed")
sys.exit()
pdus.save_axfr()
for axfr in glob.iglob("*.ax"):
if axfr != pdus.filename():
- pdus.save_ixfr(axfr_set.load(axfr))
+ pdus.save_ixfr(AXFRSet.load(axfr))
pdus.mark_current()
blather("# New serial is %d (%s)" % (pdus.serial, pdus.serial))
@@ -1877,12 +1984,12 @@ def show_main(argv):
g = glob.glob("*.ax")
g.sort()
for f in g:
- axfr_set.load(f).show()
+ AXFRSet.load(f).show()
g = glob.glob("*.ix.*")
g.sort()
for f in g:
- ixfr_set.load(f).show()
+ IXFRSet.load(f).show()
def server_main(argv):
"""
@@ -1926,8 +2033,8 @@ def server_main(argv):
sys.exit(e)
kickme = None
try:
- server = server_channel()
- kickme = kickme_channel(server = server)
+ server = ServerChannel()
+ kickme = KickmeChannel(server = server)
asyncore.loop(timeout = None)
except KeyboardInterrupt:
sys.exit(0)
@@ -1972,7 +2079,7 @@ def listener_tcp_main(argv):
except:
if listener is not None:
listener.close()
- listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
@@ -2048,11 +2155,11 @@ def client_main(argv):
argv = ["loopback"]
proto = argv[0]
if proto == "loopback" and len(argv) in (1, 2):
- constructor = client_channel.loopback
+ constructor = ClientChannel.loopback
host, port = "", ""
sqlname = None if len(argv) == 1 else argv[1]
elif proto in ("ssh", "tcp", "tls") and len(argv) in (3, 4):
- constructor = getattr(client_channel, proto)
+ constructor = getattr(ClientChannel, proto)
host, port = argv[1:3]
sqlname = None if len(argv) == 3 else argv[3]
else:
@@ -2064,9 +2171,9 @@ def client_main(argv):
client.setup_sql(sqlname)
while True:
if client.current_serial is None or client.current_nonce is None:
- client.push_pdu(reset_query())
+ client.push_pdu(ResetQueryPDU())
else:
- client.push_pdu(serial_query(serial = client.current_serial, nonce = client.current_nonce))
+ client.push_pdu(SerialQueryPDU(serial = client.current_serial, nonce = client.current_nonce))
wakeup = time.time() + 600
while True:
remaining = wakeup - time.time()
@@ -2106,10 +2213,10 @@ def bgpdump_convert_main(argv):
if filename.endswith(".ax"):
blather("Reading %s" % filename)
- db = axfr_set.load(filename)
+ db = AXFRSet.load(filename)
elif os.path.basename(filename).startswith("ribs."):
- db = axfr_set.parse_bgpdump_rib_dump(filename)
+ db = AXFRSet.parse_bgpdump_rib_dump(filename)
db.save_axfr()
elif not first:
@@ -2127,7 +2234,7 @@ def bgpdump_convert_main(argv):
for axfr in axfrs:
blather("Loading %s" % axfr)
- ax = axfr_set.load(axfr)
+ ax = AXFRSet.load(axfr)
blather("Computing changes from %d (%s) to %d (%s)" % (ax.serial, ax.serial, db.serial, db.serial))
db.save_ixfr(ax)
del ax
@@ -2152,7 +2259,7 @@ def bgpdump_select_main(argv):
try:
head, sep, tail = os.path.basename(argv[0]).partition(".")
if len(argv) == 1 and head.isdigit() and sep == "." and tail == "ax":
- serial = timestamp(head)
+ serial = Timestamp(head)
except:
pass
if serial is None:
@@ -2166,7 +2273,7 @@ def bgpdump_select_main(argv):
kick_all(serial)
-class bgpsec_replay_clock(object):
+class BGPDumpReplayClock(object):
"""
Internal clock for replaying BGP dump files.
@@ -2182,7 +2289,7 @@ class bgpsec_replay_clock(object):
"""
def __init__(self):
- self.timestamps = [timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")]
+ self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax")]
self.timestamps.sort()
self.offset = self.timestamps[0] - int(time.time())
self.nonce = new_nonce()
@@ -2191,7 +2298,7 @@ class bgpsec_replay_clock(object):
return len(self.timestamps) > 0
def now(self):
- return timestamp.now(self.offset)
+ return Timestamp.now(self.offset)
def read_current(self):
now = self.now()
@@ -2238,11 +2345,11 @@ def bgpdump_server_main(argv):
# method to our clock object. Fun stuff, huh?
#
global read_current
- clock = bgpsec_replay_clock()
+ clock = BGPDumpReplayClock()
read_current = clock.read_current
#
try:
- server = server_channel()
+ server = ServerChannel()
old_serial = server.get_serial()
blather("[Starting at serial %d (%s)]" % (old_serial, old_serial))
while clock: