aboutsummaryrefslogtreecommitdiff
path: root/rpki/rtr
diff options
context:
space:
mode:
authorRob Austein <sra@hactrn.net>2015-10-26 06:29:00 +0000
committerRob Austein <sra@hactrn.net>2015-10-26 06:29:00 +0000
commitb46deb1417dc3596e9ac9fe2fe8cc0b7f42457e7 (patch)
treeca0dc0276d1adc168bc3337ce0564c4ec4957c1b /rpki/rtr
parent397beaf6d9900dc3b3cb612c89ebf1d57b1d16f6 (diff)
"Any programmer who fails to comply with the standard naming, formatting,
or commenting conventions should be shot. If it so happens that it is inconvenient to shoot him, then he is to be politely requested to recode his program in adherence to the above standard." -- Michael Spier, Digital Equipment Corporation svn path=/branches/tk705/; revision=6152
Diffstat (limited to 'rpki/rtr')
-rwxr-xr-xrpki/rtr/bgpdump.py482
-rw-r--r--rpki/rtr/channels.py364
-rw-r--r--rpki/rtr/client.py816
-rw-r--r--rpki/rtr/generator.py948
-rw-r--r--rpki/rtr/main.py110
-rw-r--r--rpki/rtr/pdus.py960
-rw-r--r--rpki/rtr/server.py872
7 files changed, 2276 insertions, 2276 deletions
diff --git a/rpki/rtr/bgpdump.py b/rpki/rtr/bgpdump.py
index 5ffabc4d..3336fb9f 100755
--- a/rpki/rtr/bgpdump.py
+++ b/rpki/rtr/bgpdump.py
@@ -39,292 +39,292 @@ from rpki.rtr.channels import Timestamp
class IgnoreThisRecord(Exception):
- pass
+ pass
class PrefixPDU(rpki.rtr.generator.PrefixPDU):
- @staticmethod
- def from_bgpdump(line, rib_dump):
- try:
- assert isinstance(rib_dump, bool)
- fields = line.split("|")
-
- # Parse prefix, including figuring out IP protocol version
- cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU
- self = cls()
- self.timestamp = Timestamp(fields[1])
- p, l = fields[5].split("/")
- self.prefix = rpki.POW.IPAddress(p)
- self.prefixlen = self.max_prefixlen = int(l)
-
- # Withdrawals don't have AS paths, so be careful
- assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W")
- if fields[2] == "W":
- self.asn = 0
- self.announce = 0
- else:
- self.announce = 1
- if not fields[6] or "{" in fields[6] or "(" in fields[6]:
- raise IgnoreThisRecord
- a = fields[6].split()[-1]
- if "." in a:
- a = [int(s) for s in a.split(".")]
- if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535:
- logging.warn("Bad dotted ASNum %r, ignoring record", fields[6])
+ @staticmethod
+ def from_bgpdump(line, rib_dump):
+ try:
+ assert isinstance(rib_dump, bool)
+ fields = line.split("|")
+
+ # Parse prefix, including figuring out IP protocol version
+ cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU
+ self = cls()
+ self.timestamp = Timestamp(fields[1])
+ p, l = fields[5].split("/")
+ self.prefix = rpki.POW.IPAddress(p)
+ self.prefixlen = self.max_prefixlen = int(l)
+
+ # Withdrawals don't have AS paths, so be careful
+ assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W")
+ if fields[2] == "W":
+ self.asn = 0
+ self.announce = 0
+ else:
+ self.announce = 1
+ if not fields[6] or "{" in fields[6] or "(" in fields[6]:
+ raise IgnoreThisRecord
+ a = fields[6].split()[-1]
+ if "." in a:
+ a = [int(s) for s in a.split(".")]
+ if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535:
+ logging.warn("Bad dotted ASNum %r, ignoring record", fields[6])
+ raise IgnoreThisRecord
+ a = (a[0] << 16) | a[1]
+ else:
+ a = int(a)
+ self.asn = a
+
+ self.check()
+ return self
+
+ except IgnoreThisRecord:
+ raise
+
+ except Exception, e:
+ logging.warn("Ignoring line %r: %s", line, e)
raise IgnoreThisRecord
- a = (a[0] << 16) | a[1]
- else:
- a = int(a)
- self.asn = a
- self.check()
- return self
- except IgnoreThisRecord:
- raise
+class AXFRSet(rpki.rtr.generator.AXFRSet):
- except Exception, e:
- logging.warn("Ignoring line %r: %s", line, e)
- raise IgnoreThisRecord
+ @staticmethod
+ def read_bgpdump(filename):
+ assert filename.endswith(".bz2")
+ logging.debug("Reading %s", filename)
+ bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE)
+ bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE)
+ return bgpdump.stdout
+
+ @classmethod
+ def parse_bgpdump_rib_dump(cls, filename):
+ assert os.path.basename(filename).startswith("ribs.")
+ self = cls()
+ self.serial = None
+ for line in cls.read_bgpdump(filename):
+ try:
+ pfx = PrefixPDU.from_bgpdump(line, rib_dump = True)
+ except IgnoreThisRecord:
+ continue
+ self.append(pfx)
+ self.serial = pfx.timestamp
+ if self.serial is None:
+ sys.exit("Failed to parse anything useful from %s" % filename)
+ self.sort()
+ for i in xrange(len(self) - 2, -1, -1):
+ if self[i] == self[i + 1]:
+ del self[i + 1]
+ return self
+
+ def parse_bgpdump_update(self, filename):
+ assert os.path.basename(filename).startswith("updates.")
+ for line in self.read_bgpdump(filename):
+ try:
+ pfx = PrefixPDU.from_bgpdump(line, rib_dump = False)
+ except IgnoreThisRecord:
+ continue
+ announce = pfx.announce
+ pfx.announce = 1
+ i = bisect.bisect_left(self, pfx)
+ if announce:
+ if i >= len(self) or pfx != self[i]:
+ self.insert(i, pfx)
+ else:
+ while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen:
+ del self[i]
+ self.serial = pfx.timestamp
-class AXFRSet(rpki.rtr.generator.AXFRSet):
+def bgpdump_convert_main(args):
+ """
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ Simulate route origin data from a set of BGP dump files.
+ argv is an ordered list of filenames. Each file must be a BGP RIB
+ dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by
+ this program's --cronjob command. The first file must be a RIB dump
+ or AXFR dump, it cannot be an UPDATE dump. Output will be a set of
+ AXFR and IXFR files with timestamps derived from the BGP dumps,
+ which can be used as input to this program's --server command for
+ test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL.
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ """
+
+ first = True
+ db = None
+ axfrs = []
+ version = max(rpki.rtr.pdus.PDU.version_map.iterkeys())
+
+ for filename in args.files:
+
+ if ".ax.v" in filename:
+ logging.debug("Reading %s", filename)
+ db = AXFRSet.load(filename)
+
+ elif os.path.basename(filename).startswith("ribs."):
+ db = AXFRSet.parse_bgpdump_rib_dump(filename)
+ db.save_axfr()
+
+ elif not first:
+ assert db is not None
+ db.parse_bgpdump_update(filename)
+ db.save_axfr()
- @staticmethod
- def read_bgpdump(filename):
- assert filename.endswith(".bz2")
- logging.debug("Reading %s", filename)
- bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE)
- bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE)
- return bgpdump.stdout
-
- @classmethod
- def parse_bgpdump_rib_dump(cls, filename):
- assert os.path.basename(filename).startswith("ribs.")
- self = cls()
- self.serial = None
- for line in cls.read_bgpdump(filename):
- try:
- pfx = PrefixPDU.from_bgpdump(line, rib_dump = True)
- except IgnoreThisRecord:
- continue
- self.append(pfx)
- self.serial = pfx.timestamp
- if self.serial is None:
- sys.exit("Failed to parse anything useful from %s" % filename)
- self.sort()
- for i in xrange(len(self) - 2, -1, -1):
- if self[i] == self[i + 1]:
- del self[i + 1]
- return self
-
- def parse_bgpdump_update(self, filename):
- assert os.path.basename(filename).startswith("updates.")
- for line in self.read_bgpdump(filename):
- try:
- pfx = PrefixPDU.from_bgpdump(line, rib_dump = False)
- except IgnoreThisRecord:
- continue
- announce = pfx.announce
- pfx.announce = 1
- i = bisect.bisect_left(self, pfx)
- if announce:
- if i >= len(self) or pfx != self[i]:
- self.insert(i, pfx)
- else:
- while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen:
- del self[i]
- self.serial = pfx.timestamp
+ else:
+ sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename)
+ logging.debug("DB serial now %d (%s)", db.serial, db.serial)
+ if first and rpki.rtr.server.read_current(version) == (None, None):
+ db.mark_current()
+ first = False
-def bgpdump_convert_main(args):
- """
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- Simulate route origin data from a set of BGP dump files.
- argv is an ordered list of filenames. Each file must be a BGP RIB
- dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by
- this program's --cronjob command. The first file must be a RIB dump
- or AXFR dump, it cannot be an UPDATE dump. Output will be a set of
- AXFR and IXFR files with timestamps derived from the BGP dumps,
- which can be used as input to this program's --server command for
- test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL.
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- """
-
- first = True
- db = None
- axfrs = []
- version = max(rpki.rtr.pdus.PDU.version_map.iterkeys())
-
- for filename in args.files:
-
- if ".ax.v" in filename:
- logging.debug("Reading %s", filename)
- db = AXFRSet.load(filename)
-
- elif os.path.basename(filename).startswith("ribs."):
- db = AXFRSet.parse_bgpdump_rib_dump(filename)
- db.save_axfr()
-
- elif not first:
- assert db is not None
- db.parse_bgpdump_update(filename)
- db.save_axfr()
-
- else:
- sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename)
-
- logging.debug("DB serial now %d (%s)", db.serial, db.serial)
- if first and rpki.rtr.server.read_current(version) == (None, None):
- db.mark_current()
- first = False
-
- for axfr in axfrs:
- logging.debug("Loading %s", axfr)
- ax = AXFRSet.load(axfr)
- logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial)
- db.save_ixfr(ax)
- del ax
-
- axfrs.append(db.filename())
+ for axfr in axfrs:
+ logging.debug("Loading %s", axfr)
+ ax = AXFRSet.load(axfr)
+ logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial)
+ db.save_ixfr(ax)
+ del ax
+
+ axfrs.append(db.filename())
def bgpdump_select_main(args):
- """
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- Simulate route origin data from a set of BGP dump files.
- Set current serial number to correspond to an .ax file created by
- converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL.
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- """
+ """
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ Simulate route origin data from a set of BGP dump files.
+ Set current serial number to correspond to an .ax file created by
+ converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL.
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ """
- head, sep, tail = os.path.basename(args.ax_file).partition(".")
- if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit():
- sys.exit("Argument must be name of a .ax file")
+ head, sep, tail = os.path.basename(args.ax_file).partition(".")
+ if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit():
+ sys.exit("Argument must be name of a .ax file")
- serial = Timestamp(head)
- version = int(tail[4:])
+ serial = Timestamp(head)
+ version = int(tail[4:])
- if version not in rpki.rtr.pdus.PDU.version_map:
- sys.exit("Unknown protocol version %d" % version)
+ if version not in rpki.rtr.pdus.PDU.version_map:
+ sys.exit("Unknown protocol version %d" % version)
- nonce = rpki.rtr.server.read_current(version)[1]
- if nonce is None:
- nonce = rpki.rtr.generator.new_nonce()
+ nonce = rpki.rtr.server.read_current(version)[1]
+ if nonce is None:
+ nonce = rpki.rtr.generator.new_nonce()
- rpki.rtr.server.write_current(serial, nonce, version)
- rpki.rtr.generator.kick_all(serial)
+ rpki.rtr.server.write_current(serial, nonce, version)
+ rpki.rtr.generator.kick_all(serial)
class BGPDumpReplayClock(object):
- """
- Internal clock for replaying BGP dump files.
+ """
+ Internal clock for replaying BGP dump files.
- * DANGER WILL ROBINSON! *
- * DEBUGGING AND TEST USE ONLY! *
+ * DANGER WILL ROBINSON! *
+ * DEBUGGING AND TEST USE ONLY! *
- This class replaces the normal on-disk serial number mechanism with
- an in-memory version based on pre-computed data.
+ This class replaces the normal on-disk serial number mechanism with
+ an in-memory version based on pre-computed data.
- bgpdump_server_main() uses this hack to replay historical data for
- testing purposes. DO NOT USE THIS IN PRODUCTION.
+ bgpdump_server_main() uses this hack to replay historical data for
+ testing purposes. DO NOT USE THIS IN PRODUCTION.
- You have been warned.
- """
+ You have been warned.
+ """
- def __init__(self):
- self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")]
- self.timestamps.sort()
- self.offset = self.timestamps[0] - int(time.time())
- self.nonce = rpki.rtr.generator.new_nonce()
+ def __init__(self):
+ self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")]
+ self.timestamps.sort()
+ self.offset = self.timestamps[0] - int(time.time())
+ self.nonce = rpki.rtr.generator.new_nonce()
- def __nonzero__(self):
- return len(self.timestamps) > 0
+ def __nonzero__(self):
+ return len(self.timestamps) > 0
- def now(self):
- return Timestamp.now(self.offset)
+ def now(self):
+ return Timestamp.now(self.offset)
- def read_current(self, version):
- now = self.now()
- while len(self.timestamps) > 1 and now >= self.timestamps[1]:
- del self.timestamps[0]
- return self.timestamps[0], self.nonce
+ def read_current(self, version):
+ now = self.now()
+ while len(self.timestamps) > 1 and now >= self.timestamps[1]:
+ del self.timestamps[0]
+ return self.timestamps[0], self.nonce
- def siesta(self):
- now = self.now()
- if len(self.timestamps) <= 1:
- return None
- elif now < self.timestamps[1]:
- return self.timestamps[1] - now
- else:
- return 1
+ def siesta(self):
+ now = self.now()
+ if len(self.timestamps) <= 1:
+ return None
+ elif now < self.timestamps[1]:
+ return self.timestamps[1] - now
+ else:
+ return 1
def bgpdump_server_main(args):
- """
- Simulate route origin data from a set of BGP dump files.
+ """
+ Simulate route origin data from a set of BGP dump files.
+
+ * DANGER WILL ROBINSON! *
+ * DEBUGGING AND TEST USE ONLY! *
+
+ This is a clone of server_main() which replaces the external serial
+ number updates triggered via the kickme channel by cronjob_main with
+ an internal clocking mechanism to replay historical test data.
- * DANGER WILL ROBINSON! *
- * DEBUGGING AND TEST USE ONLY! *
+ DO NOT USE THIS IN PRODUCTION.
- This is a clone of server_main() which replaces the external serial
- number updates triggered via the kickme channel by cronjob_main with
- an internal clocking mechanism to replay historical test data.
+ You have been warned.
+ """
- DO NOT USE THIS IN PRODUCTION.
+ logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag()))
- You have been warned.
- """
+ logger.debug("[Starting]")
- logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag()))
+ if args.rpki_rtr_dir:
+ try:
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ sys.exit(e)
- logger.debug("[Starting]")
+ # Yes, this really does replace a global function defined in another
+ # module with a bound method to our clock object. Fun stuff, huh?
+ #
+ clock = BGPDumpReplayClock()
+ rpki.rtr.server.read_current = clock.read_current
- if args.rpki_rtr_dir:
try:
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- sys.exit(e)
-
- # Yes, this really does replace a global function defined in another
- # module with a bound method to our clock object. Fun stuff, huh?
- #
- clock = BGPDumpReplayClock()
- rpki.rtr.server.read_current = clock.read_current
-
- try:
- server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
- old_serial = server.get_serial()
- logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial)
- while clock:
- new_serial = server.get_serial()
- if old_serial != new_serial:
- logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial)
- server.notify()
- old_serial = new_serial
- asyncore.loop(timeout = clock.siesta(), count = 1)
- except KeyboardInterrupt:
- sys.exit(0)
+ server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
+ old_serial = server.get_serial()
+ logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial)
+ while clock:
+ new_serial = server.get_serial()
+ if old_serial != new_serial:
+ logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial)
+ server.notify()
+ old_serial = new_serial
+ asyncore.loop(timeout = clock.siesta(), count = 1)
+ except KeyboardInterrupt:
+ sys.exit(0)
def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__,
- help = "Convert bgpdump to fake ROAs")
- subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog")
- subparser.add_argument("files", nargs = "+", help = "input files")
-
- subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__,
- help = "Set current serial number for fake ROA data")
- subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog")
- subparser.add_argument("ax_file", help = "name of the .ax to select")
-
- subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__,
- help = "Replay fake ROAs generated from historical data")
- subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+ """
+ Set up argparse stuff for commands in this module.
+ """
+
+ subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__,
+ help = "Convert bgpdump to fake ROAs")
+ subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog")
+ subparser.add_argument("files", nargs = "+", help = "input files")
+
+ subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__,
+ help = "Set current serial number for fake ROA data")
+ subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog")
+ subparser.add_argument("ax_file", help = "name of the .ax to select")
+
+ subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__,
+ help = "Replay fake ROAs generated from historical data")
+ subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
diff --git a/rpki/rtr/channels.py b/rpki/rtr/channels.py
index d14c024d..e2f443e8 100644
--- a/rpki/rtr/channels.py
+++ b/rpki/rtr/channels.py
@@ -32,215 +32,215 @@ import rpki.rtr.pdus
class Timestamp(int):
- """
- Wrapper around time module.
- """
-
- def __new__(cls, t):
- # __new__() is a static method, not a class method, hence the odd calling sequence.
- return super(Timestamp, cls).__new__(cls, t)
-
- @classmethod
- def now(cls, delta = 0):
- return cls(time.time() + delta)
-
- def __str__(self):
- return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self))
-
-
-class ReadBuffer(object):
- """
- Wrapper around synchronous/asynchronous read state.
-
- This also handles tracking the current protocol version,
- because it has to go somewhere and there's no better place.
- """
-
- def __init__(self):
- self.buffer = ""
- self.version = None
-
- def update(self, need, callback):
"""
- Update count of needed bytes and callback, then dispatch to callback.
+ Wrapper around time module.
"""
- self.need = need
- self.callback = callback
- return self.retry()
+ def __new__(cls, t):
+ # __new__() is a static method, not a class method, hence the odd calling sequence.
+ return super(Timestamp, cls).__new__(cls, t)
- def retry(self):
- """
- Try dispatching to the callback again.
- """
+ @classmethod
+ def now(cls, delta = 0):
+ return cls(time.time() + delta)
- return self.callback(self)
+ def __str__(self):
+ return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self))
- def available(self):
- """
- How much data do we have available in this buffer?
- """
- return len(self.buffer)
-
- def needed(self):
- """
- How much more data does this buffer need to become ready?
+class ReadBuffer(object):
"""
+ Wrapper around synchronous/asynchronous read state.
- return self.need - self.available()
-
- def ready(self):
- """
- Is this buffer ready to read yet?
+ This also handles tracking the current protocol version,
+ because it has to go somewhere and there's no better place.
"""
- return self.available() >= self.need
+ def __init__(self):
+ self.buffer = ""
+ self.version = None
- def get(self, n):
- """
- Hand some data to the caller.
- """
+ def update(self, need, callback):
+ """
+ Update count of needed bytes and callback, then dispatch to callback.
+ """
- b = self.buffer[:n]
- self.buffer = self.buffer[n:]
- return b
+ self.need = need
+ self.callback = callback
+ return self.retry()
- def put(self, b):
- """
- Accumulate some data.
- """
+ def retry(self):
+ """
+ Try dispatching to the callback again.
+ """
- self.buffer += b
+ return self.callback(self)
- def check_version(self, version):
- """
- Track version number of PDUs read from this buffer.
- Once set, the version must not change.
- """
+ def available(self):
+ """
+ How much data do we have available in this buffer?
+ """
- if self.version is not None and version != self.version:
- raise rpki.rtr.pdus.CorruptData(
- "Received PDU version %d, expected %d" % (version, self.version))
- if self.version is None and version not in rpki.rtr.pdus.PDU.version_map:
- raise rpki.rtr.pdus.UnsupportedProtocolVersion(
- "Received PDU version %s, known versions %s" % (
- version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map)))
- self.version = version
+ return len(self.buffer)
+ def needed(self):
+ """
+ How much more data does this buffer need to become ready?
+ """
-class PDUChannel(asynchat.async_chat, object):
- """
- asynchat subclass that understands our PDUs. This just handles
- network I/O. Specific engines (client, server) should be subclasses
- of this with methods that do something useful with the resulting
- PDUs.
- """
-
- def __init__(self, root_pdu_class, sock = None):
- asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super()
- self.reader = ReadBuffer()
- assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU)
- self.root_pdu_class = root_pdu_class
-
- @property
- def version(self):
- return self.reader.version
-
- @version.setter
- def version(self, version):
- self.reader.check_version(version)
-
- def start_new_pdu(self):
- """
- Start read of a new PDU.
- """
-
- try:
- p = self.root_pdu_class.read_pdu(self.reader)
- while p is not None:
- self.deliver_pdu(p)
- p = self.root_pdu_class.read_pdu(self.reader)
- except rpki.rtr.pdus.PDUException, e:
- self.push_pdu(e.make_error_report(version = self.version))
- self.close_when_done()
- else:
- assert not self.reader.ready()
- self.set_terminator(self.reader.needed())
-
- def collect_incoming_data(self, data):
- """
- Collect data into the read buffer.
- """
-
- self.reader.put(data)
-
- def found_terminator(self):
- """
- Got requested data, see if we now have a PDU. If so, pass it
- along, then restart cycle for a new PDU.
- """
-
- p = self.reader.retry()
- if p is None:
- self.set_terminator(self.reader.needed())
- else:
- self.deliver_pdu(p)
- self.start_new_pdu()
-
- def push_pdu(self, pdu):
- """
- Write PDU to stream.
- """
+ return self.need - self.available()
- try:
- self.push(pdu.to_pdu())
- except OSError, e:
- if e.errno != errno.EAGAIN:
- raise
+ def ready(self):
+ """
+ Is this buffer ready to read yet?
+ """
- def log(self, msg):
- """
- Intercept asyncore's logging.
- """
+ return self.available() >= self.need
- logging.info(msg)
+ def get(self, n):
+ """
+ Hand some data to the caller.
+ """
- def log_info(self, msg, tag = "info"):
- """
- Intercept asynchat's logging.
- """
+ b = self.buffer[:n]
+ self.buffer = self.buffer[n:]
+ return b
- logging.info("asynchat: %s: %s", tag, msg)
+ def put(self, b):
+ """
+ Accumulate some data.
+ """
- def handle_error(self):
- """
- Handle errors caught by asyncore main loop.
- """
+ self.buffer += b
- logging.exception("[Unhandled exception]")
- logging.critical("[Exiting after unhandled exception]")
- sys.exit(1)
+ def check_version(self, version):
+ """
+ Track version number of PDUs read from this buffer.
+ Once set, the version must not change.
+ """
- def init_file_dispatcher(self, fd):
- """
- Kludge to plug asyncore.file_dispatcher into asynchat. Call from
- subclass's __init__() method, after calling
- PDUChannel.__init__(), and don't read this on a full stomach.
- """
+ if self.version is not None and version != self.version:
+ raise rpki.rtr.pdus.CorruptData(
+ "Received PDU version %d, expected %d" % (version, self.version))
+ if self.version is None and version not in rpki.rtr.pdus.PDU.version_map:
+ raise rpki.rtr.pdus.UnsupportedProtocolVersion(
+ "Received PDU version %s, known versions %s" % (
+ version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map)))
+ self.version = version
- self.connected = True
- self._fileno = fd
- self.socket = asyncore.file_wrapper(fd)
- self.add_channel()
- flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
- flags = flags | os.O_NONBLOCK
- fcntl.fcntl(fd, fcntl.F_SETFL, flags)
- def handle_close(self):
- """
- Exit when channel closed.
+class PDUChannel(asynchat.async_chat, object):
"""
-
- asynchat.async_chat.handle_close(self)
- sys.exit(0)
+ asynchat subclass that understands our PDUs. This just handles
+ network I/O. Specific engines (client, server) should be subclasses
+ of this with methods that do something useful with the resulting
+ PDUs.
+ """
+
+ def __init__(self, root_pdu_class, sock = None):
+ asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super()
+ self.reader = ReadBuffer()
+ assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU)
+ self.root_pdu_class = root_pdu_class
+
+ @property
+ def version(self):
+ return self.reader.version
+
+ @version.setter
+ def version(self, version):
+ self.reader.check_version(version)
+
+ def start_new_pdu(self):
+ """
+ Start read of a new PDU.
+ """
+
+ try:
+ p = self.root_pdu_class.read_pdu(self.reader)
+ while p is not None:
+ self.deliver_pdu(p)
+ p = self.root_pdu_class.read_pdu(self.reader)
+ except rpki.rtr.pdus.PDUException, e:
+ self.push_pdu(e.make_error_report(version = self.version))
+ self.close_when_done()
+ else:
+ assert not self.reader.ready()
+ self.set_terminator(self.reader.needed())
+
+ def collect_incoming_data(self, data):
+ """
+ Collect data into the read buffer.
+ """
+
+ self.reader.put(data)
+
+ def found_terminator(self):
+ """
+ Got requested data, see if we now have a PDU. If so, pass it
+ along, then restart cycle for a new PDU.
+ """
+
+ p = self.reader.retry()
+ if p is None:
+ self.set_terminator(self.reader.needed())
+ else:
+ self.deliver_pdu(p)
+ self.start_new_pdu()
+
+ def push_pdu(self, pdu):
+ """
+ Write PDU to stream.
+ """
+
+ try:
+ self.push(pdu.to_pdu())
+ except OSError, e:
+ if e.errno != errno.EAGAIN:
+ raise
+
+ def log(self, msg):
+ """
+ Intercept asyncore's logging.
+ """
+
+ logging.info(msg)
+
+ def log_info(self, msg, tag = "info"):
+ """
+ Intercept asynchat's logging.
+ """
+
+ logging.info("asynchat: %s: %s", tag, msg)
+
+ def handle_error(self):
+ """
+ Handle errors caught by asyncore main loop.
+ """
+
+ logging.exception("[Unhandled exception]")
+ logging.critical("[Exiting after unhandled exception]")
+ sys.exit(1)
+
+ def init_file_dispatcher(self, fd):
+ """
+ Kludge to plug asyncore.file_dispatcher into asynchat. Call from
+ subclass's __init__() method, after calling
+ PDUChannel.__init__(), and don't read this on a full stomach.
+ """
+
+ self.connected = True
+ self._fileno = fd
+ self.socket = asyncore.file_wrapper(fd)
+ self.add_channel()
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+ def handle_close(self):
+ """
+ Exit when channel closed.
+ """
+
+ asynchat.async_chat.handle_close(self)
+ sys.exit(0)
diff --git a/rpki/rtr/client.py b/rpki/rtr/client.py
index a35ab81d..9c7a00d6 100644
--- a/rpki/rtr/client.py
+++ b/rpki/rtr/client.py
@@ -37,13 +37,13 @@ from rpki.rtr.channels import Timestamp
class PDU(rpki.rtr.pdus.PDU):
- def consume(self, client):
- """
- Handle results in test client. Default behavior is just to print
- out the PDU; data PDU subclasses may override this.
- """
+ def consume(self, client):
+ """
+ Handle results in test client. Default behavior is just to print
+ out the PDU; data PDU subclasses may override this.
+ """
- logging.debug(self)
+ logging.debug(self)
clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU)
@@ -52,407 +52,407 @@ clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU)
@clone_pdu
class SerialNotifyPDU(rpki.rtr.pdus.SerialNotifyPDU):
- def consume(self, client):
- """
- Respond to a SerialNotifyPDU with either a SerialQueryPDU or a
- ResetQueryPDU, depending on what we already know.
- """
+ def consume(self, client):
+ """
+ Respond to a SerialNotifyPDU with either a SerialQueryPDU or a
+ ResetQueryPDU, depending on what we already know.
+ """
- logging.debug(self)
- if client.serial is None or client.nonce != self.nonce:
- client.push_pdu(ResetQueryPDU(version = client.version))
- elif self.serial != client.serial:
- client.push_pdu(SerialQueryPDU(version = client.version,
- serial = client.serial,
- nonce = client.nonce))
- else:
- logging.debug("[Notify did not change serial number, ignoring]")
+ logging.debug(self)
+ if client.serial is None or client.nonce != self.nonce:
+ client.push_pdu(ResetQueryPDU(version = client.version))
+ elif self.serial != client.serial:
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.serial,
+ nonce = client.nonce))
+ else:
+ logging.debug("[Notify did not change serial number, ignoring]")
@clone_pdu
class CacheResponsePDU(rpki.rtr.pdus.CacheResponsePDU):
- def consume(self, client):
- """
- Handle CacheResponsePDU.
- """
+ def consume(self, client):
+ """
+ Handle CacheResponsePDU.
+ """
- logging.debug(self)
- if self.nonce != client.nonce:
- logging.debug("[Nonce changed, resetting]")
- client.cache_reset()
+ logging.debug(self)
+ if self.nonce != client.nonce:
+ logging.debug("[Nonce changed, resetting]")
+ client.cache_reset()
@clone_pdu
class EndOfDataPDUv0(rpki.rtr.pdus.EndOfDataPDUv0):
- def consume(self, client):
- """
- Handle EndOfDataPDU response.
- """
+ def consume(self, client):
+ """
+ Handle EndOfDataPDU response.
+ """
- logging.debug(self)
- client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+ logging.debug(self)
+ client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
@clone_pdu
class EndOfDataPDUv1(rpki.rtr.pdus.EndOfDataPDUv1):
- def consume(self, client):
- """
- Handle EndOfDataPDU response.
- """
+ def consume(self, client):
+ """
+ Handle EndOfDataPDU response.
+ """
- logging.debug(self)
- client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+ logging.debug(self)
+ client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
@clone_pdu
class CacheResetPDU(rpki.rtr.pdus.CacheResetPDU):
- def consume(self, client):
- """
- Handle CacheResetPDU response, by issuing a ResetQueryPDU.
- """
+ def consume(self, client):
+ """
+ Handle CacheResetPDU response, by issuing a ResetQueryPDU.
+ """
- logging.debug(self)
- client.cache_reset()
- client.push_pdu(ResetQueryPDU(version = client.version))
+ logging.debug(self)
+ client.cache_reset()
+ client.push_pdu(ResetQueryPDU(version = client.version))
class PrefixPDU(rpki.rtr.pdus.PrefixPDU):
- """
- Object representing one prefix. This corresponds closely to one PDU
- in the rpki-router protocol, so closely that we use lexical ordering
- of the wire format of the PDU as the ordering for this class.
-
- This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (IPv4PrefixPDU or
- IPv6PrefixPDU) depending on the syntax of its input text.
- """
-
- def consume(self, client):
"""
- Handle one incoming prefix PDU
+ Object representing one prefix. This corresponds closely to one PDU
+ in the rpki-router protocol, so closely that we use lexical ordering
+ of the wire format of the PDU as the ordering for this class.
+
+ This is a virtual class, but the .from_text() constructor
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
"""
- logging.debug(self)
- client.consume_prefix(self)
+ def consume(self, client):
+ """
+ Handle one incoming prefix PDU
+ """
+
+ logging.debug(self)
+ client.consume_prefix(self)
@clone_pdu
class IPv4PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv4PrefixPDU):
- pass
+ pass
@clone_pdu
class IPv6PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv6PrefixPDU):
- pass
+ pass
@clone_pdu
class ErrorReportPDU(PDU, rpki.rtr.pdus.ErrorReportPDU):
- pass
+ pass
@clone_pdu
class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU):
- """
- Router Key PDU.
- """
-
- def consume(self, client):
"""
- Handle one incoming Router Key PDU
+ Router Key PDU.
"""
- logging.debug(self)
- client.consume_routerkey(self)
+ def consume(self, client):
+ """
+ Handle one incoming Router Key PDU
+ """
+ logging.debug(self)
+ client.consume_routerkey(self)
-class ClientChannel(rpki.rtr.channels.PDUChannel):
- """
- Client protocol engine, handles upcalls from PDUChannel.
- """
-
- serial = None
- nonce = None
- sql = None
- host = None
- port = None
- cache_id = None
- refresh = rpki.rtr.pdus.default_refresh
- retry = rpki.rtr.pdus.default_retry
- expire = rpki.rtr.pdus.default_expire
- updated = Timestamp(0)
-
- def __init__(self, sock, proc, killsig, args, host = None, port = None):
- self.killsig = killsig
- self.proc = proc
- self.args = args
- self.host = args.host if host is None else host
- self.port = args.port if port is None else port
- super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU)
- if args.force_version is not None:
- self.version = args.force_version
- self.start_new_pdu()
- if args.sql_database:
- self.setup_sql()
-
- @classmethod
- def ssh(cls, args):
- """
- Set up ssh connection and start listening for first PDU.
- """
- if args.port is None:
- argv = ("ssh", "-s", args.host, "rpki-rtr")
- else:
- argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr")
- logging.debug("[Running ssh: %s]", " ".join(argv))
- s = socket.socketpair()
- return cls(sock = s[1],
- proc = subprocess.Popen(argv, executable = "/usr/bin/ssh",
- stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGKILL, args = args)
-
- @classmethod
- def tcp(cls, args):
- """
- Set up TCP connection and start listening for first PDU.
+class ClientChannel(rpki.rtr.channels.PDUChannel):
"""
-
- logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port)
- try:
- addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM)
- except socket.error, e:
- logging.debug("[socket.getaddrinfo() failed: %s]", e)
- else:
- for ai in addrinfo:
- af, socktype, proto, cn, sa = ai # pylint: disable=W0612
- logging.debug("[Trying addr %s port %s]", sa[0], sa[1])
+ Client protocol engine, handles upcalls from PDUChannel.
+ """
+
+ serial = None
+ nonce = None
+ sql = None
+ host = None
+ port = None
+ cache_id = None
+ refresh = rpki.rtr.pdus.default_refresh
+ retry = rpki.rtr.pdus.default_retry
+ expire = rpki.rtr.pdus.default_expire
+ updated = Timestamp(0)
+
+ def __init__(self, sock, proc, killsig, args, host = None, port = None):
+ self.killsig = killsig
+ self.proc = proc
+ self.args = args
+ self.host = args.host if host is None else host
+ self.port = args.port if port is None else port
+ super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU)
+ if args.force_version is not None:
+ self.version = args.force_version
+ self.start_new_pdu()
+ if args.sql_database:
+ self.setup_sql()
+
+ @classmethod
+ def ssh(cls, args):
+ """
+ Set up ssh connection and start listening for first PDU.
+ """
+
+ if args.port is None:
+ argv = ("ssh", "-s", args.host, "rpki-rtr")
+ else:
+ argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr")
+ logging.debug("[Running ssh: %s]", " ".join(argv))
+ s = socket.socketpair()
+ return cls(sock = s[1],
+ proc = subprocess.Popen(argv, executable = "/usr/bin/ssh",
+ stdin = s[0], stdout = s[0], close_fds = True),
+ killsig = signal.SIGKILL, args = args)
+
+ @classmethod
+ def tcp(cls, args):
+ """
+ Set up TCP connection and start listening for first PDU.
+ """
+
+ logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port)
try:
- s = socket.socket(af, socktype, proto)
+ addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.error, e:
- logging.debug("[socket.socket() failed: %s]", e)
- continue
+ logging.debug("[socket.getaddrinfo() failed: %s]", e)
+ else:
+ for ai in addrinfo:
+ af, socktype, proto, cn, sa = ai # pylint: disable=W0612
+ logging.debug("[Trying addr %s port %s]", sa[0], sa[1])
+ try:
+ s = socket.socket(af, socktype, proto)
+ except socket.error, e:
+ logging.debug("[socket.socket() failed: %s]", e)
+ continue
+ try:
+ s.connect(sa)
+ except socket.error, e:
+ logging.exception("[socket.connect() failed: %s]", e)
+ s.close()
+ continue
+ return cls(sock = s, proc = None, killsig = None, args = args)
+ sys.exit(1)
+
+ @classmethod
+ def loopback(cls, args):
+ """
+ Set up loopback connection and start listening for first PDU.
+ """
+
+ s = socket.socketpair()
+ logging.debug("[Using direct subprocess kludge for testing]")
+ argv = (sys.executable, sys.argv[0], "server")
+ return cls(sock = s[1],
+ proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
+ killsig = signal.SIGINT, args = args,
+ host = args.host or "none", port = args.port or "none")
+
+ @classmethod
+ def tls(cls, args):
+ """
+ Set up TLS connection and start listening for first PDU.
+
+ NB: This uses OpenSSL's "s_client" command, which does not
+ check server certificates properly, so this is not suitable for
+ production use. Fixing this would be a trivial change, it just
+ requires using a client program which does check certificates
+ properly (eg, gnutls-cli, or stunnel's client mode if that works
+ for such purposes this week).
+ """
+
+ argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port))
+ logging.debug("[Running: %s]", " ".join(argv))
+ s = socket.socketpair()
+ return cls(sock = s[1],
+ proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
+ killsig = signal.SIGKILL, args = args)
+
+ def setup_sql(self):
+ """
+ Set up an SQLite database to contain the table we receive. If
+ necessary, we will create the database.
+ """
+
+ import sqlite3
+ missing = not os.path.exists(self.args.sql_database)
+ self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES)
+ self.sql.text_factory = str
+ cur = self.sql.cursor()
+ cur.execute("PRAGMA foreign_keys = on")
+ if missing:
+ cur.execute('''
+ CREATE TABLE cache (
+ cache_id INTEGER PRIMARY KEY NOT NULL,
+ host TEXT NOT NULL,
+ port TEXT NOT NULL,
+ version INTEGER,
+ nonce INTEGER,
+ serial INTEGER,
+ updated INTEGER,
+ refresh INTEGER,
+ retry INTEGER,
+ expire INTEGER,
+ UNIQUE (host, port))''')
+ cur.execute('''
+ CREATE TABLE prefix (
+ cache_id INTEGER NOT NULL
+ REFERENCES cache(cache_id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE,
+ asn INTEGER NOT NULL,
+ prefix TEXT NOT NULL,
+ prefixlen INTEGER NOT NULL,
+ max_prefixlen INTEGER NOT NULL,
+ UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''')
+ cur.execute('''
+ CREATE TABLE routerkey (
+ cache_id INTEGER NOT NULL
+ REFERENCES cache(cache_id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE,
+ asn INTEGER NOT NULL,
+ ski TEXT NOT NULL,
+ key TEXT NOT NULL,
+ UNIQUE (cache_id, asn, ski),
+ UNIQUE (cache_id, asn, key))''')
+ elif self.args.reset_session:
+ cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
+ cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated "
+ "FROM cache WHERE host = ? AND port = ?",
+ (self.host, self.port))
try:
- s.connect(sa)
- except socket.error, e:
- logging.exception("[socket.connect() failed: %s]", e)
- s.close()
- continue
- return cls(sock = s, proc = None, killsig = None, args = args)
- sys.exit(1)
-
- @classmethod
- def loopback(cls, args):
- """
- Set up loopback connection and start listening for first PDU.
- """
-
- s = socket.socketpair()
- logging.debug("[Using direct subprocess kludge for testing]")
- argv = (sys.executable, sys.argv[0], "server")
- return cls(sock = s[1],
- proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGINT, args = args,
- host = args.host or "none", port = args.port or "none")
-
- @classmethod
- def tls(cls, args):
- """
- Set up TLS connection and start listening for first PDU.
-
- NB: This uses OpenSSL's "s_client" command, which does not
- check server certificates properly, so this is not suitable for
- production use. Fixing this would be a trivial change, it just
- requires using a client program which does check certificates
- properly (eg, gnutls-cli, or stunnel's client mode if that works
- for such purposes this week).
- """
-
- argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port))
- logging.debug("[Running: %s]", " ".join(argv))
- s = socket.socketpair()
- return cls(sock = s[1],
- proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGKILL, args = args)
-
- def setup_sql(self):
- """
- Set up an SQLite database to contain the table we receive. If
- necessary, we will create the database.
- """
-
- import sqlite3
- missing = not os.path.exists(self.args.sql_database)
- self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES)
- self.sql.text_factory = str
- cur = self.sql.cursor()
- cur.execute("PRAGMA foreign_keys = on")
- if missing:
- cur.execute('''
- CREATE TABLE cache (
- cache_id INTEGER PRIMARY KEY NOT NULL,
- host TEXT NOT NULL,
- port TEXT NOT NULL,
- version INTEGER,
- nonce INTEGER,
- serial INTEGER,
- updated INTEGER,
- refresh INTEGER,
- retry INTEGER,
- expire INTEGER,
- UNIQUE (host, port))''')
- cur.execute('''
- CREATE TABLE prefix (
- cache_id INTEGER NOT NULL
- REFERENCES cache(cache_id)
- ON DELETE CASCADE
- ON UPDATE CASCADE,
- asn INTEGER NOT NULL,
- prefix TEXT NOT NULL,
- prefixlen INTEGER NOT NULL,
- max_prefixlen INTEGER NOT NULL,
- UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''')
- cur.execute('''
- CREATE TABLE routerkey (
- cache_id INTEGER NOT NULL
- REFERENCES cache(cache_id)
- ON DELETE CASCADE
- ON UPDATE CASCADE,
- asn INTEGER NOT NULL,
- ski TEXT NOT NULL,
- key TEXT NOT NULL,
- UNIQUE (cache_id, asn, ski),
- UNIQUE (cache_id, asn, key))''')
- elif self.args.reset_session:
- cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
- cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated "
- "FROM cache WHERE host = ? AND port = ?",
- (self.host, self.port))
- try:
- self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone()
- if version is not None and self.version is not None and version != self.version:
- cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
- raise TypeError # Simulate lookup failure case
- if version is not None:
- self.version = version
- if refresh is not None:
+ self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone()
+ if version is not None and self.version is not None and version != self.version:
+ cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
+ raise TypeError # Simulate lookup failure case
+ if version is not None:
+ self.version = version
+ if refresh is not None:
+ self.refresh = refresh
+ if retry is not None:
+ self.retry = retry
+ if expire is not None:
+ self.expire = expire
+ if updated is not None:
+ self.updated = Timestamp(updated)
+ except TypeError:
+ cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port))
+ self.cache_id = cur.lastrowid
+ self.sql.commit()
+ logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]",
+ self.cache_id, self.version, self.nonce,
+ self.serial, self.refresh, self.retry, self.expire, self.updated)
+
+ def cache_reset(self):
+ """
+ Handle CacheResetPDU actions.
+ """
+
+ self.serial = None
+ if self.sql:
+ cur = self.sql.cursor()
+ cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id))
+ self.sql.commit()
+
+ def end_of_data(self, version, serial, nonce, refresh, retry, expire):
+ """
+ Handle EndOfDataPDU actions.
+ """
+
+ assert version == self.version
+ self.serial = serial
+ self.nonce = nonce
self.refresh = refresh
- if retry is not None:
- self.retry = retry
- if expire is not None:
- self.expire = expire
- if updated is not None:
- self.updated = Timestamp(updated)
- except TypeError:
- cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port))
- self.cache_id = cur.lastrowid
- self.sql.commit()
- logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]",
- self.cache_id, self.version, self.nonce,
- self.serial, self.refresh, self.retry, self.expire, self.updated)
-
- def cache_reset(self):
- """
- Handle CacheResetPDU actions.
- """
-
- self.serial = None
- if self.sql:
- cur = self.sql.cursor()
- cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,))
- cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,))
- cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id))
- self.sql.commit()
-
- def end_of_data(self, version, serial, nonce, refresh, retry, expire):
- """
- Handle EndOfDataPDU actions.
- """
-
- assert version == self.version
- self.serial = serial
- self.nonce = nonce
- self.refresh = refresh
- self.retry = retry
- self.expire = expire
- self.updated = Timestamp.now()
- if self.sql:
- self.sql.execute("UPDATE cache SET"
- " version = ?, serial = ?, nonce = ?,"
- " refresh = ?, retry = ?, expire = ?,"
- " updated = ? "
- "WHERE cache_id = ?",
- (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id))
- self.sql.commit()
-
- def consume_prefix(self, prefix):
- """
- Handle one prefix PDU.
- """
-
- if self.sql:
- values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen)
- if prefix.announce:
- self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) "
- "VALUES (?, ?, ?, ?, ?)",
- values)
- else:
- self.sql.execute("DELETE FROM prefix "
- "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?",
- values)
-
- def consume_routerkey(self, routerkey):
- """
- Handle one Router Key PDU.
- """
-
- if self.sql:
- values = (self.cache_id, routerkey.asn,
- base64.urlsafe_b64encode(routerkey.ski).rstrip("="),
- base64.b64encode(routerkey.key))
- if routerkey.announce:
- self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) "
- "VALUES (?, ?, ?, ?)",
- values)
- else:
- self.sql.execute("DELETE FROM routerkey "
- "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)",
- values)
-
- def deliver_pdu(self, pdu):
- """
- Handle received PDU.
- """
-
- pdu.consume(self)
-
- def push_pdu(self, pdu):
- """
- Log outbound PDU then write it to stream.
- """
-
- logging.debug(pdu)
- super(ClientChannel, self).push_pdu(pdu)
-
- def cleanup(self):
- """
- Force clean up this client's child process. If everything goes
- well, child will have exited already before this method is called,
- but we may need to whack it with a stick if something breaks.
- """
-
- if self.proc is not None and self.proc.returncode is None:
- try:
- os.kill(self.proc.pid, self.killsig)
- except OSError:
- pass
-
- def handle_close(self):
- """
- Intercept close event so we can log it, then shut down.
- """
-
- logging.debug("Server closed channel")
- super(ClientChannel, self).handle_close()
+ self.retry = retry
+ self.expire = expire
+ self.updated = Timestamp.now()
+ if self.sql:
+ self.sql.execute("UPDATE cache SET"
+ " version = ?, serial = ?, nonce = ?,"
+ " refresh = ?, retry = ?, expire = ?,"
+ " updated = ? "
+ "WHERE cache_id = ?",
+ (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id))
+ self.sql.commit()
+
+ def consume_prefix(self, prefix):
+ """
+ Handle one prefix PDU.
+ """
+
+ if self.sql:
+ values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen)
+ if prefix.announce:
+ self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) "
+ "VALUES (?, ?, ?, ?, ?)",
+ values)
+ else:
+ self.sql.execute("DELETE FROM prefix "
+ "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?",
+ values)
+
+ def consume_routerkey(self, routerkey):
+ """
+ Handle one Router Key PDU.
+ """
+
+ if self.sql:
+ values = (self.cache_id, routerkey.asn,
+ base64.urlsafe_b64encode(routerkey.ski).rstrip("="),
+ base64.b64encode(routerkey.key))
+ if routerkey.announce:
+ self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) "
+ "VALUES (?, ?, ?, ?)",
+ values)
+ else:
+ self.sql.execute("DELETE FROM routerkey "
+ "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)",
+ values)
+
+ def deliver_pdu(self, pdu):
+ """
+ Handle received PDU.
+ """
+
+ pdu.consume(self)
+
+ def push_pdu(self, pdu):
+ """
+ Log outbound PDU then write it to stream.
+ """
+
+ logging.debug(pdu)
+ super(ClientChannel, self).push_pdu(pdu)
+
+ def cleanup(self):
+ """
+ Force clean up this client's child process. If everything goes
+ well, child will have exited already before this method is called,
+ but we may need to whack it with a stick if something breaks.
+ """
+
+ if self.proc is not None and self.proc.returncode is None:
+ try:
+ os.kill(self.proc.pid, self.killsig)
+ except OSError:
+ pass
+
+ def handle_close(self):
+ """
+ Intercept close event so we can log it, then shut down.
+ """
+
+ logging.debug("Server closed channel")
+ super(ClientChannel, self).handle_close()
# Hack to let us subclass this from scripts without needing to rewrite client_main().
@@ -460,73 +460,73 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
ClientChannelClass = ClientChannel
def client_main(args):
- """
- Test client, intended primarily for debugging.
- """
+ """
+ Test client, intended primarily for debugging.
+ """
- logging.debug("[Startup]")
+ logging.debug("[Startup]")
- assert issubclass(ClientChannelClass, ClientChannel)
- constructor = getattr(ClientChannelClass, args.protocol)
+ assert issubclass(ClientChannelClass, ClientChannel)
+ constructor = getattr(ClientChannelClass, args.protocol)
- client = None
- try:
- client = constructor(args)
+ client = None
+ try:
+ client = constructor(args)
- polled = client.updated
- wakeup = None
+ polled = client.updated
+ wakeup = None
- while True:
+ while True:
- now = Timestamp.now()
+ now = Timestamp.now()
- if client.serial is not None and now > client.updated + client.expire:
- logging.info("[Expiring client data: serial %s, last updated %s, expire %s]",
- client.serial, client.updated, client.expire)
- client.cache_reset()
+ if client.serial is not None and now > client.updated + client.expire:
+ logging.info("[Expiring client data: serial %s, last updated %s, expire %s]",
+ client.serial, client.updated, client.expire)
+ client.cache_reset()
- if client.serial is None or client.nonce is None:
- polled = now
- client.push_pdu(ResetQueryPDU(version = client.version))
+ if client.serial is None or client.nonce is None:
+ polled = now
+ client.push_pdu(ResetQueryPDU(version = client.version))
- elif now >= client.updated + client.refresh:
- polled = now
- client.push_pdu(SerialQueryPDU(version = client.version,
- serial = client.serial,
- nonce = client.nonce))
+ elif now >= client.updated + client.refresh:
+ polled = now
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.serial,
+ nonce = client.nonce))
- remaining = 1
+ remaining = 1
- while remaining > 0:
- now = Timestamp.now()
- timer = client.retry if (now >= client.updated + client.refresh) else client.refresh
- wokeup = wakeup
- wakeup = max(now, Timestamp(max(polled, client.updated) + timer))
- remaining = wakeup - now
- if wakeup != wokeup:
- logging.info("[Last client poll %s, next %s]", polled, wakeup)
- asyncore.loop(timeout = remaining, count = 1)
+ while remaining > 0:
+ now = Timestamp.now()
+ timer = client.retry if (now >= client.updated + client.refresh) else client.refresh
+ wokeup = wakeup
+ wakeup = max(now, Timestamp(max(polled, client.updated) + timer))
+ remaining = wakeup - now
+ if wakeup != wokeup:
+ logging.info("[Last client poll %s, next %s]", polled, wakeup)
+ asyncore.loop(timeout = remaining, count = 1)
- except KeyboardInterrupt:
- sys.exit(0)
+ except KeyboardInterrupt:
+ sys.exit(0)
- finally:
- if client is not None:
- client.cleanup()
+ finally:
+ if client is not None:
+ client.cleanup()
def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- subparser = subparsers.add_parser("client", description = client_main.__doc__,
- help = "Test client for RPKI-RTR protocol")
- subparser.set_defaults(func = client_main, default_log_to = "stderr")
- subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state")
- subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version")
- subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database")
- subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol")
- subparser.add_argument("host", nargs = "?", help = "server host")
- subparser.add_argument("port", nargs = "?", help = "server port")
- return subparser
+ """
+ Set up argparse stuff for commands in this module.
+ """
+
+ subparser = subparsers.add_parser("client", description = client_main.__doc__,
+ help = "Test client for RPKI-RTR protocol")
+ subparser.set_defaults(func = client_main, default_log_to = "stderr")
+ subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state")
+ subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version")
+ subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database")
+ subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol")
+ subparser.add_argument("host", nargs = "?", help = "server host")
+ subparser.add_argument("port", nargs = "?", help = "server port")
+ return subparser
diff --git a/rpki/rtr/generator.py b/rpki/rtr/generator.py
index 26e25b6e..e00e44b7 100644
--- a/rpki/rtr/generator.py
+++ b/rpki/rtr/generator.py
@@ -37,539 +37,539 @@ import rpki.rtr.server
from rpki.rtr.channels import Timestamp
class PrefixPDU(rpki.rtr.pdus.PrefixPDU):
- """
- Object representing one prefix. This corresponds closely to one PDU
- in the rpki-router protocol, so closely that we use lexical ordering
- of the wire format of the PDU as the ordering for this class.
-
- This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (IPv4PrefixPDU or
- IPv6PrefixPDU) depending on the syntax of its input text.
- """
-
- @staticmethod
- def from_text(version, asn, addr):
- """
- Construct a prefix from its text form.
- """
-
- cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
- self = cls(version = version)
- self.asn = long(asn)
- p, l = addr.split("/")
- self.prefix = rpki.POW.IPAddress(p)
- if "-" in l:
- self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-"))
- else:
- self.prefixlen = self.max_prefixlen = int(l)
- self.announce = 1
- self.check()
- return self
-
- @staticmethod
- def from_roa(version, asn, prefix_tuple):
- """
- Construct a prefix from a ROA.
"""
-
- address, length, maxlength = prefix_tuple
- cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
- self = cls(version = version)
- self.asn = asn
- self.prefix = address
- self.prefixlen = length
- self.max_prefixlen = length if maxlength is None else maxlength
- self.announce = 1
- self.check()
- return self
+ Object representing one prefix. This corresponds closely to one PDU
+ in the rpki-router protocol, so closely that we use lexical ordering
+ of the wire format of the PDU as the ordering for this class.
+
+ This is a virtual class, but the .from_text() constructor
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
+ """
+
+ @staticmethod
+ def from_text(version, asn, addr):
+ """
+ Construct a prefix from its text form.
+ """
+
+ cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
+ self = cls(version = version)
+ self.asn = long(asn)
+ p, l = addr.split("/")
+ self.prefix = rpki.POW.IPAddress(p)
+ if "-" in l:
+ self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-"))
+ else:
+ self.prefixlen = self.max_prefixlen = int(l)
+ self.announce = 1
+ self.check()
+ return self
+
+ @staticmethod
+ def from_roa(version, asn, prefix_tuple):
+ """
+ Construct a prefix from a ROA.
+ """
+
+ address, length, maxlength = prefix_tuple
+ cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
+ self = cls(version = version)
+ self.asn = asn
+ self.prefix = address
+ self.prefixlen = length
+ self.max_prefixlen = length if maxlength is None else maxlength
+ self.announce = 1
+ self.check()
+ return self
class IPv4PrefixPDU(PrefixPDU):
- """
- IPv4 flavor of a prefix.
- """
+ """
+ IPv4 flavor of a prefix.
+ """
- pdu_type = 4
- address_byte_count = 4
+ pdu_type = 4
+ address_byte_count = 4
class IPv6PrefixPDU(PrefixPDU):
- """
- IPv6 flavor of a prefix.
- """
-
- pdu_type = 6
- address_byte_count = 16
-
-class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU):
- """
- Router Key PDU.
- """
-
- @classmethod
- def from_text(cls, version, asn, gski, key):
"""
- Construct a router key from its text form.
+ IPv6 flavor of a prefix.
"""
- self = cls(version = version)
- self.asn = long(asn)
- self.ski = base64.urlsafe_b64decode(gski + "=")
- self.key = base64.b64decode(key)
- self.announce = 1
- self.check()
- return self
+ pdu_type = 6
+ address_byte_count = 16
- @classmethod
- def from_certificate(cls, version, asn, ski, key):
+class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU):
"""
- Construct a router key from a certificate.
+ Router Key PDU.
"""
- self = cls(version = version)
- self.asn = asn
- self.ski = ski
- self.key = key
- self.announce = 1
- self.check()
- return self
+ @classmethod
+ def from_text(cls, version, asn, gski, key):
+ """
+ Construct a router key from its text form.
+ """
+ self = cls(version = version)
+ self.asn = long(asn)
+ self.ski = base64.urlsafe_b64decode(gski + "=")
+ self.key = base64.b64decode(key)
+ self.announce = 1
+ self.check()
+ return self
-class ROA(rpki.POW.ROA): # pylint: disable=W0232
- """
- Minor additions to rpki.POW.ROA.
- """
-
- @classmethod
- def derReadFile(cls, fn): # pylint: disable=E1002
- self = super(ROA, cls).derReadFile(fn)
- self.extractWithoutVerifying()
- return self
-
- @property
- def prefixes(self):
- v4, v6 = self.getPrefixes()
- if v4 is not None:
- for p in v4:
- yield p
- if v6 is not None:
- for p in v6:
- yield p
-
-class X509(rpki.POW.X509): # pylint: disable=W0232
- """
- Minor additions to rpki.POW.X509.
- """
+ @classmethod
+ def from_certificate(cls, version, asn, ski, key):
+ """
+ Construct a router key from a certificate.
+ """
- @property
- def asns(self):
- resources = self.getRFC3779()
- if resources is not None and resources[0] is not None:
- for min_asn, max_asn in resources[0]:
- for asn in xrange(min_asn, max_asn + 1):
- yield asn
+ self = cls(version = version)
+ self.asn = asn
+ self.ski = ski
+ self.key = key
+ self.announce = 1
+ self.check()
+ return self
-class PDUSet(list):
- """
- Object representing a set of PDUs, that is, one versioned and
- (theoretically) consistant set of prefixes and router keys extracted
- from rcynic's output.
- """
-
- def __init__(self, version):
- assert version in rpki.rtr.pdus.PDU.version_map
- super(PDUSet, self).__init__()
- self.version = version
-
- @classmethod
- def _load_file(cls, filename, version):
+class ROA(rpki.POW.ROA): # pylint: disable=W0232
"""
- Low-level method to read PDUSet from a file.
+ Minor additions to rpki.POW.ROA.
"""
- self = cls(version = version)
- f = open(filename, "rb")
- r = rpki.rtr.channels.ReadBuffer()
- while True:
- p = rpki.rtr.pdus.PDU.read_pdu(r)
- while p is None:
- b = f.read(r.needed())
- if b == "":
- assert r.available() == 0
- return self
- r.put(b)
- p = r.retry()
- assert p.version == self.version
- self.append(p)
-
- @staticmethod
- def seq_ge(a, b):
- return ((a - b) % (1 << 32)) < (1 << 31)
+ @classmethod
+ def derReadFile(cls, fn): # pylint: disable=E1002
+ self = super(ROA, cls).derReadFile(fn)
+ self.extractWithoutVerifying()
+ return self
+ @property
+ def prefixes(self):
+ v4, v6 = self.getPrefixes()
+ if v4 is not None:
+ for p in v4:
+ yield p
+ if v6 is not None:
+ for p in v6:
+ yield p
-class AXFRSet(PDUSet):
- """
- Object representing a complete set of PDUs, that is, one versioned
- and (theoretically) consistant set of prefixes and router
- certificates extracted from rcynic's output, all with the announce
- field set.
- """
-
- @classmethod
- def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None):
+class X509(rpki.POW.X509): # pylint: disable=W0232
+ """
+ Minor additions to rpki.POW.X509.
"""
- Parse ROAS and router certificates fetched (and validated!) by
- rcynic to create a new AXFRSet.
- In normal operation, we use os.walk() and the rpki.POW library to
- parse these data directly, but we can, if so instructed, use
- external programs instead, for testing, simulation, or to provide
- a way to inject local data.
+ @property
+ def asns(self):
+ resources = self.getRFC3779()
+ if resources is not None and resources[0] is not None:
+ for min_asn, max_asn in resources[0]:
+ for asn in xrange(min_asn, max_asn + 1):
+ yield asn
- At some point the ability to parse these data from external
- programs may move to a separate constructor function, so that we
- can make this one a bit simpler and faster.
- """
- self = cls(version = version)
- self.serial = rpki.rtr.channels.Timestamp.now()
-
- include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version]
-
- if scan_roas is None or (scan_routercerts is None and include_routercerts):
- for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612
- for fn in files:
- if scan_roas is None and fn.endswith(".roa"):
- roa = ROA.derReadFile(os.path.join(root, fn))
- asn = roa.getASID()
- self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple)
- for prefix_tuple in roa.prefixes)
- if include_routercerts and scan_routercerts is None and fn.endswith(".cer"):
- x = X509.derReadFile(os.path.join(root, fn))
- eku = x.getEKU()
- if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
- ski = x.getSKI()
- key = x.getPublicKey().derWritePublic()
- self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key)
- for asn in x.asns)
-
- if scan_roas is not None:
- try:
- p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE)
- for line in p.stdout:
- line = line.split()
- asn = line[1]
- self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr)
- for addr in line[2:])
- except OSError, e:
- sys.exit("Could not run %s: %s" % (scan_roas, e))
-
- if include_routercerts and scan_routercerts is not None:
- try:
- p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE)
- for line in p.stdout:
- line = line.split()
- gski = line[0]
- key = line[-1]
- self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key)
- for asn in line[1:-1])
- except OSError, e:
- sys.exit("Could not run %s: %s" % (scan_routercerts, e))
-
- self.sort()
- for i in xrange(len(self) - 2, -1, -1):
- if self[i] == self[i + 1]:
- del self[i + 1]
- return self
-
- @classmethod
- def load(cls, filename):
- """
- Load an AXFRSet from a file, parse filename to obtain version and serial.
+class PDUSet(list):
"""
+ Object representing a set of PDUs, that is, one versioned and
+ (theoretically) consistant set of prefixes and router keys extracted
+ from rcynic's output.
+ """
+
+ def __init__(self, version):
+ assert version in rpki.rtr.pdus.PDU.version_map
+ super(PDUSet, self).__init__()
+ self.version = version
+
+ @classmethod
+ def _load_file(cls, filename, version):
+ """
+ Low-level method to read PDUSet from a file.
+ """
+
+ self = cls(version = version)
+ f = open(filename, "rb")
+ r = rpki.rtr.channels.ReadBuffer()
+ while True:
+ p = rpki.rtr.pdus.PDU.read_pdu(r)
+ while p is None:
+ b = f.read(r.needed())
+ if b == "":
+ assert r.available() == 0
+ return self
+ r.put(b)
+ p = r.retry()
+ assert p.version == self.version
+ self.append(p)
+
+ @staticmethod
+ def seq_ge(a, b):
+ return ((a - b) % (1 << 32)) < (1 << 31)
- fn1, fn2, fn3 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit()
- version = int(fn3[1:])
- self = cls._load_file(filename, version)
- self.serial = rpki.rtr.channels.Timestamp(fn1)
- return self
- def filename(self):
- """
- Generate filename for this AXFRSet.
+class AXFRSet(PDUSet):
"""
+ Object representing a complete set of PDUs, that is, one versioned
+ and (theoretically) consistant set of prefixes and router
+ certificates extracted from rcynic's output, all with the announce
+ field set.
+ """
+
+ @classmethod
+ def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None):
+ """
+ Parse ROAS and router certificates fetched (and validated!) by
+ rcynic to create a new AXFRSet.
+
+ In normal operation, we use os.walk() and the rpki.POW library to
+ parse these data directly, but we can, if so instructed, use
+ external programs instead, for testing, simulation, or to provide
+ a way to inject local data.
+
+ At some point the ability to parse these data from external
+ programs may move to a separate constructor function, so that we
+ can make this one a bit simpler and faster.
+ """
+
+ self = cls(version = version)
+ self.serial = rpki.rtr.channels.Timestamp.now()
+
+ include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version]
+
+ if scan_roas is None or (scan_routercerts is None and include_routercerts):
+ for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612
+ for fn in files:
+ if scan_roas is None and fn.endswith(".roa"):
+ roa = ROA.derReadFile(os.path.join(root, fn))
+ asn = roa.getASID()
+ self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple)
+ for prefix_tuple in roa.prefixes)
+ if include_routercerts and scan_routercerts is None and fn.endswith(".cer"):
+ x = X509.derReadFile(os.path.join(root, fn))
+ eku = x.getEKU()
+ if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
+ ski = x.getSKI()
+ key = x.getPublicKey().derWritePublic()
+ self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key)
+ for asn in x.asns)
+
+ if scan_roas is not None:
+ try:
+ p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE)
+ for line in p.stdout:
+ line = line.split()
+ asn = line[1]
+ self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr)
+ for addr in line[2:])
+ except OSError, e:
+ sys.exit("Could not run %s: %s" % (scan_roas, e))
+
+ if include_routercerts and scan_routercerts is not None:
+ try:
+ p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE)
+ for line in p.stdout:
+ line = line.split()
+ gski = line[0]
+ key = line[-1]
+ self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key)
+ for asn in line[1:-1])
+ except OSError, e:
+ sys.exit("Could not run %s: %s" % (scan_routercerts, e))
+
+ self.sort()
+ for i in xrange(len(self) - 2, -1, -1):
+ if self[i] == self[i + 1]:
+ del self[i + 1]
+ return self
+
+ @classmethod
+ def load(cls, filename):
+ """
+ Load an AXFRSet from a file, parse filename to obtain version and serial.
+ """
+
+ fn1, fn2, fn3 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit()
+ version = int(fn3[1:])
+ self = cls._load_file(filename, version)
+ self.serial = rpki.rtr.channels.Timestamp(fn1)
+ return self
+
+ def filename(self):
+ """
+ Generate filename for this AXFRSet.
+ """
+
+ return "%d.ax.v%d" % (self.serial, self.version)
+
+ @classmethod
+ def load_current(cls, version):
+ """
+ Load current AXFRSet. Return None if can't.
+ """
+
+ serial = rpki.rtr.server.read_current(version)[0]
+ if serial is None:
+ return None
+ try:
+ return cls.load("%d.ax.v%d" % (serial, version))
+ except IOError:
+ return None
+
+ def save_axfr(self):
+ """
+ Write AXFRSet to file with magic filename.
+ """
+
+ f = open(self.filename(), "wb")
+ for p in self:
+ f.write(p.to_pdu())
+ f.close()
+
+ def destroy_old_data(self):
+ """
+ Destroy old data files, presumably because our nonce changed and
+ the old serial numbers are no longer valid.
+ """
+
+ for i in glob.iglob("*.ix.*.v%d" % self.version):
+ os.unlink(i)
+ for i in glob.iglob("*.ax.v%d" % self.version):
+ if i != self.filename():
+ os.unlink(i)
+
+ @staticmethod
+ def new_nonce(force_zero_nonce):
+ """
+ Create and return a new nonce value.
+ """
+
+ if force_zero_nonce:
+ return 0
+ try:
+ return int(random.SystemRandom().getrandbits(16))
+ except NotImplementedError:
+ return int(random.getrandbits(16))
+
+ def mark_current(self, force_zero_nonce = False):
+ """
+ Save current serial number and nonce, creating new nonce if
+ necessary. Creating a new nonce triggers cleanup of old state, as
+ the new nonce invalidates all old serial numbers.
+ """
+
+ assert self.version in rpki.rtr.pdus.PDU.version_map
+ old_serial, nonce = rpki.rtr.server.read_current(self.version)
+ if old_serial is None or self.seq_ge(old_serial, self.serial):
+ logging.debug("Creating new nonce and deleting stale data")
+ nonce = self.new_nonce(force_zero_nonce)
+ self.destroy_old_data()
+ rpki.rtr.server.write_current(self.serial, nonce, self.version)
+
+ def save_ixfr(self, other):
+ """
+ Comparing this AXFRSet with an older one and write the resulting
+ IXFRSet to file with magic filename. Since we store PDUSets
+ in sorted order, computing the difference is a trivial linear
+ comparison.
+ """
+
+ f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb")
+ old = other
+ new = self
+ len_old = len(old)
+ len_new = len(new)
+ i_old = i_new = 0
+ while i_old < len_old and i_new < len_new:
+ if old[i_old] < new[i_new]:
+ f.write(old[i_old].to_pdu(announce = 0))
+ i_old += 1
+ elif old[i_old] > new[i_new]:
+ f.write(new[i_new].to_pdu(announce = 1))
+ i_new += 1
+ else:
+ i_old += 1
+ i_new += 1
+ for i in xrange(i_old, len_old):
+ f.write(old[i].to_pdu(announce = 0))
+ for i in xrange(i_new, len_new):
+ f.write(new[i].to_pdu(announce = 1))
+ f.close()
+
+ def show(self):
+ """
+ Print this AXFRSet.
+ """
+
+ logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version)
+ for p in self:
+ logging.debug(p)
- return "%d.ax.v%d" % (self.serial, self.version)
- @classmethod
- def load_current(cls, version):
+class IXFRSet(PDUSet):
"""
- Load current AXFRSet. Return None if can't.
+ Object representing an incremental set of PDUs, that is, the
+ differences between one versioned and (theoretically) consistant set
+ of prefixes and router certificates extracted from rcynic's output
+ and another, with the announce fields set or cleared as necessary to
+ indicate the changes.
"""
- serial = rpki.rtr.server.read_current(version)[0]
- if serial is None:
- return None
- try:
- return cls.load("%d.ax.v%d" % (serial, version))
- except IOError:
- return None
+ @classmethod
+ def load(cls, filename):
+ """
+ Load an IXFRSet from a file, parse filename to obtain version and serials.
+ """
- def save_axfr(self):
- """
- Write AXFRSet to file with magic filename.
- """
+ fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit()
+ version = int(fn4[1:])
+ self = cls._load_file(filename, version)
+ self.from_serial = rpki.rtr.channels.Timestamp(fn3)
+ self.to_serial = rpki.rtr.channels.Timestamp(fn1)
+ return self
- f = open(self.filename(), "wb")
- for p in self:
- f.write(p.to_pdu())
- f.close()
+ def filename(self):
+ """
+ Generate filename for this IXFRSet.
+ """
- def destroy_old_data(self):
- """
- Destroy old data files, presumably because our nonce changed and
- the old serial numbers are no longer valid.
- """
+ return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version)
- for i in glob.iglob("*.ix.*.v%d" % self.version):
- os.unlink(i)
- for i in glob.iglob("*.ax.v%d" % self.version):
- if i != self.filename():
- os.unlink(i)
+ def show(self):
+ """
+ Print this IXFRSet.
+ """
- @staticmethod
- def new_nonce(force_zero_nonce):
- """
- Create and return a new nonce value.
- """
+ logging.debug("# IXFR %d (%s) -> %d (%s) v%d",
+ self.from_serial, self.from_serial,
+ self.to_serial, self.to_serial,
+ self.version)
+ for p in self:
+ logging.debug(p)
- if force_zero_nonce:
- return 0
- try:
- return int(random.SystemRandom().getrandbits(16))
- except NotImplementedError:
- return int(random.getrandbits(16))
- def mark_current(self, force_zero_nonce = False):
+def kick_all(serial):
"""
- Save current serial number and nonce, creating new nonce if
- necessary. Creating a new nonce triggers cleanup of old state, as
- the new nonce invalidates all old serial numbers.
+ Kick any existing server processes to wake them up.
"""
- assert self.version in rpki.rtr.pdus.PDU.version_map
- old_serial, nonce = rpki.rtr.server.read_current(self.version)
- if old_serial is None or self.seq_ge(old_serial, self.serial):
- logging.debug("Creating new nonce and deleting stale data")
- nonce = self.new_nonce(force_zero_nonce)
- self.destroy_old_data()
- rpki.rtr.server.write_current(self.serial, nonce, self.version)
+ try:
+ os.stat(rpki.rtr.server.kickme_dir)
+ except OSError:
+ logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir)
+ os.makedirs(rpki.rtr.server.kickme_dir)
+
+ msg = "Good morning, serial %d is ready" % serial
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base):
+ try:
+ logging.debug("# Kicking %s", name)
+ sock.sendto(msg, name)
+ except socket.error:
+ try:
+ logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name)
+ os.unlink(name)
+ except Exception, e:
+ logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e)
+ except Exception, e:
+ logging.warning("# Failed to kick %s and don't understand why: %s", name, e)
+ sock.close()
- def save_ixfr(self, other):
- """
- Comparing this AXFRSet with an older one and write the resulting
- IXFRSet to file with magic filename. Since we store PDUSets
- in sorted order, computing the difference is a trivial linear
- comparison.
- """
- f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb")
- old = other
- new = self
- len_old = len(old)
- len_new = len(new)
- i_old = i_new = 0
- while i_old < len_old and i_new < len_new:
- if old[i_old] < new[i_new]:
- f.write(old[i_old].to_pdu(announce = 0))
- i_old += 1
- elif old[i_old] > new[i_new]:
- f.write(new[i_new].to_pdu(announce = 1))
- i_new += 1
- else:
- i_old += 1
- i_new += 1
- for i in xrange(i_old, len_old):
- f.write(old[i].to_pdu(announce = 0))
- for i in xrange(i_new, len_new):
- f.write(new[i].to_pdu(announce = 1))
- f.close()
-
- def show(self):
- """
- Print this AXFRSet.
+def cronjob_main(args):
"""
-
- logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version)
- for p in self:
- logging.debug(p)
+ Run this right after running rcynic to wade through the ROAs and
+ router certificates that rcynic collects and translate that data
+ into the form used in the rpki-router protocol. Output is an
+ updated database containing both full dumps (AXFR) and incremental
+ dumps against a specific prior version (IXFR). After updating the
+ database, kicks any active servers, so that they can notify their
+ clients that a new version is available.
+ """
+
+ if args.rpki_rtr_dir:
+ try:
+ if not os.path.isdir(args.rpki_rtr_dir):
+ os.makedirs(args.rpki_rtr_dir)
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ logging.critical(str(e))
+ sys.exit(1)
+
+ for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True):
+
+ logging.debug("# Generating updates for protocol version %d", version)
+
+ old_ixfrs = glob.glob("*.ix.*.v%d" % version)
+
+ current = rpki.rtr.server.read_current(version)[0]
+ cutoff = Timestamp.now(-(24 * 60 * 60))
+ for f in glob.iglob("*.ax.v%d" % version):
+ t = Timestamp(int(f.split(".")[0]))
+ if t < cutoff and t != current:
+ logging.debug("# Deleting old file %s, timestamp %s", f, t)
+ os.unlink(f)
+
+ pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts)
+ if pdus == rpki.rtr.generator.AXFRSet.load_current(version):
+ logging.debug("# No change, new serial not needed")
+ continue
+ pdus.save_axfr()
+ for axfr in glob.iglob("*.ax.v%d" % version):
+ if axfr != pdus.filename():
+ pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr))
+ pdus.mark_current(args.force_zero_nonce)
+
+ logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial)
+
+ rpki.rtr.generator.kick_all(pdus.serial)
+
+ old_ixfrs.sort()
+ for ixfr in old_ixfrs:
+ try:
+ logging.debug("# Deleting old file %s", ixfr)
+ os.unlink(ixfr)
+ except OSError:
+ pass
-class IXFRSet(PDUSet):
- """
- Object representing an incremental set of PDUs, that is, the
- differences between one versioned and (theoretically) consistant set
- of prefixes and router certificates extracted from rcynic's output
- and another, with the announce fields set or cleared as necessary to
- indicate the changes.
- """
-
- @classmethod
- def load(cls, filename):
+def show_main(args):
"""
- Load an IXFRSet from a file, parse filename to obtain version and serials.
+ Display current rpki-rtr server database in textual form.
"""
- fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit()
- version = int(fn4[1:])
- self = cls._load_file(filename, version)
- self.from_serial = rpki.rtr.channels.Timestamp(fn3)
- self.to_serial = rpki.rtr.channels.Timestamp(fn1)
- return self
+ if args.rpki_rtr_dir:
+ try:
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ sys.exit(e)
- def filename(self):
- """
- Generate filename for this IXFRSet.
- """
+ g = glob.glob("*.ax.v*")
+ g.sort()
+ for f in g:
+ rpki.rtr.generator.AXFRSet.load(f).show()
- return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version)
+ g = glob.glob("*.ix.*.v*")
+ g.sort()
+ for f in g:
+ rpki.rtr.generator.IXFRSet.load(f).show()
- def show(self):
+def argparse_setup(subparsers):
"""
- Print this IXFRSet.
+ Set up argparse stuff for commands in this module.
"""
- logging.debug("# IXFR %d (%s) -> %d (%s) v%d",
- self.from_serial, self.from_serial,
- self.to_serial, self.to_serial,
- self.version)
- for p in self:
- logging.debug(p)
-
+ subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__,
+ help = "Generate RPKI-RTR database from rcynic output")
+ subparser.set_defaults(func = cronjob_main, default_log_to = "syslog")
+ subparser.add_argument("--scan-roas", help = "specify an external scan_roas program")
+ subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program")
+ subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero")
+ subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
-def kick_all(serial):
- """
- Kick any existing server processes to wake them up.
- """
-
- try:
- os.stat(rpki.rtr.server.kickme_dir)
- except OSError:
- logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir)
- os.makedirs(rpki.rtr.server.kickme_dir)
-
- msg = "Good morning, serial %d is ready" % serial
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
- for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base):
- try:
- logging.debug("# Kicking %s", name)
- sock.sendto(msg, name)
- except socket.error:
- try:
- logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name)
- os.unlink(name)
- except Exception, e:
- logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e)
- except Exception, e:
- logging.warning("# Failed to kick %s and don't understand why: %s", name, e)
- sock.close()
-
-
-def cronjob_main(args):
- """
- Run this right after running rcynic to wade through the ROAs and
- router certificates that rcynic collects and translate that data
- into the form used in the rpki-router protocol. Output is an
- updated database containing both full dumps (AXFR) and incremental
- dumps against a specific prior version (IXFR). After updating the
- database, kicks any active servers, so that they can notify their
- clients that a new version is available.
- """
-
- if args.rpki_rtr_dir:
- try:
- if not os.path.isdir(args.rpki_rtr_dir):
- os.makedirs(args.rpki_rtr_dir)
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- logging.critical(str(e))
- sys.exit(1)
-
- for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True):
-
- logging.debug("# Generating updates for protocol version %d", version)
-
- old_ixfrs = glob.glob("*.ix.*.v%d" % version)
-
- current = rpki.rtr.server.read_current(version)[0]
- cutoff = Timestamp.now(-(24 * 60 * 60))
- for f in glob.iglob("*.ax.v%d" % version):
- t = Timestamp(int(f.split(".")[0]))
- if t < cutoff and t != current:
- logging.debug("# Deleting old file %s, timestamp %s", f, t)
- os.unlink(f)
-
- pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts)
- if pdus == rpki.rtr.generator.AXFRSet.load_current(version):
- logging.debug("# No change, new serial not needed")
- continue
- pdus.save_axfr()
- for axfr in glob.iglob("*.ax.v%d" % version):
- if axfr != pdus.filename():
- pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr))
- pdus.mark_current(args.force_zero_nonce)
-
- logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial)
-
- rpki.rtr.generator.kick_all(pdus.serial)
-
- old_ixfrs.sort()
- for ixfr in old_ixfrs:
- try:
- logging.debug("# Deleting old file %s", ixfr)
- os.unlink(ixfr)
- except OSError:
- pass
-
-
-def show_main(args):
- """
- Display current rpki-rtr server database in textual form.
- """
-
- if args.rpki_rtr_dir:
- try:
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- sys.exit(e)
-
- g = glob.glob("*.ax.v*")
- g.sort()
- for f in g:
- rpki.rtr.generator.AXFRSet.load(f).show()
-
- g = glob.glob("*.ix.*.v*")
- g.sort()
- for f in g:
- rpki.rtr.generator.IXFRSet.load(f).show()
-
-def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__,
- help = "Generate RPKI-RTR database from rcynic output")
- subparser.set_defaults(func = cronjob_main, default_log_to = "syslog")
- subparser.add_argument("--scan-roas", help = "specify an external scan_roas program")
- subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program")
- subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero")
- subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
-
- subparser = subparsers.add_parser("show", description = show_main.__doc__,
- help = "Display content of RPKI-RTR database")
- subparser.set_defaults(func = show_main, default_log_to = "stderr")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+ subparser = subparsers.add_parser("show", description = show_main.__doc__,
+ help = "Display content of RPKI-RTR database")
+ subparser.set_defaults(func = show_main, default_log_to = "stderr")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
diff --git a/rpki/rtr/main.py b/rpki/rtr/main.py
index 12de30cc..34f5598d 100644
--- a/rpki/rtr/main.py
+++ b/rpki/rtr/main.py
@@ -31,64 +31,64 @@ import argparse
class Formatter(logging.Formatter):
- converter = time.gmtime
+ converter = time.gmtime
- def __init__(self, debug, fmt, datefmt):
- self.debug = debug
- super(Formatter, self).__init__(fmt, datefmt)
+ def __init__(self, debug, fmt, datefmt):
+ self.debug = debug
+ super(Formatter, self).__init__(fmt, datefmt)
- def format(self, record):
- if getattr(record, "connection", None) is None:
- record.connection = ""
- return super(Formatter, self).format(record)
+ def format(self, record):
+ if getattr(record, "connection", None) is None:
+ record.connection = ""
+ return super(Formatter, self).format(record)
- def formatException(self, ei):
- if self.debug:
- return super(Formatter, self).formatException(ei)
- else:
- return str(ei[1])
+ def formatException(self, ei):
+ if self.debug:
+ return super(Formatter, self).formatException(ei)
+ else:
+ return str(ei[1])
def main():
- os.environ["TZ"] = "UTC"
- time.tzset()
-
- from rpki.rtr.server import argparse_setup as argparse_setup_server
- from rpki.rtr.client import argparse_setup as argparse_setup_client
- from rpki.rtr.generator import argparse_setup as argparse_setup_generator
-
- if "rpki.rtr.bgpdump" in sys.modules:
- from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump
- else:
- def argparse_setup_bgpdump(ignored):
- pass
-
- argparser = argparse.ArgumentParser(description = __doc__)
- argparser.add_argument("--debug", action = "store_true", help = "debugging mode")
- argparser.add_argument("--log-level", default = "debug",
- choices = ("debug", "info", "warning", "error", "critical"),
- type = lambda s: s.lower())
- argparser.add_argument("--log-to",
- choices = ("syslog", "stderr"))
- subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode")
- argparse_setup_server(subparsers)
- argparse_setup_client(subparsers)
- argparse_setup_generator(subparsers)
- argparse_setup_bgpdump(subparsers)
- args = argparser.parse_args()
-
- fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s"
-
- if (args.log_to or args.default_log_to) == "stderr":
- handler = logging.StreamHandler()
- fmt = "%(asctime)s " + fmt
- elif os.path.exists("/dev/log"):
- handler = logging.handlers.SysLogHandler("/dev/log")
- else:
- handler = logging.handlers.SysLogHandler()
-
- handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ"))
- logging.root.addHandler(handler)
- logging.root.setLevel(int(getattr(logging, args.log_level.upper())))
-
- return args.func(args)
+ os.environ["TZ"] = "UTC"
+ time.tzset()
+
+ from rpki.rtr.server import argparse_setup as argparse_setup_server
+ from rpki.rtr.client import argparse_setup as argparse_setup_client
+ from rpki.rtr.generator import argparse_setup as argparse_setup_generator
+
+ if "rpki.rtr.bgpdump" in sys.modules:
+ from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump
+ else:
+ def argparse_setup_bgpdump(ignored):
+ pass
+
+ argparser = argparse.ArgumentParser(description = __doc__)
+ argparser.add_argument("--debug", action = "store_true", help = "debugging mode")
+ argparser.add_argument("--log-level", default = "debug",
+ choices = ("debug", "info", "warning", "error", "critical"),
+ type = lambda s: s.lower())
+ argparser.add_argument("--log-to",
+ choices = ("syslog", "stderr"))
+ subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode")
+ argparse_setup_server(subparsers)
+ argparse_setup_client(subparsers)
+ argparse_setup_generator(subparsers)
+ argparse_setup_bgpdump(subparsers)
+ args = argparser.parse_args()
+
+ fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s"
+
+ if (args.log_to or args.default_log_to) == "stderr":
+ handler = logging.StreamHandler()
+ fmt = "%(asctime)s " + fmt
+ elif os.path.exists("/dev/log"):
+ handler = logging.handlers.SysLogHandler("/dev/log")
+ else:
+ handler = logging.handlers.SysLogHandler()
+
+ handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ"))
+ logging.root.addHandler(handler)
+ logging.root.setLevel(int(getattr(logging, args.log_level.upper())))
+
+ return args.func(args)
diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py
index 0d2e5928..94f579a1 100644
--- a/rpki/rtr/pdus.py
+++ b/rpki/rtr/pdus.py
@@ -28,292 +28,292 @@ import rpki.POW
# Exceptions
class PDUException(Exception):
- """
- Parent exception type for exceptions that signal particular protocol
- errors. String value of exception instance will be the message to
- put in the ErrorReportPDU, error_report_code value of exception
- will be the numeric code to use.
- """
-
- def __init__(self, msg = None, pdu = None):
- super(PDUException, self).__init__()
- assert msg is None or isinstance(msg, (str, unicode))
- self.error_report_msg = msg
- self.error_report_pdu = pdu
-
- def __str__(self):
- return self.error_report_msg or self.__class__.__name__
-
- def make_error_report(self, version):
- return ErrorReportPDU(version = version,
- errno = self.error_report_code,
- errmsg = self.error_report_msg,
- errpdu = self.error_report_pdu)
+ """
+ Parent exception type for exceptions that signal particular protocol
+ errors. String value of exception instance will be the message to
+ put in the ErrorReportPDU, error_report_code value of exception
+ will be the numeric code to use.
+ """
+
+ def __init__(self, msg = None, pdu = None):
+ super(PDUException, self).__init__()
+ assert msg is None or isinstance(msg, (str, unicode))
+ self.error_report_msg = msg
+ self.error_report_pdu = pdu
+
+ def __str__(self):
+ return self.error_report_msg or self.__class__.__name__
+
+ def make_error_report(self, version):
+ return ErrorReportPDU(version = version,
+ errno = self.error_report_code,
+ errmsg = self.error_report_msg,
+ errpdu = self.error_report_pdu)
class UnsupportedProtocolVersion(PDUException):
- error_report_code = 4
+ error_report_code = 4
class UnsupportedPDUType(PDUException):
- error_report_code = 5
+ error_report_code = 5
class CorruptData(PDUException):
- error_report_code = 0
+ error_report_code = 0
# Decorators
def wire_pdu(cls, versions = None):
- """
- Class decorator to add a PDU class to the set of known PDUs
- for all supported protocol versions.
- """
+ """
+ Class decorator to add a PDU class to the set of known PDUs
+ for all supported protocol versions.
+ """
- for v in PDU.version_map.iterkeys() if versions is None else versions:
- assert cls.pdu_type not in PDU.version_map[v]
- PDU.version_map[v][cls.pdu_type] = cls
- return cls
+ for v in PDU.version_map.iterkeys() if versions is None else versions:
+ assert cls.pdu_type not in PDU.version_map[v]
+ PDU.version_map[v][cls.pdu_type] = cls
+ return cls
def wire_pdu_only(*versions):
- """
- Class decorator to add a PDU class to the set of known PDUs
- for specific protocol versions.
- """
+ """
+ Class decorator to add a PDU class to the set of known PDUs
+ for specific protocol versions.
+ """
- assert versions and all(v in PDU.version_map for v in versions)
- return lambda cls: wire_pdu(cls, versions)
+ assert versions and all(v in PDU.version_map for v in versions)
+ return lambda cls: wire_pdu(cls, versions)
def clone_pdu_root(root_pdu_class):
- """
- Replace a PDU root class's version_map with a two-level deep copy of itself,
- and return a class decorator which subclasses can use to replace their
- parent classes with themselves in the resulting cloned version map.
+ """
+ Replace a PDU root class's version_map with a two-level deep copy of itself,
+ and return a class decorator which subclasses can use to replace their
+ parent classes with themselves in the resulting cloned version map.
- This function is not itself a decorator, it returns one.
- """
+ This function is not itself a decorator, it returns one.
+ """
- root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems())
+ root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems())
- def decorator(cls):
- for pdu_map in root_pdu_class.version_map.itervalues():
- for pdu_type, pdu_class in pdu_map.items():
- if pdu_class in cls.__bases__:
- pdu_map[pdu_type] = cls
- return cls
+ def decorator(cls):
+ for pdu_map in root_pdu_class.version_map.itervalues():
+ for pdu_type, pdu_class in pdu_map.items():
+ if pdu_class in cls.__bases__:
+ pdu_map[pdu_type] = cls
+ return cls
- return decorator
+ return decorator
# PDUs
class PDU(object):
- """
- Base PDU. Real PDUs are subclasses of this class.
- """
+ """
+ Base PDU. Real PDUs are subclasses of this class.
+ """
- version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu
+ version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu
- _pdu = None # Cached when first generated
+ _pdu = None # Cached when first generated
- header_struct = struct.Struct("!BB2xL")
+ header_struct = struct.Struct("!BB2xL")
- def __init__(self, version):
- assert version in self.version_map
- self.version = version
+ def __init__(self, version):
+ assert version in self.version_map
+ self.version = version
- def __cmp__(self, other):
- return cmp(self.to_pdu(), other.to_pdu())
+ def __cmp__(self, other):
+ return cmp(self.to_pdu(), other.to_pdu())
- @property
- def default_version(self):
- return max(self.version_map.iterkeys())
+ @property
+ def default_version(self):
+ return max(self.version_map.iterkeys())
- def check(self):
- pass
+ def check(self):
+ pass
- @classmethod
- def read_pdu(cls, reader):
- return reader.update(need = cls.header_struct.size, callback = cls.got_header)
+ @classmethod
+ def read_pdu(cls, reader):
+ return reader.update(need = cls.header_struct.size, callback = cls.got_header)
- @classmethod
- def got_header(cls, reader):
- if not reader.ready():
- return None
- assert reader.available() >= cls.header_struct.size
- version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
- reader.check_version(version)
- if pdu_type not in cls.version_map[version]:
- raise UnsupportedPDUType(
- "Received unsupported PDU type %d" % pdu_type)
- if length < 8:
- raise CorruptData(
- "Received PDU with length %d, which is too short to be valid" % length)
- self = cls.version_map[version][pdu_type](version = version)
- return reader.update(need = length, callback = self.got_pdu)
+ @classmethod
+ def got_header(cls, reader):
+ if not reader.ready():
+ return None
+ assert reader.available() >= cls.header_struct.size
+ version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
+ reader.check_version(version)
+ if pdu_type not in cls.version_map[version]:
+ raise UnsupportedPDUType(
+ "Received unsupported PDU type %d" % pdu_type)
+ if length < 8:
+ raise CorruptData(
+ "Received PDU with length %d, which is too short to be valid" % length)
+ self = cls.version_map[version][pdu_type](version = version)
+ return reader.update(need = length, callback = self.got_pdu)
class PDUWithSerial(PDU):
- """
- Base class for PDUs consisting of just a serial number and nonce.
- """
-
- header_struct = struct.Struct("!BBHLL")
-
- def __init__(self, version, serial = None, nonce = None):
- super(PDUWithSerial, self).__init__(version)
- if serial is not None:
- assert isinstance(serial, int)
- self.serial = serial
- if nonce is not None:
- assert isinstance(nonce, int)
- self.nonce = nonce
-
- def __str__(self):
- return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce)
-
- def to_pdu(self):
"""
- Generate the wire format PDU.
+ Base class for PDUs consisting of just a serial number and nonce.
"""
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
- self.header_struct.size, self.serial)
- return self._pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if length != 12:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ header_struct = struct.Struct("!BBHLL")
+
+ def __init__(self, version, serial = None, nonce = None):
+ super(PDUWithSerial, self).__init__(version)
+ if serial is not None:
+ assert isinstance(serial, int)
+ self.serial = serial
+ if nonce is not None:
+ assert isinstance(nonce, int)
+ self.nonce = nonce
+
+ def __str__(self):
+ return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce)
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
+
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
+ self.header_struct.size, self.serial)
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 12:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
class PDUWithNonce(PDU):
- """
- Base class for PDUs consisting of just a nonce.
- """
-
- header_struct = struct.Struct("!BBHL")
-
- def __init__(self, version, nonce = None):
- super(PDUWithNonce, self).__init__(version)
- if nonce is not None:
- assert isinstance(nonce, int)
- self.nonce = nonce
-
- def __str__(self):
- return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce)
-
- def to_pdu(self):
"""
- Generate the wire format PDU.
+ Base class for PDUs consisting of just a nonce.
"""
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size)
- return self._pdu
+ header_struct = struct.Struct("!BBHL")
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if length != 8:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ def __init__(self, version, nonce = None):
+ super(PDUWithNonce, self).__init__(version)
+ if nonce is not None:
+ assert isinstance(nonce, int)
+ self.nonce = nonce
+ def __str__(self):
+ return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce)
-class PDUEmpty(PDU):
- """
- Base class for empty PDUs.
- """
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
- header_struct = struct.Struct("!BBHL")
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size)
+ return self._pdu
- def __str__(self):
- return "[%s]" % self.__class__.__name__
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 8:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
- def to_pdu(self):
+
+class PDUEmpty(PDU):
"""
- Generate the wire format PDU for this prefix.
+ Base class for empty PDUs.
"""
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size)
- return self._pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, zero, length = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if zero != 0:
- raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
- if length != 8:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ header_struct = struct.Struct("!BBHL")
+
+ def __str__(self):
+ return "[%s]" % self.__class__.__name__
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU for this prefix.
+ """
+
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size)
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, zero, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if zero != 0:
+ raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
+ if length != 8:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
@wire_pdu
class SerialNotifyPDU(PDUWithSerial):
- """
- Serial Notify PDU.
- """
+ """
+ Serial Notify PDU.
+ """
- pdu_type = 0
+ pdu_type = 0
@wire_pdu
class SerialQueryPDU(PDUWithSerial):
- """
- Serial Query PDU.
- """
+ """
+ Serial Query PDU.
+ """
- pdu_type = 1
+ pdu_type = 1
- def __init__(self, version, serial = None, nonce = None):
- super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce)
+ def __init__(self, version, serial = None, nonce = None):
+ super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce)
@wire_pdu
class ResetQueryPDU(PDUEmpty):
- """
- Reset Query PDU.
- """
+ """
+ Reset Query PDU.
+ """
- pdu_type = 2
+ pdu_type = 2
- def __init__(self, version):
- super(ResetQueryPDU, self).__init__(self.default_version if version is None else version)
+ def __init__(self, version):
+ super(ResetQueryPDU, self).__init__(self.default_version if version is None else version)
@wire_pdu
class CacheResponsePDU(PDUWithNonce):
- """
- Cache Response PDU.
- """
+ """
+ Cache Response PDU.
+ """
- pdu_type = 3
+ pdu_type = 3
def EndOfDataPDU(version, *args, **kwargs):
- """
- Factory for the EndOfDataPDU classes, which take different forms in
- different protocol versions.
- """
+ """
+ Factory for the EndOfDataPDU classes, which take different forms in
+ different protocol versions.
+ """
- if version == 0:
- return EndOfDataPDUv0(version, *args, **kwargs)
- if version == 1:
- return EndOfDataPDUv1(version, *args, **kwargs)
- raise NotImplementedError
+ if version == 0:
+ return EndOfDataPDUv0(version, *args, **kwargs)
+ if version == 1:
+ return EndOfDataPDUv1(version, *args, **kwargs)
+ raise NotImplementedError
# Min, max, and default values, from the current RFC 6810 bis I-D.
@@ -324,325 +324,325 @@ def EndOfDataPDU(version, *args, **kwargs):
default_refresh = 3600
def valid_refresh(refresh):
- if not isinstance(refresh, int) or refresh < 120 or refresh > 86400:
- raise ValueError
- return refresh
+ if not isinstance(refresh, int) or refresh < 120 or refresh > 86400:
+ raise ValueError
+ return refresh
default_retry = 600
def valid_retry(retry):
- if not isinstance(retry, int) or retry < 120 or retry > 7200:
- raise ValueError
- return retry
+ if not isinstance(retry, int) or retry < 120 or retry > 7200:
+ raise ValueError
+ return retry
default_expire = 7200
def valid_expire(expire):
- if not isinstance(expire, int) or expire < 600 or expire > 172800:
- raise ValueError
- return expire
+ if not isinstance(expire, int) or expire < 600 or expire > 172800:
+ raise ValueError
+ return expire
@wire_pdu_only(0)
class EndOfDataPDUv0(PDUWithSerial):
- """
- End of Data PDU, protocol version 0.
- """
+ """
+ End of Data PDU, protocol version 0.
+ """
- pdu_type = 7
+ pdu_type = 7
- def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None):
- super(EndOfDataPDUv0, self).__init__(version, serial, nonce)
- self.refresh = valid_refresh(default_refresh if refresh is None else refresh)
- self.retry = valid_retry( default_retry if retry is None else retry)
- self.expire = valid_expire( default_expire if expire is None else expire)
+ def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None):
+ super(EndOfDataPDUv0, self).__init__(version, serial, nonce)
+ self.refresh = valid_refresh(default_refresh if refresh is None else refresh)
+ self.retry = valid_retry( default_retry if retry is None else retry)
+ self.expire = valid_expire( default_expire if expire is None else expire)
@wire_pdu_only(1)
class EndOfDataPDUv1(EndOfDataPDUv0):
- """
- End of Data PDU, protocol version 1.
- """
+ """
+ End of Data PDU, protocol version 1.
+ """
- header_struct = struct.Struct("!BBHLLLLL")
+ header_struct = struct.Struct("!BBHLLLLL")
- def __str__(self):
- return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % (
- self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+ def __str__(self):
+ return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % (
+ self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire)
- def to_pdu(self):
- """
- Generate the wire format PDU.
- """
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
- self.header_struct.size, self.serial,
- self.refresh, self.retry, self.expire)
- return self._pdu
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
+ self.header_struct.size, self.serial,
+ self.refresh, self.retry, self.expire)
+ return self._pdu
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \
- = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if length != 24:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \
+ = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 24:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
@wire_pdu
class CacheResetPDU(PDUEmpty):
- """
- Cache reset PDU.
- """
+ """
+ Cache reset PDU.
+ """
- pdu_type = 8
+ pdu_type = 8
class PrefixPDU(PDU):
- """
- Object representing one prefix. This corresponds closely to one PDU
- in the rpki-router protocol, so closely that we use lexical ordering
- of the wire format of the PDU as the ordering for this class.
-
- This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (IPv4PrefixPDU or
- IPv6PrefixPDU) depending on the syntax of its input text.
- """
-
- header_struct = struct.Struct("!BB2xLBBBx")
- asnum_struct = struct.Struct("!L")
-
- def __str__(self):
- plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen)
- return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm,
- ":".join(("%02X" % ord(b) for b in self.to_pdu())))
-
- def show(self):
- logging.debug("# Class: %s", self.__class__.__name__)
- logging.debug("# ASN: %s", self.asn)
- logging.debug("# Prefix: %s", self.prefix)
- logging.debug("# Prefixlen: %s", self.prefixlen)
- logging.debug("# MaxPrefixlen: %s", self.max_prefixlen)
- logging.debug("# Announce: %s", self.announce)
-
- def check(self):
- """
- Check attributes to make sure they're within range.
- """
-
- if self.announce not in (0, 1):
- raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
- if self.prefix.bits != self.address_byte_count * 8:
- raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self)
- if self.prefixlen < 0 or self.prefixlen > self.prefix.bits:
- raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self)
- if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits:
- raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self)
- pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
- if len(self.to_pdu()) != pdulen:
- raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
-
- def to_pdu(self, announce = None):
- """
- Generate the wire format PDU for this prefix.
- """
-
- if announce is not None:
- assert announce in (0, 1)
- elif self._pdu is not None:
- return self._pdu
- pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
- pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen,
- announce if announce is not None else self.announce,
- self.prefixlen, self.max_prefixlen) +
- self.prefix.toBytes() +
- self.asnum_struct.pack(self.asn))
- if announce is None:
- assert self._pdu is None
- self._pdu = pdu
- return pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b1 = reader.get(self.header_struct.size)
- b2 = reader.get(self.address_byte_count)
- b3 = reader.get(self.asnum_struct.size)
- version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
- assert version == self.version and pdu_type == self.pdu_type
- if length != len(b1) + len(b2) + len(b3):
- raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
- self.prefix = rpki.POW.IPAddress.fromBytes(b2)
- self.asn = self.asnum_struct.unpack(b3)[0]
- assert b1 + b2 + b3 == self.to_pdu()
- return self
+ """
+ Object representing one prefix. This corresponds closely to one PDU
+ in the rpki-router protocol, so closely that we use lexical ordering
+ of the wire format of the PDU as the ordering for this class.
+
+ This is a virtual class, but the .from_text() constructor
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
+ """
+
+ header_struct = struct.Struct("!BB2xLBBBx")
+ asnum_struct = struct.Struct("!L")
+
+ def __str__(self):
+ plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen)
+ return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm,
+ ":".join(("%02X" % ord(b) for b in self.to_pdu())))
+
+ def show(self):
+ logging.debug("# Class: %s", self.__class__.__name__)
+ logging.debug("# ASN: %s", self.asn)
+ logging.debug("# Prefix: %s", self.prefix)
+ logging.debug("# Prefixlen: %s", self.prefixlen)
+ logging.debug("# MaxPrefixlen: %s", self.max_prefixlen)
+ logging.debug("# Announce: %s", self.announce)
+
+ def check(self):
+ """
+ Check attributes to make sure they're within range.
+ """
+
+ if self.announce not in (0, 1):
+ raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
+ if self.prefix.bits != self.address_byte_count * 8:
+ raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self)
+ if self.prefixlen < 0 or self.prefixlen > self.prefix.bits:
+ raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self)
+ if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits:
+ raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self)
+ pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
+ if len(self.to_pdu()) != pdulen:
+ raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
+
+ def to_pdu(self, announce = None):
+ """
+ Generate the wire format PDU for this prefix.
+ """
+
+ if announce is not None:
+ assert announce in (0, 1)
+ elif self._pdu is not None:
+ return self._pdu
+ pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
+ pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen,
+ announce if announce is not None else self.announce,
+ self.prefixlen, self.max_prefixlen) +
+ self.prefix.toBytes() +
+ self.asnum_struct.pack(self.asn))
+ if announce is None:
+ assert self._pdu is None
+ self._pdu = pdu
+ return pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b1 = reader.get(self.header_struct.size)
+ b2 = reader.get(self.address_byte_count)
+ b3 = reader.get(self.asnum_struct.size)
+ version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != len(b1) + len(b2) + len(b3):
+ raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
+ self.prefix = rpki.POW.IPAddress.fromBytes(b2)
+ self.asn = self.asnum_struct.unpack(b3)[0]
+ assert b1 + b2 + b3 == self.to_pdu()
+ return self
@wire_pdu
class IPv4PrefixPDU(PrefixPDU):
- """
- IPv4 flavor of a prefix.
- """
+ """
+ IPv4 flavor of a prefix.
+ """
- pdu_type = 4
- address_byte_count = 4
+ pdu_type = 4
+ address_byte_count = 4
@wire_pdu
class IPv6PrefixPDU(PrefixPDU):
- """
- IPv6 flavor of a prefix.
- """
+ """
+ IPv6 flavor of a prefix.
+ """
- pdu_type = 6
- address_byte_count = 16
+ pdu_type = 6
+ address_byte_count = 16
@wire_pdu_only(1)
class RouterKeyPDU(PDU):
- """
- Router Key PDU.
- """
-
- pdu_type = 9
-
- header_struct = struct.Struct("!BBBxL20sL")
-
- def __str__(self):
- return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn,
- base64.urlsafe_b64encode(self.ski).rstrip("="),
- ":".join(("%02X" % ord(b) for b in self.to_pdu())))
-
- def check(self):
- """
- Check attributes to make sure they're within range.
- """
-
- if self.announce not in (0, 1):
- raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
- if len(self.ski) != 20:
- raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self)
- pdulen = self.header_struct.size + len(self.key)
- if len(self.to_pdu()) != pdulen:
- raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
-
- def to_pdu(self, announce = None):
- if announce is not None:
- assert announce in (0, 1)
- elif self._pdu is not None:
- return self._pdu
- pdulen = self.header_struct.size + len(self.key)
- pdu = (self.header_struct.pack(self.version,
- self.pdu_type,
- announce if announce is not None else self.announce,
- pdulen,
- self.ski,
- self.asn)
- + self.key)
- if announce is None:
- assert self._pdu is None
- self._pdu = pdu
- return pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- header = reader.get(self.header_struct.size)
- version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header)
- assert version == self.version and pdu_type == self.pdu_type
- remaining = length - self.header_struct.size
- if remaining <= 0:
- raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self)
- self.key = reader.get(remaining)
- assert header + self.key == self.to_pdu()
- return self
+ """
+ Router Key PDU.
+ """
+
+ pdu_type = 9
+
+ header_struct = struct.Struct("!BBBxL20sL")
+
+ def __str__(self):
+ return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn,
+ base64.urlsafe_b64encode(self.ski).rstrip("="),
+ ":".join(("%02X" % ord(b) for b in self.to_pdu())))
+
+ def check(self):
+ """
+ Check attributes to make sure they're within range.
+ """
+
+ if self.announce not in (0, 1):
+ raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
+ if len(self.ski) != 20:
+ raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self)
+ pdulen = self.header_struct.size + len(self.key)
+ if len(self.to_pdu()) != pdulen:
+ raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
+
+ def to_pdu(self, announce = None):
+ if announce is not None:
+ assert announce in (0, 1)
+ elif self._pdu is not None:
+ return self._pdu
+ pdulen = self.header_struct.size + len(self.key)
+ pdu = (self.header_struct.pack(self.version,
+ self.pdu_type,
+ announce if announce is not None else self.announce,
+ pdulen,
+ self.ski,
+ self.asn)
+ + self.key)
+ if announce is None:
+ assert self._pdu is None
+ self._pdu = pdu
+ return pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ header = reader.get(self.header_struct.size)
+ version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
+ remaining = length - self.header_struct.size
+ if remaining <= 0:
+ raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self)
+ self.key = reader.get(remaining)
+ assert header + self.key == self.to_pdu()
+ return self
@wire_pdu
class ErrorReportPDU(PDU):
- """
- Error Report PDU.
- """
-
- pdu_type = 10
-
- header_struct = struct.Struct("!BBHL")
- string_struct = struct.Struct("!L")
-
- errors = {
- 2 : "No Data Available" }
-
- fatal = {
- 0 : "Corrupt Data",
- 1 : "Internal Error",
- 3 : "Invalid Request",
- 4 : "Unsupported Protocol Version",
- 5 : "Unsupported PDU Type",
- 6 : "Withdrawal of Unknown Record",
- 7 : "Duplicate Announcement Received" }
-
- assert set(errors) & set(fatal) == set()
-
- errors.update(fatal)
-
- codes = dict((v, k) for k, v in errors.items())
-
- def __init__(self, version, errno = None, errpdu = None, errmsg = None):
- super(ErrorReportPDU, self).__init__(version)
- assert errno is None or errno in self.errors
- self.errno = errno
- self.errpdu = errpdu
- self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno]
-
- def __str__(self):
- return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg)
-
- def to_counted_string(self, s):
- return self.string_struct.pack(len(s)) + s
-
- def read_counted_string(self, reader, remaining):
- assert remaining >= self.string_struct.size
- n = self.string_struct.unpack(reader.get(self.string_struct.size))[0]
- assert remaining >= self.string_struct.size + n
- return n, reader.get(n), (remaining - self.string_struct.size - n)
-
- def to_pdu(self):
- """
- Generate the wire format PDU for this error report.
- """
-
- if self._pdu is None:
- assert isinstance(self.errno, int)
- assert not isinstance(self.errpdu, ErrorReportPDU)
- p = self.errpdu
- if p is None:
- p = ""
- elif isinstance(p, PDU):
- p = p.to_pdu()
- assert isinstance(p, str)
- pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg)
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen)
- self._pdu += self.to_counted_string(p)
- self._pdu += self.to_counted_string(self.errmsg.encode("utf8"))
- return self._pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- header = reader.get(self.header_struct.size)
- version, pdu_type, self.errno, length = self.header_struct.unpack(header)
- assert version == self.version and pdu_type == self.pdu_type
- remaining = length - self.header_struct.size
- self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
- self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
- if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen:
- raise CorruptData("Got PDU length %d, expected %d" % (
- length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen))
- assert (header
- + self.to_counted_string(self.errpdu)
- + self.to_counted_string(self.errmsg.encode("utf8"))
- == self.to_pdu())
- return self
+ """
+ Error Report PDU.
+ """
+
+ pdu_type = 10
+
+ header_struct = struct.Struct("!BBHL")
+ string_struct = struct.Struct("!L")
+
+ errors = {
+ 2 : "No Data Available" }
+
+ fatal = {
+ 0 : "Corrupt Data",
+ 1 : "Internal Error",
+ 3 : "Invalid Request",
+ 4 : "Unsupported Protocol Version",
+ 5 : "Unsupported PDU Type",
+ 6 : "Withdrawal of Unknown Record",
+ 7 : "Duplicate Announcement Received" }
+
+ assert set(errors) & set(fatal) == set()
+
+ errors.update(fatal)
+
+ codes = dict((v, k) for k, v in errors.items())
+
+ def __init__(self, version, errno = None, errpdu = None, errmsg = None):
+ super(ErrorReportPDU, self).__init__(version)
+ assert errno is None or errno in self.errors
+ self.errno = errno
+ self.errpdu = errpdu
+ self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno]
+
+ def __str__(self):
+ return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg)
+
+ def to_counted_string(self, s):
+ return self.string_struct.pack(len(s)) + s
+
+ def read_counted_string(self, reader, remaining):
+ assert remaining >= self.string_struct.size
+ n = self.string_struct.unpack(reader.get(self.string_struct.size))[0]
+ assert remaining >= self.string_struct.size + n
+ return n, reader.get(n), (remaining - self.string_struct.size - n)
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU for this error report.
+ """
+
+ if self._pdu is None:
+ assert isinstance(self.errno, int)
+ assert not isinstance(self.errpdu, ErrorReportPDU)
+ p = self.errpdu
+ if p is None:
+ p = ""
+ elif isinstance(p, PDU):
+ p = p.to_pdu()
+ assert isinstance(p, str)
+ pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg)
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen)
+ self._pdu += self.to_counted_string(p)
+ self._pdu += self.to_counted_string(self.errmsg.encode("utf8"))
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ header = reader.get(self.header_struct.size)
+ version, pdu_type, self.errno, length = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
+ remaining = length - self.header_struct.size
+ self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
+ self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
+ if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen:
+ raise CorruptData("Got PDU length %d, expected %d" % (
+ length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen))
+ assert (header
+ + self.to_counted_string(self.errpdu)
+ + self.to_counted_string(self.errmsg.encode("utf8"))
+ == self.to_pdu())
+ return self
diff --git a/rpki/rtr/server.py b/rpki/rtr/server.py
index 1c7a5e78..f57c3037 100644
--- a/rpki/rtr/server.py
+++ b/rpki/rtr/server.py
@@ -44,37 +44,37 @@ kickme_base = os.path.join(kickme_dir, "kickme")
class PDU(rpki.rtr.pdus.PDU):
- """
- Generic server PDU.
- """
-
- def send_file(self, server, filename):
"""
- Send a content of a file as a cache response. Caller should catch IOError.
+ Generic server PDU.
"""
- fn2 = os.path.splitext(filename)[1]
- assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version
-
- f = open(filename, "rb")
- server.push_pdu(CacheResponsePDU(version = server.version,
- nonce = server.current_nonce))
- server.push_file(f)
- server.push_pdu(EndOfDataPDU(version = server.version,
- serial = server.current_serial,
- nonce = server.current_nonce,
- refresh = server.refresh,
- retry = server.retry,
- expire = server.expire))
-
- def send_nodata(self, server):
- """
- Send a nodata error.
- """
+ def send_file(self, server, filename):
+ """
+ Send a content of a file as a cache response. Caller should catch IOError.
+ """
+
+ fn2 = os.path.splitext(filename)[1]
+ assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version
- server.push_pdu(ErrorReportPDU(version = server.version,
- errno = ErrorReportPDU.codes["No Data Available"],
- errpdu = self))
+ f = open(filename, "rb")
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
+ server.push_file(f)
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce,
+ refresh = server.refresh,
+ retry = server.retry,
+ expire = server.expire))
+
+ def send_nodata(self, server):
+ """
+ Send a nodata error.
+ """
+
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["No Data Available"],
+ errpdu = self))
clone_pdu = clone_pdu_root(PDU)
@@ -82,512 +82,512 @@ clone_pdu = clone_pdu_root(PDU)
@clone_pdu
class SerialQueryPDU(PDU, rpki.rtr.pdus.SerialQueryPDU):
- """
- Serial Query PDU.
- """
-
- def serve(self, server):
- """
- Received a serial query, send incremental transfer in response.
- If client is already up to date, just send an empty incremental
- transfer.
"""
-
- server.logger.debug(self)
- if server.get_serial() is None:
- self.send_nodata(server)
- elif server.current_nonce != self.nonce:
- server.logger.info("[Client requested wrong nonce, resetting client]")
- server.push_pdu(CacheResetPDU(version = server.version))
- elif server.current_serial == self.serial:
- server.logger.debug("[Client is already current, sending empty IXFR]")
- server.push_pdu(CacheResponsePDU(version = server.version,
- nonce = server.current_nonce))
- server.push_pdu(EndOfDataPDU(version = server.version,
- serial = server.current_serial,
- nonce = server.current_nonce,
- refresh = server.refresh,
- retry = server.retry,
- expire = server.expire))
- elif disable_incrementals:
- server.push_pdu(CacheResetPDU(version = server.version))
- else:
- try:
- self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version))
- except IOError:
- server.push_pdu(CacheResetPDU(version = server.version))
+ Serial Query PDU.
+ """
+
+ def serve(self, server):
+ """
+ Received a serial query, send incremental transfer in response.
+ If client is already up to date, just send an empty incremental
+ transfer.
+ """
+
+ server.logger.debug(self)
+ if server.get_serial() is None:
+ self.send_nodata(server)
+ elif server.current_nonce != self.nonce:
+ server.logger.info("[Client requested wrong nonce, resetting client]")
+ server.push_pdu(CacheResetPDU(version = server.version))
+ elif server.current_serial == self.serial:
+ server.logger.debug("[Client is already current, sending empty IXFR]")
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce,
+ refresh = server.refresh,
+ retry = server.retry,
+ expire = server.expire))
+ elif disable_incrementals:
+ server.push_pdu(CacheResetPDU(version = server.version))
+ else:
+ try:
+ self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version))
+ except IOError:
+ server.push_pdu(CacheResetPDU(version = server.version))
@clone_pdu
class ResetQueryPDU(PDU, rpki.rtr.pdus.ResetQueryPDU):
- """
- Reset Query PDU.
- """
-
- def serve(self, server):
"""
- Received a reset query, send full current state in response.
+ Reset Query PDU.
"""
- server.logger.debug(self)
- if server.get_serial() is None:
- self.send_nodata(server)
- else:
- try:
- fn = "%d.ax.v%d" % (server.current_serial, server.version)
- self.send_file(server, fn)
- except IOError:
- server.push_pdu(ErrorReportPDU(version = server.version,
- errno = ErrorReportPDU.codes["Internal Error"],
- errpdu = self,
- errmsg = "Couldn't open %s" % fn))
+ def serve(self, server):
+ """
+ Received a reset query, send full current state in response.
+ """
+
+ server.logger.debug(self)
+ if server.get_serial() is None:
+ self.send_nodata(server)
+ else:
+ try:
+ fn = "%d.ax.v%d" % (server.current_serial, server.version)
+ self.send_file(server, fn)
+ except IOError:
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["Internal Error"],
+ errpdu = self,
+ errmsg = "Couldn't open %s" % fn))
@clone_pdu
class ErrorReportPDU(rpki.rtr.pdus.ErrorReportPDU):
- """
- Error Report PDU.
- """
-
- def serve(self, server):
"""
- Received an ErrorReportPDU from client. Not much we can do beyond
- logging it, then killing the connection if error was fatal.
+ Error Report PDU.
"""
- server.logger.error(self)
- if self.errno in self.fatal:
- server.logger.error("[Shutting down due to reported fatal protocol error]")
- sys.exit(1)
+ def serve(self, server):
+ """
+ Received an ErrorReportPDU from client. Not much we can do beyond
+ logging it, then killing the connection if error was fatal.
+ """
+
+ server.logger.error(self)
+ if self.errno in self.fatal:
+ server.logger.error("[Shutting down due to reported fatal protocol error]")
+ sys.exit(1)
def read_current(version):
- """
- Read current serial number and nonce. Return None for both if
- serial and nonce not recorded. For backwards compatibility, treat
- file containing just a serial number as having a nonce of zero.
- """
-
- if version is None:
- return None, None
- try:
- with open("current.v%d" % version, "r") as f:
- values = tuple(int(s) for s in f.read().split())
- return values[0], values[1]
- except IndexError:
- return values[0], 0
- except IOError:
- return None, None
+ """
+ Read current serial number and nonce. Return None for both if
+ serial and nonce not recorded. For backwards compatibility, treat
+ file containing just a serial number as having a nonce of zero.
+ """
+
+ if version is None:
+ return None, None
+ try:
+ with open("current.v%d" % version, "r") as f:
+ values = tuple(int(s) for s in f.read().split())
+ return values[0], values[1]
+ except IndexError:
+ return values[0], 0
+ except IOError:
+ return None, None
def write_current(serial, nonce, version):
- """
- Write serial number and nonce.
- """
+ """
+ Write serial number and nonce.
+ """
- curfn = "current.v%d" % version
- tmpfn = curfn + "%d.tmp" % os.getpid()
- with open(tmpfn, "w") as f:
- f.write("%d %d\n" % (serial, nonce))
- os.rename(tmpfn, curfn)
+ curfn = "current.v%d" % version
+ tmpfn = curfn + "%d.tmp" % os.getpid()
+ with open(tmpfn, "w") as f:
+ f.write("%d %d\n" % (serial, nonce))
+ os.rename(tmpfn, curfn)
class FileProducer(object):
- """
- File-based producer object for asynchat.
- """
+ """
+ File-based producer object for asynchat.
+ """
- def __init__(self, handle, buffersize):
- self.handle = handle
- self.buffersize = buffersize
+ def __init__(self, handle, buffersize):
+ self.handle = handle
+ self.buffersize = buffersize
- def more(self):
- return self.handle.read(self.buffersize)
+ def more(self):
+ return self.handle.read(self.buffersize)
class ServerWriteChannel(rpki.rtr.channels.PDUChannel):
- """
- Kludge to deal with ssh's habit of sometimes (compile time option)
- invoking us with two unidirectional pipes instead of one
- bidirectional socketpair. All the server logic is in the
- ServerChannel class, this class just deals with sending the
- server's output to a different file descriptor.
- """
-
- def __init__(self):
"""
- Set up stdout.
+ Kludge to deal with ssh's habit of sometimes (compile time option)
+ invoking us with two unidirectional pipes instead of one
+ bidirectional socketpair. All the server logic is in the
+ ServerChannel class, this class just deals with sending the
+ server's output to a different file descriptor.
"""
- super(ServerWriteChannel, self).__init__(root_pdu_class = PDU)
- self.init_file_dispatcher(sys.stdout.fileno())
+ def __init__(self):
+ """
+ Set up stdout.
+ """
- def readable(self):
- """
- This channel is never readable.
- """
+ super(ServerWriteChannel, self).__init__(root_pdu_class = PDU)
+ self.init_file_dispatcher(sys.stdout.fileno())
- return False
+ def readable(self):
+ """
+ This channel is never readable.
+ """
- def push_file(self, f):
- """
- Write content of a file to stream.
- """
+ return False
- try:
- self.push_with_producer(FileProducer(f, self.ac_out_buffer_size))
- except OSError, e:
- if e.errno != errno.EAGAIN:
- raise
+ def push_file(self, f):
+ """
+ Write content of a file to stream.
+ """
+ try:
+ self.push_with_producer(FileProducer(f, self.ac_out_buffer_size))
+ except OSError, e:
+ if e.errno != errno.EAGAIN:
+ raise
-class ServerChannel(rpki.rtr.channels.PDUChannel):
- """
- Server protocol engine, handles upcalls from PDUChannel to
- implement protocol logic.
- """
- def __init__(self, logger, refresh, retry, expire):
+class ServerChannel(rpki.rtr.channels.PDUChannel):
"""
- Set up stdin and stdout as connection and start listening for
- first PDU.
+ Server protocol engine, handles upcalls from PDUChannel to
+ implement protocol logic.
"""
- super(ServerChannel, self).__init__(root_pdu_class = PDU)
- self.init_file_dispatcher(sys.stdin.fileno())
- self.writer = ServerWriteChannel()
- self.logger = logger
- self.refresh = refresh
- self.retry = retry
- self.expire = expire
- self.get_serial()
- self.start_new_pdu()
-
- def writable(self):
- """
- This channel is never writable.
- """
+ def __init__(self, logger, refresh, retry, expire):
+ """
+ Set up stdin and stdout as connection and start listening for
+ first PDU.
+ """
- return False
+ super(ServerChannel, self).__init__(root_pdu_class = PDU)
+ self.init_file_dispatcher(sys.stdin.fileno())
+ self.writer = ServerWriteChannel()
+ self.logger = logger
+ self.refresh = refresh
+ self.retry = retry
+ self.expire = expire
+ self.get_serial()
+ self.start_new_pdu()
- def push(self, data):
- """
- Redirect to writer channel.
- """
+ def writable(self):
+ """
+ This channel is never writable.
+ """
- return self.writer.push(data)
+ return False
- def push_with_producer(self, producer):
- """
- Redirect to writer channel.
- """
+ def push(self, data):
+ """
+ Redirect to writer channel.
+ """
- return self.writer.push_with_producer(producer)
+ return self.writer.push(data)
- def push_pdu(self, pdu):
- """
- Redirect to writer channel.
- """
+ def push_with_producer(self, producer):
+ """
+ Redirect to writer channel.
+ """
- return self.writer.push_pdu(pdu)
+ return self.writer.push_with_producer(producer)
- def push_file(self, f):
- """
- Redirect to writer channel.
- """
+ def push_pdu(self, pdu):
+ """
+ Redirect to writer channel.
+ """
- return self.writer.push_file(f)
+ return self.writer.push_pdu(pdu)
- def deliver_pdu(self, pdu):
- """
- Handle received PDU.
- """
+ def push_file(self, f):
+ """
+ Redirect to writer channel.
+ """
- pdu.serve(self)
+ return self.writer.push_file(f)
- def get_serial(self):
- """
- Read, cache, and return current serial number, or None if we can't
- find the serial number file. The latter condition should never
- happen, but maybe we got started in server mode while the cronjob
- mode instance is still building its database.
- """
+ def deliver_pdu(self, pdu):
+ """
+ Handle received PDU.
+ """
- self.current_serial, self.current_nonce = read_current(self.version)
- return self.current_serial
+ pdu.serve(self)
- def check_serial(self):
- """
- Check for a new serial number.
- """
+ def get_serial(self):
+ """
+ Read, cache, and return current serial number, or None if we can't
+ find the serial number file. The latter condition should never
+ happen, but maybe we got started in server mode while the cronjob
+ mode instance is still building its database.
+ """
- old_serial = self.current_serial
- return old_serial != self.get_serial()
+ self.current_serial, self.current_nonce = read_current(self.version)
+ return self.current_serial
- def notify(self, data = None, force = False):
- """
- Cronjob instance kicked us: check whether our serial number has
- changed, and send a notify message if so.
+ def check_serial(self):
+ """
+ Check for a new serial number.
+ """
- We have to check rather than just blindly notifying when kicked
- because the cronjob instance has no good way of knowing which
- protocol version we're running, thus has no good way of knowing
- whether we care about a particular change set or not.
- """
+ old_serial = self.current_serial
+ return old_serial != self.get_serial()
- if force or self.check_serial():
- self.push_pdu(SerialNotifyPDU(version = self.version,
- serial = self.current_serial,
- nonce = self.current_nonce))
- else:
- self.logger.debug("Cronjob kicked me but I see no serial change, ignoring")
+ def notify(self, data = None, force = False):
+ """
+ Cronjob instance kicked us: check whether our serial number has
+ changed, and send a notify message if so.
+
+ We have to check rather than just blindly notifying when kicked
+ because the cronjob instance has no good way of knowing which
+ protocol version we're running, thus has no good way of knowing
+ whether we care about a particular change set or not.
+ """
+
+ if force or self.check_serial():
+ self.push_pdu(SerialNotifyPDU(version = self.version,
+ serial = self.current_serial,
+ nonce = self.current_nonce))
+ else:
+ self.logger.debug("Cronjob kicked me but I see no serial change, ignoring")
class KickmeChannel(asyncore.dispatcher, object):
- """
- asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
- kick servers when it's time to send notify PDUs to clients.
- """
-
- def __init__(self, server):
- asyncore.dispatcher.__init__(self) # Old-style class
- self.server = server
- self.sockname = "%s.%d" % (kickme_base, os.getpid())
- self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM)
- try:
- self.bind(self.sockname)
- os.chmod(self.sockname, 0660)
- except socket.error, e:
- self.server.logger.exception("Couldn't bind() kickme socket: %r", e)
- self.close()
- except OSError, e:
- self.server.logger.exception("Couldn't chmod() kickme socket: %r", e)
-
- def writable(self):
"""
- This socket is read-only, never writable.
+ asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
+ kick servers when it's time to send notify PDUs to clients.
"""
- return False
+ def __init__(self, server):
+ asyncore.dispatcher.__init__(self) # Old-style class
+ self.server = server
+ self.sockname = "%s.%d" % (kickme_base, os.getpid())
+ self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ try:
+ self.bind(self.sockname)
+ os.chmod(self.sockname, 0660)
+ except socket.error, e:
+ self.server.logger.exception("Couldn't bind() kickme socket: %r", e)
+ self.close()
+ except OSError, e:
+ self.server.logger.exception("Couldn't chmod() kickme socket: %r", e)
+
+ def writable(self):
+ """
+ This socket is read-only, never writable.
+ """
+
+ return False
+
+ def handle_connect(self):
+ """
+ Ignore connect events (not very useful on datagram socket).
+ """
+
+ pass
+
+ def handle_read(self):
+ """
+ Handle receipt of a datagram.
+ """
+
+ data = self.recv(512)
+ self.server.notify(data)
+
+ def cleanup(self):
+ """
+ Clean up this dispatcher's socket.
+ """
+
+ self.close()
+ try:
+ os.unlink(self.sockname)
+ except: # pylint: disable=W0702
+ pass
- def handle_connect(self):
- """
- Ignore connect events (not very useful on datagram socket).
- """
+ def log(self, msg):
+ """
+ Intercept asyncore's logging.
+ """
- pass
+ self.server.logger.info(msg)
- def handle_read(self):
- """
- Handle receipt of a datagram.
- """
+ def log_info(self, msg, tag = "info"):
+ """
+ Intercept asyncore's logging.
+ """
- data = self.recv(512)
- self.server.notify(data)
+ self.server.logger.info("asyncore: %s: %s", tag, msg)
- def cleanup(self):
- """
- Clean up this dispatcher's socket.
- """
+ def handle_error(self):
+ """
+ Handle errors caught by asyncore main loop.
+ """
- self.close()
- try:
- os.unlink(self.sockname)
- except: # pylint: disable=W0702
- pass
+ self.server.logger.exception("[Unhandled exception]")
+ self.server.logger.critical("[Exiting after unhandled exception]")
+ sys.exit(1)
- def log(self, msg):
+
+def _hostport_tag():
"""
- Intercept asyncore's logging.
+ Construct hostname/address + port when we're running under a
+ protocol we understand well enough to do that. This is all
+ kludgery. Just grit your teeth, or perhaps just close your eyes.
"""
- self.server.logger.info(msg)
+ proto = None
- def log_info(self, msg, tag = "info"):
- """
- Intercept asyncore's logging.
- """
+ if proto is None:
+ try:
+ host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername()
+ proto = "tcp"
+ except: # pylint: disable=W0702
+ pass
- self.server.logger.info("asyncore: %s: %s", tag, msg)
+ if proto is None:
+ try:
+ host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2]
+ proto = "tcp"
+ except: # pylint: disable=W0702
+ pass
- def handle_error(self):
+ if proto is None:
+ try:
+ host, port = os.environ["SSH_CONNECTION"].split()[0:2]
+ proto = "ssh"
+ except: # pylint: disable=W0702
+ pass
+
+ if proto is None:
+ try:
+ host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT")
+ proto = "ssl"
+ except: # pylint: disable=W0702
+ pass
+
+ if proto is None:
+ return ""
+ elif not port:
+ return "/%s/%s" % (proto, host)
+ elif ":" in host:
+ return "/%s/%s.%s" % (proto, host, port)
+ else:
+ return "/%s/%s:%s" % (proto, host, port)
+
+
+def server_main(args):
"""
- Handle errors caught by asyncore main loop.
+ Implement the server side of the rpkk-router protocol. Other than
+ one PF_UNIX socket inode, this doesn't write anything to disk, so it
+ can be run with minimal privileges. Most of the work has already
+ been done by the database generator, so all this server has to do is
+ pass the results along to a client.
"""
- self.server.logger.exception("[Unhandled exception]")
- self.server.logger.critical("[Exiting after unhandled exception]")
- sys.exit(1)
+ logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag()))
+ logger.debug("[Starting]")
-def _hostport_tag():
- """
- Construct hostname/address + port when we're running under a
- protocol we understand well enough to do that. This is all
- kludgery. Just grit your teeth, or perhaps just close your eyes.
- """
-
- proto = None
+ if args.rpki_rtr_dir:
+ try:
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ sys.exit(e)
- if proto is None:
+ kickme = None
try:
- host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername()
- proto = "tcp"
- except: # pylint: disable=W0702
- pass
+ server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
+ kickme = rpki.rtr.server.KickmeChannel(server = server)
+ asyncore.loop(timeout = None)
+ signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition
+ except KeyboardInterrupt:
+ sys.exit(0)
+ finally:
+ signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition
+ if kickme is not None:
+ kickme.cleanup()
- if proto is None:
- try:
- host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2]
- proto = "tcp"
- except: # pylint: disable=W0702
- pass
- if proto is None:
- try:
- host, port = os.environ["SSH_CONNECTION"].split()[0:2]
- proto = "ssh"
- except: # pylint: disable=W0702
- pass
+def listener_main(args):
+ """
+ Totally insecure TCP listener for rpki-rtr protocol. We only
+ implement this because it's all that the routers currently support.
+ In theory, we will all be running TCP-AO in the future, at which
+ point this listener will go away or become a TCP-AO listener.
+ """
- if proto is None:
- try:
- host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT")
- proto = "ssl"
- except: # pylint: disable=W0702
- pass
+ # Perhaps we should daemonize? Deal with that later.
- if proto is None:
- return ""
- elif not port:
- return "/%s/%s" % (proto, host)
- elif ":" in host:
- return "/%s/%s.%s" % (proto, host, port)
- else:
- return "/%s/%s:%s" % (proto, host, port)
+ # server_main() handles args.rpki_rtr_dir.
+ listener = None
+ try:
+ listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+ except: # pylint: disable=W0702
+ if listener is not None:
+ listener.close()
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ try:
+ listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ except AttributeError:
+ pass
+ listener.bind(("", args.port))
+ listener.listen(5)
+ logging.debug("[Listening on port %s]", args.port)
+ while True:
+ try:
+ s, ai = listener.accept()
+ except KeyboardInterrupt:
+ sys.exit(0)
+ logging.debug("[Received connection from %r]", ai)
+ pid = os.fork()
+ if pid == 0:
+ os.dup2(s.fileno(), 0) # pylint: disable=E1103
+ os.dup2(s.fileno(), 1) # pylint: disable=E1103
+ s.close()
+ #os.closerange(3, os.sysconf("SC_OPEN_MAX"))
+ server_main(args)
+ sys.exit()
+ else:
+ logging.debug("[Spawned server %d]", pid)
+ while True:
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612
+ if pid:
+ logging.debug("[Server %s exited]", pid)
+ continue
+ except: # pylint: disable=W0702
+ pass
+ break
-def server_main(args):
- """
- Implement the server side of the rpkk-router protocol. Other than
- one PF_UNIX socket inode, this doesn't write anything to disk, so it
- can be run with minimal privileges. Most of the work has already
- been done by the database generator, so all this server has to do is
- pass the results along to a client.
- """
- logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag()))
+def argparse_setup(subparsers):
+ """
+ Set up argparse stuff for commands in this module.
+ """
- logger.debug("[Starting]")
+ # These could have been lambdas, but doing it this way results in
+ # more useful error messages on argparse failures.
- if args.rpki_rtr_dir:
- try:
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- sys.exit(e)
-
- kickme = None
- try:
- server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
- kickme = rpki.rtr.server.KickmeChannel(server = server)
- asyncore.loop(timeout = None)
- signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition
- except KeyboardInterrupt:
- sys.exit(0)
- finally:
- signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition
- if kickme is not None:
- kickme.cleanup()
+ def refresh(v):
+ return rpki.rtr.pdus.valid_refresh(int(v))
+ def retry(v):
+ return rpki.rtr.pdus.valid_retry(int(v))
-def listener_main(args):
- """
- Totally insecure TCP listener for rpki-rtr protocol. We only
- implement this because it's all that the routers currently support.
- In theory, we will all be running TCP-AO in the future, at which
- point this listener will go away or become a TCP-AO listener.
- """
-
- # Perhaps we should daemonize? Deal with that later.
-
- # server_main() handles args.rpki_rtr_dir.
-
- listener = None
- try:
- listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
- listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
- except: # pylint: disable=W0702
- if listener is not None:
- listener.close()
- listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- try:
- listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- except AttributeError:
- pass
- listener.bind(("", args.port))
- listener.listen(5)
- logging.debug("[Listening on port %s]", args.port)
- while True:
- try:
- s, ai = listener.accept()
- except KeyboardInterrupt:
- sys.exit(0)
- logging.debug("[Received connection from %r]", ai)
- pid = os.fork()
- if pid == 0:
- os.dup2(s.fileno(), 0) # pylint: disable=E1103
- os.dup2(s.fileno(), 1) # pylint: disable=E1103
- s.close()
- #os.closerange(3, os.sysconf("SC_OPEN_MAX"))
- server_main(args)
- sys.exit()
- else:
- logging.debug("[Spawned server %d]", pid)
- while True:
- try:
- pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612
- if pid:
- logging.debug("[Server %s exited]", pid)
- continue
- except: # pylint: disable=W0702
- pass
- break
+ def expire(v):
+ return rpki.rtr.pdus.valid_expire(int(v))
+ # Some duplication of arguments here, not enough to be worth huge
+ # effort to clean up, worry about it later in any case.
-def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- # These could have been lambdas, but doing it this way results in
- # more useful error messages on argparse failures.
-
- def refresh(v):
- return rpki.rtr.pdus.valid_refresh(int(v))
-
- def retry(v):
- return rpki.rtr.pdus.valid_retry(int(v))
-
- def expire(v):
- return rpki.rtr.pdus.valid_expire(int(v))
-
- # Some duplication of arguments here, not enough to be worth huge
- # effort to clean up, worry about it later in any case.
-
- subparser = subparsers.add_parser("server", description = server_main.__doc__,
- help = "RPKI-RTR protocol server")
- subparser.set_defaults(func = server_main, default_log_to = "syslog")
- subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
- subparser.add_argument("--retry", type = retry, help = "override default retry timer")
- subparser.add_argument("--expire", type = expire, help = "override default expire timer")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
-
- subparser = subparsers.add_parser("listener", description = listener_main.__doc__,
- help = "TCP listener for RPKI-RTR protocol server")
- subparser.set_defaults(func = listener_main, default_log_to = "syslog")
- subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
- subparser.add_argument("--retry", type = retry, help = "override default retry timer")
- subparser.add_argument("--expire", type = expire, help = "override default expire timer")
- subparser.add_argument("port", type = int, help = "TCP port on which to listen")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+ subparser = subparsers.add_parser("server", description = server_main.__doc__,
+ help = "RPKI-RTR protocol server")
+ subparser.set_defaults(func = server_main, default_log_to = "syslog")
+ subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
+ subparser.add_argument("--retry", type = retry, help = "override default retry timer")
+ subparser.add_argument("--expire", type = expire, help = "override default expire timer")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+
+ subparser = subparsers.add_parser("listener", description = listener_main.__doc__,
+ help = "TCP listener for RPKI-RTR protocol server")
+ subparser.set_defaults(func = listener_main, default_log_to = "syslog")
+ subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
+ subparser.add_argument("--retry", type = retry, help = "override default retry timer")
+ subparser.add_argument("--expire", type = expire, help = "override default expire timer")
+ subparser.add_argument("port", type = int, help = "TCP port on which to listen")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")