aboutsummaryrefslogtreecommitdiff
path: root/rpki/rtr
diff options
context:
space:
mode:
Diffstat (limited to 'rpki/rtr')
-rwxr-xr-xrpki/rtr/bgpdump.py482
-rw-r--r--rpki/rtr/channels.py364
-rw-r--r--rpki/rtr/client.py816
-rw-r--r--rpki/rtr/generator.py948
-rw-r--r--rpki/rtr/main.py110
-rw-r--r--rpki/rtr/pdus.py960
-rw-r--r--rpki/rtr/server.py872
7 files changed, 2276 insertions, 2276 deletions
diff --git a/rpki/rtr/bgpdump.py b/rpki/rtr/bgpdump.py
index 5ffabc4d..3336fb9f 100755
--- a/rpki/rtr/bgpdump.py
+++ b/rpki/rtr/bgpdump.py
@@ -39,292 +39,292 @@ from rpki.rtr.channels import Timestamp
class IgnoreThisRecord(Exception):
- pass
+ pass
class PrefixPDU(rpki.rtr.generator.PrefixPDU):
- @staticmethod
- def from_bgpdump(line, rib_dump):
- try:
- assert isinstance(rib_dump, bool)
- fields = line.split("|")
-
- # Parse prefix, including figuring out IP protocol version
- cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU
- self = cls()
- self.timestamp = Timestamp(fields[1])
- p, l = fields[5].split("/")
- self.prefix = rpki.POW.IPAddress(p)
- self.prefixlen = self.max_prefixlen = int(l)
-
- # Withdrawals don't have AS paths, so be careful
- assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W")
- if fields[2] == "W":
- self.asn = 0
- self.announce = 0
- else:
- self.announce = 1
- if not fields[6] or "{" in fields[6] or "(" in fields[6]:
- raise IgnoreThisRecord
- a = fields[6].split()[-1]
- if "." in a:
- a = [int(s) for s in a.split(".")]
- if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535:
- logging.warn("Bad dotted ASNum %r, ignoring record", fields[6])
+ @staticmethod
+ def from_bgpdump(line, rib_dump):
+ try:
+ assert isinstance(rib_dump, bool)
+ fields = line.split("|")
+
+ # Parse prefix, including figuring out IP protocol version
+ cls = rpki.rtr.generator.IPv6PrefixPDU if ":" in fields[5] else rpki.rtr.generator.IPv4PrefixPDU
+ self = cls()
+ self.timestamp = Timestamp(fields[1])
+ p, l = fields[5].split("/")
+ self.prefix = rpki.POW.IPAddress(p)
+ self.prefixlen = self.max_prefixlen = int(l)
+
+ # Withdrawals don't have AS paths, so be careful
+ assert fields[2] == "B" if rib_dump else fields[2] in ("A", "W")
+ if fields[2] == "W":
+ self.asn = 0
+ self.announce = 0
+ else:
+ self.announce = 1
+ if not fields[6] or "{" in fields[6] or "(" in fields[6]:
+ raise IgnoreThisRecord
+ a = fields[6].split()[-1]
+ if "." in a:
+ a = [int(s) for s in a.split(".")]
+ if len(a) != 2 or a[0] < 0 or a[0] > 65535 or a[1] < 0 or a[1] > 65535:
+ logging.warn("Bad dotted ASNum %r, ignoring record", fields[6])
+ raise IgnoreThisRecord
+ a = (a[0] << 16) | a[1]
+ else:
+ a = int(a)
+ self.asn = a
+
+ self.check()
+ return self
+
+ except IgnoreThisRecord:
+ raise
+
+ except Exception, e:
+ logging.warn("Ignoring line %r: %s", line, e)
raise IgnoreThisRecord
- a = (a[0] << 16) | a[1]
- else:
- a = int(a)
- self.asn = a
- self.check()
- return self
- except IgnoreThisRecord:
- raise
+class AXFRSet(rpki.rtr.generator.AXFRSet):
- except Exception, e:
- logging.warn("Ignoring line %r: %s", line, e)
- raise IgnoreThisRecord
+ @staticmethod
+ def read_bgpdump(filename):
+ assert filename.endswith(".bz2")
+ logging.debug("Reading %s", filename)
+ bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE)
+ bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE)
+ return bgpdump.stdout
+
+ @classmethod
+ def parse_bgpdump_rib_dump(cls, filename):
+ assert os.path.basename(filename).startswith("ribs.")
+ self = cls()
+ self.serial = None
+ for line in cls.read_bgpdump(filename):
+ try:
+ pfx = PrefixPDU.from_bgpdump(line, rib_dump = True)
+ except IgnoreThisRecord:
+ continue
+ self.append(pfx)
+ self.serial = pfx.timestamp
+ if self.serial is None:
+ sys.exit("Failed to parse anything useful from %s" % filename)
+ self.sort()
+ for i in xrange(len(self) - 2, -1, -1):
+ if self[i] == self[i + 1]:
+ del self[i + 1]
+ return self
+
+ def parse_bgpdump_update(self, filename):
+ assert os.path.basename(filename).startswith("updates.")
+ for line in self.read_bgpdump(filename):
+ try:
+ pfx = PrefixPDU.from_bgpdump(line, rib_dump = False)
+ except IgnoreThisRecord:
+ continue
+ announce = pfx.announce
+ pfx.announce = 1
+ i = bisect.bisect_left(self, pfx)
+ if announce:
+ if i >= len(self) or pfx != self[i]:
+ self.insert(i, pfx)
+ else:
+ while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen:
+ del self[i]
+ self.serial = pfx.timestamp
-class AXFRSet(rpki.rtr.generator.AXFRSet):
+def bgpdump_convert_main(args):
+ """
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ Simulate route origin data from a set of BGP dump files.
+ argv is an ordered list of filenames. Each file must be a BGP RIB
+ dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by
+ this program's --cronjob command. The first file must be a RIB dump
+ or AXFR dump, it cannot be an UPDATE dump. Output will be a set of
+ AXFR and IXFR files with timestamps derived from the BGP dumps,
+ which can be used as input to this program's --server command for
+ test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL.
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ """
+
+ first = True
+ db = None
+ axfrs = []
+ version = max(rpki.rtr.pdus.PDU.version_map.iterkeys())
+
+ for filename in args.files:
+
+ if ".ax.v" in filename:
+ logging.debug("Reading %s", filename)
+ db = AXFRSet.load(filename)
+
+ elif os.path.basename(filename).startswith("ribs."):
+ db = AXFRSet.parse_bgpdump_rib_dump(filename)
+ db.save_axfr()
+
+ elif not first:
+ assert db is not None
+ db.parse_bgpdump_update(filename)
+ db.save_axfr()
- @staticmethod
- def read_bgpdump(filename):
- assert filename.endswith(".bz2")
- logging.debug("Reading %s", filename)
- bunzip2 = subprocess.Popen(("bzip2", "-c", "-d", filename), stdout = subprocess.PIPE)
- bgpdump = subprocess.Popen(("bgpdump", "-m", "-"), stdin = bunzip2.stdout, stdout = subprocess.PIPE)
- return bgpdump.stdout
-
- @classmethod
- def parse_bgpdump_rib_dump(cls, filename):
- assert os.path.basename(filename).startswith("ribs.")
- self = cls()
- self.serial = None
- for line in cls.read_bgpdump(filename):
- try:
- pfx = PrefixPDU.from_bgpdump(line, rib_dump = True)
- except IgnoreThisRecord:
- continue
- self.append(pfx)
- self.serial = pfx.timestamp
- if self.serial is None:
- sys.exit("Failed to parse anything useful from %s" % filename)
- self.sort()
- for i in xrange(len(self) - 2, -1, -1):
- if self[i] == self[i + 1]:
- del self[i + 1]
- return self
-
- def parse_bgpdump_update(self, filename):
- assert os.path.basename(filename).startswith("updates.")
- for line in self.read_bgpdump(filename):
- try:
- pfx = PrefixPDU.from_bgpdump(line, rib_dump = False)
- except IgnoreThisRecord:
- continue
- announce = pfx.announce
- pfx.announce = 1
- i = bisect.bisect_left(self, pfx)
- if announce:
- if i >= len(self) or pfx != self[i]:
- self.insert(i, pfx)
- else:
- while i < len(self) and pfx.prefix == self[i].prefix and pfx.prefixlen == self[i].prefixlen:
- del self[i]
- self.serial = pfx.timestamp
+ else:
+ sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename)
+ logging.debug("DB serial now %d (%s)", db.serial, db.serial)
+ if first and rpki.rtr.server.read_current(version) == (None, None):
+ db.mark_current()
+ first = False
-def bgpdump_convert_main(args):
- """
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- Simulate route origin data from a set of BGP dump files.
- argv is an ordered list of filenames. Each file must be a BGP RIB
- dumps, a BGP UPDATE dumps, or an AXFR dump in the format written by
- this program's --cronjob command. The first file must be a RIB dump
- or AXFR dump, it cannot be an UPDATE dump. Output will be a set of
- AXFR and IXFR files with timestamps derived from the BGP dumps,
- which can be used as input to this program's --server command for
- test purposes. SUCH DATA PROVIDE NO SECURITY AT ALL.
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- """
-
- first = True
- db = None
- axfrs = []
- version = max(rpki.rtr.pdus.PDU.version_map.iterkeys())
-
- for filename in args.files:
-
- if ".ax.v" in filename:
- logging.debug("Reading %s", filename)
- db = AXFRSet.load(filename)
-
- elif os.path.basename(filename).startswith("ribs."):
- db = AXFRSet.parse_bgpdump_rib_dump(filename)
- db.save_axfr()
-
- elif not first:
- assert db is not None
- db.parse_bgpdump_update(filename)
- db.save_axfr()
-
- else:
- sys.exit("First argument must be a RIB dump or .ax file, don't know what to do with %s" % filename)
-
- logging.debug("DB serial now %d (%s)", db.serial, db.serial)
- if first and rpki.rtr.server.read_current(version) == (None, None):
- db.mark_current()
- first = False
-
- for axfr in axfrs:
- logging.debug("Loading %s", axfr)
- ax = AXFRSet.load(axfr)
- logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial)
- db.save_ixfr(ax)
- del ax
-
- axfrs.append(db.filename())
+ for axfr in axfrs:
+ logging.debug("Loading %s", axfr)
+ ax = AXFRSet.load(axfr)
+ logging.debug("Computing changes from %d (%s) to %d (%s)", ax.serial, ax.serial, db.serial, db.serial)
+ db.save_ixfr(ax)
+ del ax
+
+ axfrs.append(db.filename())
def bgpdump_select_main(args):
- """
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- Simulate route origin data from a set of BGP dump files.
- Set current serial number to correspond to an .ax file created by
- converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL.
- * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
- """
+ """
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ Simulate route origin data from a set of BGP dump files.
+ Set current serial number to correspond to an .ax file created by
+ converting BGP dump files. SUCH DATA PROVIDE NO SECURITY AT ALL.
+ * DANGER WILL ROBINSON! * DEBUGGING AND TEST USE ONLY! *
+ """
- head, sep, tail = os.path.basename(args.ax_file).partition(".")
- if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit():
- sys.exit("Argument must be name of a .ax file")
+ head, sep, tail = os.path.basename(args.ax_file).partition(".")
+ if not head.isdigit() or sep != "." or not tail.startswith("ax.v") or not tail[4:].isdigit():
+ sys.exit("Argument must be name of a .ax file")
- serial = Timestamp(head)
- version = int(tail[4:])
+ serial = Timestamp(head)
+ version = int(tail[4:])
- if version not in rpki.rtr.pdus.PDU.version_map:
- sys.exit("Unknown protocol version %d" % version)
+ if version not in rpki.rtr.pdus.PDU.version_map:
+ sys.exit("Unknown protocol version %d" % version)
- nonce = rpki.rtr.server.read_current(version)[1]
- if nonce is None:
- nonce = rpki.rtr.generator.new_nonce()
+ nonce = rpki.rtr.server.read_current(version)[1]
+ if nonce is None:
+ nonce = rpki.rtr.generator.new_nonce()
- rpki.rtr.server.write_current(serial, nonce, version)
- rpki.rtr.generator.kick_all(serial)
+ rpki.rtr.server.write_current(serial, nonce, version)
+ rpki.rtr.generator.kick_all(serial)
class BGPDumpReplayClock(object):
- """
- Internal clock for replaying BGP dump files.
+ """
+ Internal clock for replaying BGP dump files.
- * DANGER WILL ROBINSON! *
- * DEBUGGING AND TEST USE ONLY! *
+ * DANGER WILL ROBINSON! *
+ * DEBUGGING AND TEST USE ONLY! *
- This class replaces the normal on-disk serial number mechanism with
- an in-memory version based on pre-computed data.
+ This class replaces the normal on-disk serial number mechanism with
+ an in-memory version based on pre-computed data.
- bgpdump_server_main() uses this hack to replay historical data for
- testing purposes. DO NOT USE THIS IN PRODUCTION.
+ bgpdump_server_main() uses this hack to replay historical data for
+ testing purposes. DO NOT USE THIS IN PRODUCTION.
- You have been warned.
- """
+ You have been warned.
+ """
- def __init__(self):
- self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")]
- self.timestamps.sort()
- self.offset = self.timestamps[0] - int(time.time())
- self.nonce = rpki.rtr.generator.new_nonce()
+ def __init__(self):
+ self.timestamps = [Timestamp(int(f.split(".")[0])) for f in glob.iglob("*.ax.v*")]
+ self.timestamps.sort()
+ self.offset = self.timestamps[0] - int(time.time())
+ self.nonce = rpki.rtr.generator.new_nonce()
- def __nonzero__(self):
- return len(self.timestamps) > 0
+ def __nonzero__(self):
+ return len(self.timestamps) > 0
- def now(self):
- return Timestamp.now(self.offset)
+ def now(self):
+ return Timestamp.now(self.offset)
- def read_current(self, version):
- now = self.now()
- while len(self.timestamps) > 1 and now >= self.timestamps[1]:
- del self.timestamps[0]
- return self.timestamps[0], self.nonce
+ def read_current(self, version):
+ now = self.now()
+ while len(self.timestamps) > 1 and now >= self.timestamps[1]:
+ del self.timestamps[0]
+ return self.timestamps[0], self.nonce
- def siesta(self):
- now = self.now()
- if len(self.timestamps) <= 1:
- return None
- elif now < self.timestamps[1]:
- return self.timestamps[1] - now
- else:
- return 1
+ def siesta(self):
+ now = self.now()
+ if len(self.timestamps) <= 1:
+ return None
+ elif now < self.timestamps[1]:
+ return self.timestamps[1] - now
+ else:
+ return 1
def bgpdump_server_main(args):
- """
- Simulate route origin data from a set of BGP dump files.
+ """
+ Simulate route origin data from a set of BGP dump files.
+
+ * DANGER WILL ROBINSON! *
+ * DEBUGGING AND TEST USE ONLY! *
+
+ This is a clone of server_main() which replaces the external serial
+ number updates triggered via the kickme channel by cronjob_main with
+ an internal clocking mechanism to replay historical test data.
- * DANGER WILL ROBINSON! *
- * DEBUGGING AND TEST USE ONLY! *
+ DO NOT USE THIS IN PRODUCTION.
- This is a clone of server_main() which replaces the external serial
- number updates triggered via the kickme channel by cronjob_main with
- an internal clocking mechanism to replay historical test data.
+ You have been warned.
+ """
- DO NOT USE THIS IN PRODUCTION.
+ logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag()))
- You have been warned.
- """
+ logger.debug("[Starting]")
- logger = logging.LoggerAdapter(logging.root, dict(connection = rpki.rtr.server._hostport_tag()))
+ if args.rpki_rtr_dir:
+ try:
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ sys.exit(e)
- logger.debug("[Starting]")
+ # Yes, this really does replace a global function defined in another
+ # module with a bound method to our clock object. Fun stuff, huh?
+ #
+ clock = BGPDumpReplayClock()
+ rpki.rtr.server.read_current = clock.read_current
- if args.rpki_rtr_dir:
try:
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- sys.exit(e)
-
- # Yes, this really does replace a global function defined in another
- # module with a bound method to our clock object. Fun stuff, huh?
- #
- clock = BGPDumpReplayClock()
- rpki.rtr.server.read_current = clock.read_current
-
- try:
- server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
- old_serial = server.get_serial()
- logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial)
- while clock:
- new_serial = server.get_serial()
- if old_serial != new_serial:
- logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial)
- server.notify()
- old_serial = new_serial
- asyncore.loop(timeout = clock.siesta(), count = 1)
- except KeyboardInterrupt:
- sys.exit(0)
+ server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
+ old_serial = server.get_serial()
+ logger.debug("[Starting at serial %d (%s)]", old_serial, old_serial)
+ while clock:
+ new_serial = server.get_serial()
+ if old_serial != new_serial:
+ logger.debug("[Serial bumped from %d (%s) to %d (%s)]", old_serial, old_serial, new_serial, new_serial)
+ server.notify()
+ old_serial = new_serial
+ asyncore.loop(timeout = clock.siesta(), count = 1)
+ except KeyboardInterrupt:
+ sys.exit(0)
def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__,
- help = "Convert bgpdump to fake ROAs")
- subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog")
- subparser.add_argument("files", nargs = "+", help = "input files")
-
- subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__,
- help = "Set current serial number for fake ROA data")
- subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog")
- subparser.add_argument("ax_file", help = "name of the .ax to select")
-
- subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__,
- help = "Replay fake ROAs generated from historical data")
- subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+ """
+ Set up argparse stuff for commands in this module.
+ """
+
+ subparser = subparsers.add_parser("bgpdump-convert", description = bgpdump_convert_main.__doc__,
+ help = "Convert bgpdump to fake ROAs")
+ subparser.set_defaults(func = bgpdump_convert_main, default_log_to = "syslog")
+ subparser.add_argument("files", nargs = "+", help = "input files")
+
+ subparser = subparsers.add_parser("bgpdump-select", description = bgpdump_select_main.__doc__,
+ help = "Set current serial number for fake ROA data")
+ subparser.set_defaults(func = bgpdump_select_main, default_log_to = "syslog")
+ subparser.add_argument("ax_file", help = "name of the .ax to select")
+
+ subparser = subparsers.add_parser("bgpdump-server", description = bgpdump_server_main.__doc__,
+ help = "Replay fake ROAs generated from historical data")
+ subparser.set_defaults(func = bgpdump_server_main, default_log_to = "syslog")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
diff --git a/rpki/rtr/channels.py b/rpki/rtr/channels.py
index d14c024d..e2f443e8 100644
--- a/rpki/rtr/channels.py
+++ b/rpki/rtr/channels.py
@@ -32,215 +32,215 @@ import rpki.rtr.pdus
class Timestamp(int):
- """
- Wrapper around time module.
- """
-
- def __new__(cls, t):
- # __new__() is a static method, not a class method, hence the odd calling sequence.
- return super(Timestamp, cls).__new__(cls, t)
-
- @classmethod
- def now(cls, delta = 0):
- return cls(time.time() + delta)
-
- def __str__(self):
- return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self))
-
-
-class ReadBuffer(object):
- """
- Wrapper around synchronous/asynchronous read state.
-
- This also handles tracking the current protocol version,
- because it has to go somewhere and there's no better place.
- """
-
- def __init__(self):
- self.buffer = ""
- self.version = None
-
- def update(self, need, callback):
"""
- Update count of needed bytes and callback, then dispatch to callback.
+ Wrapper around time module.
"""
- self.need = need
- self.callback = callback
- return self.retry()
+ def __new__(cls, t):
+ # __new__() is a static method, not a class method, hence the odd calling sequence.
+ return super(Timestamp, cls).__new__(cls, t)
- def retry(self):
- """
- Try dispatching to the callback again.
- """
+ @classmethod
+ def now(cls, delta = 0):
+ return cls(time.time() + delta)
- return self.callback(self)
+ def __str__(self):
+ return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self))
- def available(self):
- """
- How much data do we have available in this buffer?
- """
- return len(self.buffer)
-
- def needed(self):
- """
- How much more data does this buffer need to become ready?
+class ReadBuffer(object):
"""
+ Wrapper around synchronous/asynchronous read state.
- return self.need - self.available()
-
- def ready(self):
- """
- Is this buffer ready to read yet?
+ This also handles tracking the current protocol version,
+ because it has to go somewhere and there's no better place.
"""
- return self.available() >= self.need
+ def __init__(self):
+ self.buffer = ""
+ self.version = None
- def get(self, n):
- """
- Hand some data to the caller.
- """
+ def update(self, need, callback):
+ """
+ Update count of needed bytes and callback, then dispatch to callback.
+ """
- b = self.buffer[:n]
- self.buffer = self.buffer[n:]
- return b
+ self.need = need
+ self.callback = callback
+ return self.retry()
- def put(self, b):
- """
- Accumulate some data.
- """
+ def retry(self):
+ """
+ Try dispatching to the callback again.
+ """
- self.buffer += b
+ return self.callback(self)
- def check_version(self, version):
- """
- Track version number of PDUs read from this buffer.
- Once set, the version must not change.
- """
+ def available(self):
+ """
+ How much data do we have available in this buffer?
+ """
- if self.version is not None and version != self.version:
- raise rpki.rtr.pdus.CorruptData(
- "Received PDU version %d, expected %d" % (version, self.version))
- if self.version is None and version not in rpki.rtr.pdus.PDU.version_map:
- raise rpki.rtr.pdus.UnsupportedProtocolVersion(
- "Received PDU version %s, known versions %s" % (
- version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map)))
- self.version = version
+ return len(self.buffer)
+ def needed(self):
+ """
+ How much more data does this buffer need to become ready?
+ """
-class PDUChannel(asynchat.async_chat, object):
- """
- asynchat subclass that understands our PDUs. This just handles
- network I/O. Specific engines (client, server) should be subclasses
- of this with methods that do something useful with the resulting
- PDUs.
- """
-
- def __init__(self, root_pdu_class, sock = None):
- asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super()
- self.reader = ReadBuffer()
- assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU)
- self.root_pdu_class = root_pdu_class
-
- @property
- def version(self):
- return self.reader.version
-
- @version.setter
- def version(self, version):
- self.reader.check_version(version)
-
- def start_new_pdu(self):
- """
- Start read of a new PDU.
- """
-
- try:
- p = self.root_pdu_class.read_pdu(self.reader)
- while p is not None:
- self.deliver_pdu(p)
- p = self.root_pdu_class.read_pdu(self.reader)
- except rpki.rtr.pdus.PDUException, e:
- self.push_pdu(e.make_error_report(version = self.version))
- self.close_when_done()
- else:
- assert not self.reader.ready()
- self.set_terminator(self.reader.needed())
-
- def collect_incoming_data(self, data):
- """
- Collect data into the read buffer.
- """
-
- self.reader.put(data)
-
- def found_terminator(self):
- """
- Got requested data, see if we now have a PDU. If so, pass it
- along, then restart cycle for a new PDU.
- """
-
- p = self.reader.retry()
- if p is None:
- self.set_terminator(self.reader.needed())
- else:
- self.deliver_pdu(p)
- self.start_new_pdu()
-
- def push_pdu(self, pdu):
- """
- Write PDU to stream.
- """
+ return self.need - self.available()
- try:
- self.push(pdu.to_pdu())
- except OSError, e:
- if e.errno != errno.EAGAIN:
- raise
+ def ready(self):
+ """
+ Is this buffer ready to read yet?
+ """
- def log(self, msg):
- """
- Intercept asyncore's logging.
- """
+ return self.available() >= self.need
- logging.info(msg)
+ def get(self, n):
+ """
+ Hand some data to the caller.
+ """
- def log_info(self, msg, tag = "info"):
- """
- Intercept asynchat's logging.
- """
+ b = self.buffer[:n]
+ self.buffer = self.buffer[n:]
+ return b
- logging.info("asynchat: %s: %s", tag, msg)
+ def put(self, b):
+ """
+ Accumulate some data.
+ """
- def handle_error(self):
- """
- Handle errors caught by asyncore main loop.
- """
+ self.buffer += b
- logging.exception("[Unhandled exception]")
- logging.critical("[Exiting after unhandled exception]")
- sys.exit(1)
+ def check_version(self, version):
+ """
+ Track version number of PDUs read from this buffer.
+ Once set, the version must not change.
+ """
- def init_file_dispatcher(self, fd):
- """
- Kludge to plug asyncore.file_dispatcher into asynchat. Call from
- subclass's __init__() method, after calling
- PDUChannel.__init__(), and don't read this on a full stomach.
- """
+ if self.version is not None and version != self.version:
+ raise rpki.rtr.pdus.CorruptData(
+ "Received PDU version %d, expected %d" % (version, self.version))
+ if self.version is None and version not in rpki.rtr.pdus.PDU.version_map:
+ raise rpki.rtr.pdus.UnsupportedProtocolVersion(
+ "Received PDU version %s, known versions %s" % (
+ version, ", ".join(str(v) for v in rpki.rtr.pdus.PDU.version_map)))
+ self.version = version
- self.connected = True
- self._fileno = fd
- self.socket = asyncore.file_wrapper(fd)
- self.add_channel()
- flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
- flags = flags | os.O_NONBLOCK
- fcntl.fcntl(fd, fcntl.F_SETFL, flags)
- def handle_close(self):
- """
- Exit when channel closed.
+class PDUChannel(asynchat.async_chat, object):
"""
-
- asynchat.async_chat.handle_close(self)
- sys.exit(0)
+ asynchat subclass that understands our PDUs. This just handles
+ network I/O. Specific engines (client, server) should be subclasses
+ of this with methods that do something useful with the resulting
+ PDUs.
+ """
+
+ def __init__(self, root_pdu_class, sock = None):
+ asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super()
+ self.reader = ReadBuffer()
+ assert issubclass(root_pdu_class, rpki.rtr.pdus.PDU)
+ self.root_pdu_class = root_pdu_class
+
+ @property
+ def version(self):
+ return self.reader.version
+
+ @version.setter
+ def version(self, version):
+ self.reader.check_version(version)
+
+ def start_new_pdu(self):
+ """
+ Start read of a new PDU.
+ """
+
+ try:
+ p = self.root_pdu_class.read_pdu(self.reader)
+ while p is not None:
+ self.deliver_pdu(p)
+ p = self.root_pdu_class.read_pdu(self.reader)
+ except rpki.rtr.pdus.PDUException, e:
+ self.push_pdu(e.make_error_report(version = self.version))
+ self.close_when_done()
+ else:
+ assert not self.reader.ready()
+ self.set_terminator(self.reader.needed())
+
+ def collect_incoming_data(self, data):
+ """
+ Collect data into the read buffer.
+ """
+
+ self.reader.put(data)
+
+ def found_terminator(self):
+ """
+ Got requested data, see if we now have a PDU. If so, pass it
+ along, then restart cycle for a new PDU.
+ """
+
+ p = self.reader.retry()
+ if p is None:
+ self.set_terminator(self.reader.needed())
+ else:
+ self.deliver_pdu(p)
+ self.start_new_pdu()
+
+ def push_pdu(self, pdu):
+ """
+ Write PDU to stream.
+ """
+
+ try:
+ self.push(pdu.to_pdu())
+ except OSError, e:
+ if e.errno != errno.EAGAIN:
+ raise
+
+ def log(self, msg):
+ """
+ Intercept asyncore's logging.
+ """
+
+ logging.info(msg)
+
+ def log_info(self, msg, tag = "info"):
+ """
+ Intercept asynchat's logging.
+ """
+
+ logging.info("asynchat: %s: %s", tag, msg)
+
+ def handle_error(self):
+ """
+ Handle errors caught by asyncore main loop.
+ """
+
+ logging.exception("[Unhandled exception]")
+ logging.critical("[Exiting after unhandled exception]")
+ sys.exit(1)
+
+ def init_file_dispatcher(self, fd):
+ """
+ Kludge to plug asyncore.file_dispatcher into asynchat. Call from
+ subclass's __init__() method, after calling
+ PDUChannel.__init__(), and don't read this on a full stomach.
+ """
+
+ self.connected = True
+ self._fileno = fd
+ self.socket = asyncore.file_wrapper(fd)
+ self.add_channel()
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+ def handle_close(self):
+ """
+ Exit when channel closed.
+ """
+
+ asynchat.async_chat.handle_close(self)
+ sys.exit(0)
diff --git a/rpki/rtr/client.py b/rpki/rtr/client.py
index a35ab81d..9c7a00d6 100644
--- a/rpki/rtr/client.py
+++ b/rpki/rtr/client.py
@@ -37,13 +37,13 @@ from rpki.rtr.channels import Timestamp
class PDU(rpki.rtr.pdus.PDU):
- def consume(self, client):
- """
- Handle results in test client. Default behavior is just to print
- out the PDU; data PDU subclasses may override this.
- """
+ def consume(self, client):
+ """
+ Handle results in test client. Default behavior is just to print
+ out the PDU; data PDU subclasses may override this.
+ """
- logging.debug(self)
+ logging.debug(self)
clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU)
@@ -52,407 +52,407 @@ clone_pdu = rpki.rtr.pdus.clone_pdu_root(PDU)
@clone_pdu
class SerialNotifyPDU(rpki.rtr.pdus.SerialNotifyPDU):
- def consume(self, client):
- """
- Respond to a SerialNotifyPDU with either a SerialQueryPDU or a
- ResetQueryPDU, depending on what we already know.
- """
+ def consume(self, client):
+ """
+ Respond to a SerialNotifyPDU with either a SerialQueryPDU or a
+ ResetQueryPDU, depending on what we already know.
+ """
- logging.debug(self)
- if client.serial is None or client.nonce != self.nonce:
- client.push_pdu(ResetQueryPDU(version = client.version))
- elif self.serial != client.serial:
- client.push_pdu(SerialQueryPDU(version = client.version,
- serial = client.serial,
- nonce = client.nonce))
- else:
- logging.debug("[Notify did not change serial number, ignoring]")
+ logging.debug(self)
+ if client.serial is None or client.nonce != self.nonce:
+ client.push_pdu(ResetQueryPDU(version = client.version))
+ elif self.serial != client.serial:
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.serial,
+ nonce = client.nonce))
+ else:
+ logging.debug("[Notify did not change serial number, ignoring]")
@clone_pdu
class CacheResponsePDU(rpki.rtr.pdus.CacheResponsePDU):
- def consume(self, client):
- """
- Handle CacheResponsePDU.
- """
+ def consume(self, client):
+ """
+ Handle CacheResponsePDU.
+ """
- logging.debug(self)
- if self.nonce != client.nonce:
- logging.debug("[Nonce changed, resetting]")
- client.cache_reset()
+ logging.debug(self)
+ if self.nonce != client.nonce:
+ logging.debug("[Nonce changed, resetting]")
+ client.cache_reset()
@clone_pdu
class EndOfDataPDUv0(rpki.rtr.pdus.EndOfDataPDUv0):
- def consume(self, client):
- """
- Handle EndOfDataPDU response.
- """
+ def consume(self, client):
+ """
+ Handle EndOfDataPDU response.
+ """
- logging.debug(self)
- client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+ logging.debug(self)
+ client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
@clone_pdu
class EndOfDataPDUv1(rpki.rtr.pdus.EndOfDataPDUv1):
- def consume(self, client):
- """
- Handle EndOfDataPDU response.
- """
+ def consume(self, client):
+ """
+ Handle EndOfDataPDU response.
+ """
- logging.debug(self)
- client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+ logging.debug(self)
+ client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire)
@clone_pdu
class CacheResetPDU(rpki.rtr.pdus.CacheResetPDU):
- def consume(self, client):
- """
- Handle CacheResetPDU response, by issuing a ResetQueryPDU.
- """
+ def consume(self, client):
+ """
+ Handle CacheResetPDU response, by issuing a ResetQueryPDU.
+ """
- logging.debug(self)
- client.cache_reset()
- client.push_pdu(ResetQueryPDU(version = client.version))
+ logging.debug(self)
+ client.cache_reset()
+ client.push_pdu(ResetQueryPDU(version = client.version))
class PrefixPDU(rpki.rtr.pdus.PrefixPDU):
- """
- Object representing one prefix. This corresponds closely to one PDU
- in the rpki-router protocol, so closely that we use lexical ordering
- of the wire format of the PDU as the ordering for this class.
-
- This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (IPv4PrefixPDU or
- IPv6PrefixPDU) depending on the syntax of its input text.
- """
-
- def consume(self, client):
"""
- Handle one incoming prefix PDU
+ Object representing one prefix. This corresponds closely to one PDU
+ in the rpki-router protocol, so closely that we use lexical ordering
+ of the wire format of the PDU as the ordering for this class.
+
+ This is a virtual class, but the .from_text() constructor
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
"""
- logging.debug(self)
- client.consume_prefix(self)
+ def consume(self, client):
+ """
+ Handle one incoming prefix PDU
+ """
+
+ logging.debug(self)
+ client.consume_prefix(self)
@clone_pdu
class IPv4PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv4PrefixPDU):
- pass
+ pass
@clone_pdu
class IPv6PrefixPDU(PrefixPDU, rpki.rtr.pdus.IPv6PrefixPDU):
- pass
+ pass
@clone_pdu
class ErrorReportPDU(PDU, rpki.rtr.pdus.ErrorReportPDU):
- pass
+ pass
@clone_pdu
class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU):
- """
- Router Key PDU.
- """
-
- def consume(self, client):
"""
- Handle one incoming Router Key PDU
+ Router Key PDU.
"""
- logging.debug(self)
- client.consume_routerkey(self)
+ def consume(self, client):
+ """
+ Handle one incoming Router Key PDU
+ """
+ logging.debug(self)
+ client.consume_routerkey(self)
-class ClientChannel(rpki.rtr.channels.PDUChannel):
- """
- Client protocol engine, handles upcalls from PDUChannel.
- """
-
- serial = None
- nonce = None
- sql = None
- host = None
- port = None
- cache_id = None
- refresh = rpki.rtr.pdus.default_refresh
- retry = rpki.rtr.pdus.default_retry
- expire = rpki.rtr.pdus.default_expire
- updated = Timestamp(0)
-
- def __init__(self, sock, proc, killsig, args, host = None, port = None):
- self.killsig = killsig
- self.proc = proc
- self.args = args
- self.host = args.host if host is None else host
- self.port = args.port if port is None else port
- super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU)
- if args.force_version is not None:
- self.version = args.force_version
- self.start_new_pdu()
- if args.sql_database:
- self.setup_sql()
-
- @classmethod
- def ssh(cls, args):
- """
- Set up ssh connection and start listening for first PDU.
- """
- if args.port is None:
- argv = ("ssh", "-s", args.host, "rpki-rtr")
- else:
- argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr")
- logging.debug("[Running ssh: %s]", " ".join(argv))
- s = socket.socketpair()
- return cls(sock = s[1],
- proc = subprocess.Popen(argv, executable = "/usr/bin/ssh",
- stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGKILL, args = args)
-
- @classmethod
- def tcp(cls, args):
- """
- Set up TCP connection and start listening for first PDU.
+class ClientChannel(rpki.rtr.channels.PDUChannel):
"""
-
- logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port)
- try:
- addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM)
- except socket.error, e:
- logging.debug("[socket.getaddrinfo() failed: %s]", e)
- else:
- for ai in addrinfo:
- af, socktype, proto, cn, sa = ai # pylint: disable=W0612
- logging.debug("[Trying addr %s port %s]", sa[0], sa[1])
+ Client protocol engine, handles upcalls from PDUChannel.
+ """
+
+ serial = None
+ nonce = None
+ sql = None
+ host = None
+ port = None
+ cache_id = None
+ refresh = rpki.rtr.pdus.default_refresh
+ retry = rpki.rtr.pdus.default_retry
+ expire = rpki.rtr.pdus.default_expire
+ updated = Timestamp(0)
+
+ def __init__(self, sock, proc, killsig, args, host = None, port = None):
+ self.killsig = killsig
+ self.proc = proc
+ self.args = args
+ self.host = args.host if host is None else host
+ self.port = args.port if port is None else port
+ super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU)
+ if args.force_version is not None:
+ self.version = args.force_version
+ self.start_new_pdu()
+ if args.sql_database:
+ self.setup_sql()
+
+ @classmethod
+ def ssh(cls, args):
+ """
+ Set up ssh connection and start listening for first PDU.
+ """
+
+ if args.port is None:
+ argv = ("ssh", "-s", args.host, "rpki-rtr")
+ else:
+ argv = ("ssh", "-p", args.port, "-s", args.host, "rpki-rtr")
+ logging.debug("[Running ssh: %s]", " ".join(argv))
+ s = socket.socketpair()
+ return cls(sock = s[1],
+ proc = subprocess.Popen(argv, executable = "/usr/bin/ssh",
+ stdin = s[0], stdout = s[0], close_fds = True),
+ killsig = signal.SIGKILL, args = args)
+
+ @classmethod
+ def tcp(cls, args):
+ """
+ Set up TCP connection and start listening for first PDU.
+ """
+
+ logging.debug("[Starting raw TCP connection to %s:%s]", args.host, args.port)
try:
- s = socket.socket(af, socktype, proto)
+ addrinfo = socket.getaddrinfo(args.host, args.port, socket.AF_UNSPEC, socket.SOCK_STREAM)
except socket.error, e:
- logging.debug("[socket.socket() failed: %s]", e)
- continue
+ logging.debug("[socket.getaddrinfo() failed: %s]", e)
+ else:
+ for ai in addrinfo:
+ af, socktype, proto, cn, sa = ai # pylint: disable=W0612
+ logging.debug("[Trying addr %s port %s]", sa[0], sa[1])
+ try:
+ s = socket.socket(af, socktype, proto)
+ except socket.error, e:
+ logging.debug("[socket.socket() failed: %s]", e)
+ continue
+ try:
+ s.connect(sa)
+ except socket.error, e:
+ logging.exception("[socket.connect() failed: %s]", e)
+ s.close()
+ continue
+ return cls(sock = s, proc = None, killsig = None, args = args)
+ sys.exit(1)
+
+ @classmethod
+ def loopback(cls, args):
+ """
+ Set up loopback connection and start listening for first PDU.
+ """
+
+ s = socket.socketpair()
+ logging.debug("[Using direct subprocess kludge for testing]")
+ argv = (sys.executable, sys.argv[0], "server")
+ return cls(sock = s[1],
+ proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
+ killsig = signal.SIGINT, args = args,
+ host = args.host or "none", port = args.port or "none")
+
+ @classmethod
+ def tls(cls, args):
+ """
+ Set up TLS connection and start listening for first PDU.
+
+ NB: This uses OpenSSL's "s_client" command, which does not
+ check server certificates properly, so this is not suitable for
+ production use. Fixing this would be a trivial change, it just
+ requires using a client program which does check certificates
+ properly (eg, gnutls-cli, or stunnel's client mode if that works
+ for such purposes this week).
+ """
+
+ argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port))
+ logging.debug("[Running: %s]", " ".join(argv))
+ s = socket.socketpair()
+ return cls(sock = s[1],
+ proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
+ killsig = signal.SIGKILL, args = args)
+
+ def setup_sql(self):
+ """
+ Set up an SQLite database to contain the table we receive. If
+ necessary, we will create the database.
+ """
+
+ import sqlite3
+ missing = not os.path.exists(self.args.sql_database)
+ self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES)
+ self.sql.text_factory = str
+ cur = self.sql.cursor()
+ cur.execute("PRAGMA foreign_keys = on")
+ if missing:
+ cur.execute('''
+ CREATE TABLE cache (
+ cache_id INTEGER PRIMARY KEY NOT NULL,
+ host TEXT NOT NULL,
+ port TEXT NOT NULL,
+ version INTEGER,
+ nonce INTEGER,
+ serial INTEGER,
+ updated INTEGER,
+ refresh INTEGER,
+ retry INTEGER,
+ expire INTEGER,
+ UNIQUE (host, port))''')
+ cur.execute('''
+ CREATE TABLE prefix (
+ cache_id INTEGER NOT NULL
+ REFERENCES cache(cache_id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE,
+ asn INTEGER NOT NULL,
+ prefix TEXT NOT NULL,
+ prefixlen INTEGER NOT NULL,
+ max_prefixlen INTEGER NOT NULL,
+ UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''')
+ cur.execute('''
+ CREATE TABLE routerkey (
+ cache_id INTEGER NOT NULL
+ REFERENCES cache(cache_id)
+ ON DELETE CASCADE
+ ON UPDATE CASCADE,
+ asn INTEGER NOT NULL,
+ ski TEXT NOT NULL,
+ key TEXT NOT NULL,
+ UNIQUE (cache_id, asn, ski),
+ UNIQUE (cache_id, asn, key))''')
+ elif self.args.reset_session:
+ cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
+ cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated "
+ "FROM cache WHERE host = ? AND port = ?",
+ (self.host, self.port))
try:
- s.connect(sa)
- except socket.error, e:
- logging.exception("[socket.connect() failed: %s]", e)
- s.close()
- continue
- return cls(sock = s, proc = None, killsig = None, args = args)
- sys.exit(1)
-
- @classmethod
- def loopback(cls, args):
- """
- Set up loopback connection and start listening for first PDU.
- """
-
- s = socket.socketpair()
- logging.debug("[Using direct subprocess kludge for testing]")
- argv = (sys.executable, sys.argv[0], "server")
- return cls(sock = s[1],
- proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGINT, args = args,
- host = args.host or "none", port = args.port or "none")
-
- @classmethod
- def tls(cls, args):
- """
- Set up TLS connection and start listening for first PDU.
-
- NB: This uses OpenSSL's "s_client" command, which does not
- check server certificates properly, so this is not suitable for
- production use. Fixing this would be a trivial change, it just
- requires using a client program which does check certificates
- properly (eg, gnutls-cli, or stunnel's client mode if that works
- for such purposes this week).
- """
-
- argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (args.host, args.port))
- logging.debug("[Running: %s]", " ".join(argv))
- s = socket.socketpair()
- return cls(sock = s[1],
- proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGKILL, args = args)
-
- def setup_sql(self):
- """
- Set up an SQLite database to contain the table we receive. If
- necessary, we will create the database.
- """
-
- import sqlite3
- missing = not os.path.exists(self.args.sql_database)
- self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES)
- self.sql.text_factory = str
- cur = self.sql.cursor()
- cur.execute("PRAGMA foreign_keys = on")
- if missing:
- cur.execute('''
- CREATE TABLE cache (
- cache_id INTEGER PRIMARY KEY NOT NULL,
- host TEXT NOT NULL,
- port TEXT NOT NULL,
- version INTEGER,
- nonce INTEGER,
- serial INTEGER,
- updated INTEGER,
- refresh INTEGER,
- retry INTEGER,
- expire INTEGER,
- UNIQUE (host, port))''')
- cur.execute('''
- CREATE TABLE prefix (
- cache_id INTEGER NOT NULL
- REFERENCES cache(cache_id)
- ON DELETE CASCADE
- ON UPDATE CASCADE,
- asn INTEGER NOT NULL,
- prefix TEXT NOT NULL,
- prefixlen INTEGER NOT NULL,
- max_prefixlen INTEGER NOT NULL,
- UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''')
- cur.execute('''
- CREATE TABLE routerkey (
- cache_id INTEGER NOT NULL
- REFERENCES cache(cache_id)
- ON DELETE CASCADE
- ON UPDATE CASCADE,
- asn INTEGER NOT NULL,
- ski TEXT NOT NULL,
- key TEXT NOT NULL,
- UNIQUE (cache_id, asn, ski),
- UNIQUE (cache_id, asn, key))''')
- elif self.args.reset_session:
- cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
- cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated "
- "FROM cache WHERE host = ? AND port = ?",
- (self.host, self.port))
- try:
- self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone()
- if version is not None and self.version is not None and version != self.version:
- cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
- raise TypeError # Simulate lookup failure case
- if version is not None:
- self.version = version
- if refresh is not None:
+ self.cache_id, version, self.nonce, self.serial, refresh, retry, expire, updated = cur.fetchone()
+ if version is not None and self.version is not None and version != self.version:
+ cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
+ raise TypeError # Simulate lookup failure case
+ if version is not None:
+ self.version = version
+ if refresh is not None:
+ self.refresh = refresh
+ if retry is not None:
+ self.retry = retry
+ if expire is not None:
+ self.expire = expire
+ if updated is not None:
+ self.updated = Timestamp(updated)
+ except TypeError:
+ cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port))
+ self.cache_id = cur.lastrowid
+ self.sql.commit()
+ logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]",
+ self.cache_id, self.version, self.nonce,
+ self.serial, self.refresh, self.retry, self.expire, self.updated)
+
+ def cache_reset(self):
+ """
+ Handle CacheResetPDU actions.
+ """
+
+ self.serial = None
+ if self.sql:
+ cur = self.sql.cursor()
+ cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,))
+ cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id))
+ self.sql.commit()
+
+ def end_of_data(self, version, serial, nonce, refresh, retry, expire):
+ """
+ Handle EndOfDataPDU actions.
+ """
+
+ assert version == self.version
+ self.serial = serial
+ self.nonce = nonce
self.refresh = refresh
- if retry is not None:
- self.retry = retry
- if expire is not None:
- self.expire = expire
- if updated is not None:
- self.updated = Timestamp(updated)
- except TypeError:
- cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port))
- self.cache_id = cur.lastrowid
- self.sql.commit()
- logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s updated %s]",
- self.cache_id, self.version, self.nonce,
- self.serial, self.refresh, self.retry, self.expire, self.updated)
-
- def cache_reset(self):
- """
- Handle CacheResetPDU actions.
- """
-
- self.serial = None
- if self.sql:
- cur = self.sql.cursor()
- cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,))
- cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,))
- cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id))
- self.sql.commit()
-
- def end_of_data(self, version, serial, nonce, refresh, retry, expire):
- """
- Handle EndOfDataPDU actions.
- """
-
- assert version == self.version
- self.serial = serial
- self.nonce = nonce
- self.refresh = refresh
- self.retry = retry
- self.expire = expire
- self.updated = Timestamp.now()
- if self.sql:
- self.sql.execute("UPDATE cache SET"
- " version = ?, serial = ?, nonce = ?,"
- " refresh = ?, retry = ?, expire = ?,"
- " updated = ? "
- "WHERE cache_id = ?",
- (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id))
- self.sql.commit()
-
- def consume_prefix(self, prefix):
- """
- Handle one prefix PDU.
- """
-
- if self.sql:
- values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen)
- if prefix.announce:
- self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) "
- "VALUES (?, ?, ?, ?, ?)",
- values)
- else:
- self.sql.execute("DELETE FROM prefix "
- "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?",
- values)
-
- def consume_routerkey(self, routerkey):
- """
- Handle one Router Key PDU.
- """
-
- if self.sql:
- values = (self.cache_id, routerkey.asn,
- base64.urlsafe_b64encode(routerkey.ski).rstrip("="),
- base64.b64encode(routerkey.key))
- if routerkey.announce:
- self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) "
- "VALUES (?, ?, ?, ?)",
- values)
- else:
- self.sql.execute("DELETE FROM routerkey "
- "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)",
- values)
-
- def deliver_pdu(self, pdu):
- """
- Handle received PDU.
- """
-
- pdu.consume(self)
-
- def push_pdu(self, pdu):
- """
- Log outbound PDU then write it to stream.
- """
-
- logging.debug(pdu)
- super(ClientChannel, self).push_pdu(pdu)
-
- def cleanup(self):
- """
- Force clean up this client's child process. If everything goes
- well, child will have exited already before this method is called,
- but we may need to whack it with a stick if something breaks.
- """
-
- if self.proc is not None and self.proc.returncode is None:
- try:
- os.kill(self.proc.pid, self.killsig)
- except OSError:
- pass
-
- def handle_close(self):
- """
- Intercept close event so we can log it, then shut down.
- """
-
- logging.debug("Server closed channel")
- super(ClientChannel, self).handle_close()
+ self.retry = retry
+ self.expire = expire
+ self.updated = Timestamp.now()
+ if self.sql:
+ self.sql.execute("UPDATE cache SET"
+ " version = ?, serial = ?, nonce = ?,"
+ " refresh = ?, retry = ?, expire = ?,"
+ " updated = ? "
+ "WHERE cache_id = ?",
+ (version, serial, nonce, refresh, retry, expire, int(self.updated), self.cache_id))
+ self.sql.commit()
+
+ def consume_prefix(self, prefix):
+ """
+ Handle one prefix PDU.
+ """
+
+ if self.sql:
+ values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen)
+ if prefix.announce:
+ self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) "
+ "VALUES (?, ?, ?, ?, ?)",
+ values)
+ else:
+ self.sql.execute("DELETE FROM prefix "
+ "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?",
+ values)
+
+ def consume_routerkey(self, routerkey):
+ """
+ Handle one Router Key PDU.
+ """
+
+ if self.sql:
+ values = (self.cache_id, routerkey.asn,
+ base64.urlsafe_b64encode(routerkey.ski).rstrip("="),
+ base64.b64encode(routerkey.key))
+ if routerkey.announce:
+ self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) "
+ "VALUES (?, ?, ?, ?)",
+ values)
+ else:
+ self.sql.execute("DELETE FROM routerkey "
+ "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)",
+ values)
+
+ def deliver_pdu(self, pdu):
+ """
+ Handle received PDU.
+ """
+
+ pdu.consume(self)
+
+ def push_pdu(self, pdu):
+ """
+ Log outbound PDU then write it to stream.
+ """
+
+ logging.debug(pdu)
+ super(ClientChannel, self).push_pdu(pdu)
+
+ def cleanup(self):
+ """
+ Force clean up this client's child process. If everything goes
+ well, child will have exited already before this method is called,
+ but we may need to whack it with a stick if something breaks.
+ """
+
+ if self.proc is not None and self.proc.returncode is None:
+ try:
+ os.kill(self.proc.pid, self.killsig)
+ except OSError:
+ pass
+
+ def handle_close(self):
+ """
+ Intercept close event so we can log it, then shut down.
+ """
+
+ logging.debug("Server closed channel")
+ super(ClientChannel, self).handle_close()
# Hack to let us subclass this from scripts without needing to rewrite client_main().
@@ -460,73 +460,73 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
ClientChannelClass = ClientChannel
def client_main(args):
- """
- Test client, intended primarily for debugging.
- """
+ """
+ Test client, intended primarily for debugging.
+ """
- logging.debug("[Startup]")
+ logging.debug("[Startup]")
- assert issubclass(ClientChannelClass, ClientChannel)
- constructor = getattr(ClientChannelClass, args.protocol)
+ assert issubclass(ClientChannelClass, ClientChannel)
+ constructor = getattr(ClientChannelClass, args.protocol)
- client = None
- try:
- client = constructor(args)
+ client = None
+ try:
+ client = constructor(args)
- polled = client.updated
- wakeup = None
+ polled = client.updated
+ wakeup = None
- while True:
+ while True:
- now = Timestamp.now()
+ now = Timestamp.now()
- if client.serial is not None and now > client.updated + client.expire:
- logging.info("[Expiring client data: serial %s, last updated %s, expire %s]",
- client.serial, client.updated, client.expire)
- client.cache_reset()
+ if client.serial is not None and now > client.updated + client.expire:
+ logging.info("[Expiring client data: serial %s, last updated %s, expire %s]",
+ client.serial, client.updated, client.expire)
+ client.cache_reset()
- if client.serial is None or client.nonce is None:
- polled = now
- client.push_pdu(ResetQueryPDU(version = client.version))
+ if client.serial is None or client.nonce is None:
+ polled = now
+ client.push_pdu(ResetQueryPDU(version = client.version))
- elif now >= client.updated + client.refresh:
- polled = now
- client.push_pdu(SerialQueryPDU(version = client.version,
- serial = client.serial,
- nonce = client.nonce))
+ elif now >= client.updated + client.refresh:
+ polled = now
+ client.push_pdu(SerialQueryPDU(version = client.version,
+ serial = client.serial,
+ nonce = client.nonce))
- remaining = 1
+ remaining = 1
- while remaining > 0:
- now = Timestamp.now()
- timer = client.retry if (now >= client.updated + client.refresh) else client.refresh
- wokeup = wakeup
- wakeup = max(now, Timestamp(max(polled, client.updated) + timer))
- remaining = wakeup - now
- if wakeup != wokeup:
- logging.info("[Last client poll %s, next %s]", polled, wakeup)
- asyncore.loop(timeout = remaining, count = 1)
+ while remaining > 0:
+ now = Timestamp.now()
+ timer = client.retry if (now >= client.updated + client.refresh) else client.refresh
+ wokeup = wakeup
+ wakeup = max(now, Timestamp(max(polled, client.updated) + timer))
+ remaining = wakeup - now
+ if wakeup != wokeup:
+ logging.info("[Last client poll %s, next %s]", polled, wakeup)
+ asyncore.loop(timeout = remaining, count = 1)
- except KeyboardInterrupt:
- sys.exit(0)
+ except KeyboardInterrupt:
+ sys.exit(0)
- finally:
- if client is not None:
- client.cleanup()
+ finally:
+ if client is not None:
+ client.cleanup()
def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- subparser = subparsers.add_parser("client", description = client_main.__doc__,
- help = "Test client for RPKI-RTR protocol")
- subparser.set_defaults(func = client_main, default_log_to = "stderr")
- subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state")
- subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version")
- subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database")
- subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol")
- subparser.add_argument("host", nargs = "?", help = "server host")
- subparser.add_argument("port", nargs = "?", help = "server port")
- return subparser
+ """
+ Set up argparse stuff for commands in this module.
+ """
+
+ subparser = subparsers.add_parser("client", description = client_main.__doc__,
+ help = "Test client for RPKI-RTR protocol")
+ subparser.set_defaults(func = client_main, default_log_to = "stderr")
+ subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state")
+ subparser.add_argument("--force-version", type = int, choices = PDU.version_map, help = "force specific protocol version")
+ subparser.add_argument("--reset-session", action = "store_true", help = "reset any existing session found in sqlite3 database")
+ subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol")
+ subparser.add_argument("host", nargs = "?", help = "server host")
+ subparser.add_argument("port", nargs = "?", help = "server port")
+ return subparser
diff --git a/rpki/rtr/generator.py b/rpki/rtr/generator.py
index 26e25b6e..e00e44b7 100644
--- a/rpki/rtr/generator.py
+++ b/rpki/rtr/generator.py
@@ -37,539 +37,539 @@ import rpki.rtr.server
from rpki.rtr.channels import Timestamp
class PrefixPDU(rpki.rtr.pdus.PrefixPDU):
- """
- Object representing one prefix. This corresponds closely to one PDU
- in the rpki-router protocol, so closely that we use lexical ordering
- of the wire format of the PDU as the ordering for this class.
-
- This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (IPv4PrefixPDU or
- IPv6PrefixPDU) depending on the syntax of its input text.
- """
-
- @staticmethod
- def from_text(version, asn, addr):
- """
- Construct a prefix from its text form.
- """
-
- cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
- self = cls(version = version)
- self.asn = long(asn)
- p, l = addr.split("/")
- self.prefix = rpki.POW.IPAddress(p)
- if "-" in l:
- self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-"))
- else:
- self.prefixlen = self.max_prefixlen = int(l)
- self.announce = 1
- self.check()
- return self
-
- @staticmethod
- def from_roa(version, asn, prefix_tuple):
- """
- Construct a prefix from a ROA.
"""
-
- address, length, maxlength = prefix_tuple
- cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
- self = cls(version = version)
- self.asn = asn
- self.prefix = address
- self.prefixlen = length
- self.max_prefixlen = length if maxlength is None else maxlength
- self.announce = 1
- self.check()
- return self
+ Object representing one prefix. This corresponds closely to one PDU
+ in the rpki-router protocol, so closely that we use lexical ordering
+ of the wire format of the PDU as the ordering for this class.
+
+ This is a virtual class, but the .from_text() constructor
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
+ """
+
+ @staticmethod
+ def from_text(version, asn, addr):
+ """
+ Construct a prefix from its text form.
+ """
+
+ cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU
+ self = cls(version = version)
+ self.asn = long(asn)
+ p, l = addr.split("/")
+ self.prefix = rpki.POW.IPAddress(p)
+ if "-" in l:
+ self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-"))
+ else:
+ self.prefixlen = self.max_prefixlen = int(l)
+ self.announce = 1
+ self.check()
+ return self
+
+ @staticmethod
+ def from_roa(version, asn, prefix_tuple):
+ """
+ Construct a prefix from a ROA.
+ """
+
+ address, length, maxlength = prefix_tuple
+ cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU
+ self = cls(version = version)
+ self.asn = asn
+ self.prefix = address
+ self.prefixlen = length
+ self.max_prefixlen = length if maxlength is None else maxlength
+ self.announce = 1
+ self.check()
+ return self
class IPv4PrefixPDU(PrefixPDU):
- """
- IPv4 flavor of a prefix.
- """
+ """
+ IPv4 flavor of a prefix.
+ """
- pdu_type = 4
- address_byte_count = 4
+ pdu_type = 4
+ address_byte_count = 4
class IPv6PrefixPDU(PrefixPDU):
- """
- IPv6 flavor of a prefix.
- """
-
- pdu_type = 6
- address_byte_count = 16
-
-class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU):
- """
- Router Key PDU.
- """
-
- @classmethod
- def from_text(cls, version, asn, gski, key):
"""
- Construct a router key from its text form.
+ IPv6 flavor of a prefix.
"""
- self = cls(version = version)
- self.asn = long(asn)
- self.ski = base64.urlsafe_b64decode(gski + "=")
- self.key = base64.b64decode(key)
- self.announce = 1
- self.check()
- return self
+ pdu_type = 6
+ address_byte_count = 16
- @classmethod
- def from_certificate(cls, version, asn, ski, key):
+class RouterKeyPDU(rpki.rtr.pdus.RouterKeyPDU):
"""
- Construct a router key from a certificate.
+ Router Key PDU.
"""
- self = cls(version = version)
- self.asn = asn
- self.ski = ski
- self.key = key
- self.announce = 1
- self.check()
- return self
+ @classmethod
+ def from_text(cls, version, asn, gski, key):
+ """
+ Construct a router key from its text form.
+ """
+ self = cls(version = version)
+ self.asn = long(asn)
+ self.ski = base64.urlsafe_b64decode(gski + "=")
+ self.key = base64.b64decode(key)
+ self.announce = 1
+ self.check()
+ return self
-class ROA(rpki.POW.ROA): # pylint: disable=W0232
- """
- Minor additions to rpki.POW.ROA.
- """
-
- @classmethod
- def derReadFile(cls, fn): # pylint: disable=E1002
- self = super(ROA, cls).derReadFile(fn)
- self.extractWithoutVerifying()
- return self
-
- @property
- def prefixes(self):
- v4, v6 = self.getPrefixes()
- if v4 is not None:
- for p in v4:
- yield p
- if v6 is not None:
- for p in v6:
- yield p
-
-class X509(rpki.POW.X509): # pylint: disable=W0232
- """
- Minor additions to rpki.POW.X509.
- """
+ @classmethod
+ def from_certificate(cls, version, asn, ski, key):
+ """
+ Construct a router key from a certificate.
+ """
- @property
- def asns(self):
- resources = self.getRFC3779()
- if resources is not None and resources[0] is not None:
- for min_asn, max_asn in resources[0]:
- for asn in xrange(min_asn, max_asn + 1):
- yield asn
+ self = cls(version = version)
+ self.asn = asn
+ self.ski = ski
+ self.key = key
+ self.announce = 1
+ self.check()
+ return self
-class PDUSet(list):
- """
- Object representing a set of PDUs, that is, one versioned and
- (theoretically) consistant set of prefixes and router keys extracted
- from rcynic's output.
- """
-
- def __init__(self, version):
- assert version in rpki.rtr.pdus.PDU.version_map
- super(PDUSet, self).__init__()
- self.version = version
-
- @classmethod
- def _load_file(cls, filename, version):
+class ROA(rpki.POW.ROA): # pylint: disable=W0232
"""
- Low-level method to read PDUSet from a file.
+ Minor additions to rpki.POW.ROA.
"""
- self = cls(version = version)
- f = open(filename, "rb")
- r = rpki.rtr.channels.ReadBuffer()
- while True:
- p = rpki.rtr.pdus.PDU.read_pdu(r)
- while p is None:
- b = f.read(r.needed())
- if b == "":
- assert r.available() == 0
- return self
- r.put(b)
- p = r.retry()
- assert p.version == self.version
- self.append(p)
-
- @staticmethod
- def seq_ge(a, b):
- return ((a - b) % (1 << 32)) < (1 << 31)
+ @classmethod
+ def derReadFile(cls, fn): # pylint: disable=E1002
+ self = super(ROA, cls).derReadFile(fn)
+ self.extractWithoutVerifying()
+ return self
+ @property
+ def prefixes(self):
+ v4, v6 = self.getPrefixes()
+ if v4 is not None:
+ for p in v4:
+ yield p
+ if v6 is not None:
+ for p in v6:
+ yield p
-class AXFRSet(PDUSet):
- """
- Object representing a complete set of PDUs, that is, one versioned
- and (theoretically) consistant set of prefixes and router
- certificates extracted from rcynic's output, all with the announce
- field set.
- """
-
- @classmethod
- def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None):
+class X509(rpki.POW.X509): # pylint: disable=W0232
+ """
+ Minor additions to rpki.POW.X509.
"""
- Parse ROAS and router certificates fetched (and validated!) by
- rcynic to create a new AXFRSet.
- In normal operation, we use os.walk() and the rpki.POW library to
- parse these data directly, but we can, if so instructed, use
- external programs instead, for testing, simulation, or to provide
- a way to inject local data.
+ @property
+ def asns(self):
+ resources = self.getRFC3779()
+ if resources is not None and resources[0] is not None:
+ for min_asn, max_asn in resources[0]:
+ for asn in xrange(min_asn, max_asn + 1):
+ yield asn
- At some point the ability to parse these data from external
- programs may move to a separate constructor function, so that we
- can make this one a bit simpler and faster.
- """
- self = cls(version = version)
- self.serial = rpki.rtr.channels.Timestamp.now()
-
- include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version]
-
- if scan_roas is None or (scan_routercerts is None and include_routercerts):
- for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612
- for fn in files:
- if scan_roas is None and fn.endswith(".roa"):
- roa = ROA.derReadFile(os.path.join(root, fn))
- asn = roa.getASID()
- self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple)
- for prefix_tuple in roa.prefixes)
- if include_routercerts and scan_routercerts is None and fn.endswith(".cer"):
- x = X509.derReadFile(os.path.join(root, fn))
- eku = x.getEKU()
- if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
- ski = x.getSKI()
- key = x.getPublicKey().derWritePublic()
- self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key)
- for asn in x.asns)
-
- if scan_roas is not None:
- try:
- p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE)
- for line in p.stdout:
- line = line.split()
- asn = line[1]
- self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr)
- for addr in line[2:])
- except OSError, e:
- sys.exit("Could not run %s: %s" % (scan_roas, e))
-
- if include_routercerts and scan_routercerts is not None:
- try:
- p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE)
- for line in p.stdout:
- line = line.split()
- gski = line[0]
- key = line[-1]
- self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key)
- for asn in line[1:-1])
- except OSError, e:
- sys.exit("Could not run %s: %s" % (scan_routercerts, e))
-
- self.sort()
- for i in xrange(len(self) - 2, -1, -1):
- if self[i] == self[i + 1]:
- del self[i + 1]
- return self
-
- @classmethod
- def load(cls, filename):
- """
- Load an AXFRSet from a file, parse filename to obtain version and serial.
+class PDUSet(list):
"""
+ Object representing a set of PDUs, that is, one versioned and
+ (theoretically) consistant set of prefixes and router keys extracted
+ from rcynic's output.
+ """
+
+ def __init__(self, version):
+ assert version in rpki.rtr.pdus.PDU.version_map
+ super(PDUSet, self).__init__()
+ self.version = version
+
+ @classmethod
+ def _load_file(cls, filename, version):
+ """
+ Low-level method to read PDUSet from a file.
+ """
+
+ self = cls(version = version)
+ f = open(filename, "rb")
+ r = rpki.rtr.channels.ReadBuffer()
+ while True:
+ p = rpki.rtr.pdus.PDU.read_pdu(r)
+ while p is None:
+ b = f.read(r.needed())
+ if b == "":
+ assert r.available() == 0
+ return self
+ r.put(b)
+ p = r.retry()
+ assert p.version == self.version
+ self.append(p)
+
+ @staticmethod
+ def seq_ge(a, b):
+ return ((a - b) % (1 << 32)) < (1 << 31)
- fn1, fn2, fn3 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit()
- version = int(fn3[1:])
- self = cls._load_file(filename, version)
- self.serial = rpki.rtr.channels.Timestamp(fn1)
- return self
- def filename(self):
- """
- Generate filename for this AXFRSet.
+class AXFRSet(PDUSet):
"""
+ Object representing a complete set of PDUs, that is, one versioned
+ and (theoretically) consistant set of prefixes and router
+ certificates extracted from rcynic's output, all with the announce
+ field set.
+ """
+
+ @classmethod
+ def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None):
+ """
+ Parse ROAS and router certificates fetched (and validated!) by
+ rcynic to create a new AXFRSet.
+
+ In normal operation, we use os.walk() and the rpki.POW library to
+ parse these data directly, but we can, if so instructed, use
+ external programs instead, for testing, simulation, or to provide
+ a way to inject local data.
+
+ At some point the ability to parse these data from external
+ programs may move to a separate constructor function, so that we
+ can make this one a bit simpler and faster.
+ """
+
+ self = cls(version = version)
+ self.serial = rpki.rtr.channels.Timestamp.now()
+
+ include_routercerts = RouterKeyPDU.pdu_type in rpki.rtr.pdus.PDU.version_map[version]
+
+ if scan_roas is None or (scan_routercerts is None and include_routercerts):
+ for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612
+ for fn in files:
+ if scan_roas is None and fn.endswith(".roa"):
+ roa = ROA.derReadFile(os.path.join(root, fn))
+ asn = roa.getASID()
+ self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple)
+ for prefix_tuple in roa.prefixes)
+ if include_routercerts and scan_routercerts is None and fn.endswith(".cer"):
+ x = X509.derReadFile(os.path.join(root, fn))
+ eku = x.getEKU()
+ if eku is not None and rpki.oids.id_kp_bgpsec_router in eku:
+ ski = x.getSKI()
+ key = x.getPublicKey().derWritePublic()
+ self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key)
+ for asn in x.asns)
+
+ if scan_roas is not None:
+ try:
+ p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE)
+ for line in p.stdout:
+ line = line.split()
+ asn = line[1]
+ self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr)
+ for addr in line[2:])
+ except OSError, e:
+ sys.exit("Could not run %s: %s" % (scan_roas, e))
+
+ if include_routercerts and scan_routercerts is not None:
+ try:
+ p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE)
+ for line in p.stdout:
+ line = line.split()
+ gski = line[0]
+ key = line[-1]
+ self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key)
+ for asn in line[1:-1])
+ except OSError, e:
+ sys.exit("Could not run %s: %s" % (scan_routercerts, e))
+
+ self.sort()
+ for i in xrange(len(self) - 2, -1, -1):
+ if self[i] == self[i + 1]:
+ del self[i + 1]
+ return self
+
+ @classmethod
+ def load(cls, filename):
+ """
+ Load an AXFRSet from a file, parse filename to obtain version and serial.
+ """
+
+ fn1, fn2, fn3 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit()
+ version = int(fn3[1:])
+ self = cls._load_file(filename, version)
+ self.serial = rpki.rtr.channels.Timestamp(fn1)
+ return self
+
+ def filename(self):
+ """
+ Generate filename for this AXFRSet.
+ """
+
+ return "%d.ax.v%d" % (self.serial, self.version)
+
+ @classmethod
+ def load_current(cls, version):
+ """
+ Load current AXFRSet. Return None if can't.
+ """
+
+ serial = rpki.rtr.server.read_current(version)[0]
+ if serial is None:
+ return None
+ try:
+ return cls.load("%d.ax.v%d" % (serial, version))
+ except IOError:
+ return None
+
+ def save_axfr(self):
+ """
+ Write AXFRSet to file with magic filename.
+ """
+
+ f = open(self.filename(), "wb")
+ for p in self:
+ f.write(p.to_pdu())
+ f.close()
+
+ def destroy_old_data(self):
+ """
+ Destroy old data files, presumably because our nonce changed and
+ the old serial numbers are no longer valid.
+ """
+
+ for i in glob.iglob("*.ix.*.v%d" % self.version):
+ os.unlink(i)
+ for i in glob.iglob("*.ax.v%d" % self.version):
+ if i != self.filename():
+ os.unlink(i)
+
+ @staticmethod
+ def new_nonce(force_zero_nonce):
+ """
+ Create and return a new nonce value.
+ """
+
+ if force_zero_nonce:
+ return 0
+ try:
+ return int(random.SystemRandom().getrandbits(16))
+ except NotImplementedError:
+ return int(random.getrandbits(16))
+
+ def mark_current(self, force_zero_nonce = False):
+ """
+ Save current serial number and nonce, creating new nonce if
+ necessary. Creating a new nonce triggers cleanup of old state, as
+ the new nonce invalidates all old serial numbers.
+ """
+
+ assert self.version in rpki.rtr.pdus.PDU.version_map
+ old_serial, nonce = rpki.rtr.server.read_current(self.version)
+ if old_serial is None or self.seq_ge(old_serial, self.serial):
+ logging.debug("Creating new nonce and deleting stale data")
+ nonce = self.new_nonce(force_zero_nonce)
+ self.destroy_old_data()
+ rpki.rtr.server.write_current(self.serial, nonce, self.version)
+
+ def save_ixfr(self, other):
+ """
+ Comparing this AXFRSet with an older one and write the resulting
+ IXFRSet to file with magic filename. Since we store PDUSets
+ in sorted order, computing the difference is a trivial linear
+ comparison.
+ """
+
+ f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb")
+ old = other
+ new = self
+ len_old = len(old)
+ len_new = len(new)
+ i_old = i_new = 0
+ while i_old < len_old and i_new < len_new:
+ if old[i_old] < new[i_new]:
+ f.write(old[i_old].to_pdu(announce = 0))
+ i_old += 1
+ elif old[i_old] > new[i_new]:
+ f.write(new[i_new].to_pdu(announce = 1))
+ i_new += 1
+ else:
+ i_old += 1
+ i_new += 1
+ for i in xrange(i_old, len_old):
+ f.write(old[i].to_pdu(announce = 0))
+ for i in xrange(i_new, len_new):
+ f.write(new[i].to_pdu(announce = 1))
+ f.close()
+
+ def show(self):
+ """
+ Print this AXFRSet.
+ """
+
+ logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version)
+ for p in self:
+ logging.debug(p)
- return "%d.ax.v%d" % (self.serial, self.version)
- @classmethod
- def load_current(cls, version):
+class IXFRSet(PDUSet):
"""
- Load current AXFRSet. Return None if can't.
+ Object representing an incremental set of PDUs, that is, the
+ differences between one versioned and (theoretically) consistant set
+ of prefixes and router certificates extracted from rcynic's output
+ and another, with the announce fields set or cleared as necessary to
+ indicate the changes.
"""
- serial = rpki.rtr.server.read_current(version)[0]
- if serial is None:
- return None
- try:
- return cls.load("%d.ax.v%d" % (serial, version))
- except IOError:
- return None
+ @classmethod
+ def load(cls, filename):
+ """
+ Load an IXFRSet from a file, parse filename to obtain version and serials.
+ """
- def save_axfr(self):
- """
- Write AXFRSet to file with magic filename.
- """
+ fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".")
+ assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit()
+ version = int(fn4[1:])
+ self = cls._load_file(filename, version)
+ self.from_serial = rpki.rtr.channels.Timestamp(fn3)
+ self.to_serial = rpki.rtr.channels.Timestamp(fn1)
+ return self
- f = open(self.filename(), "wb")
- for p in self:
- f.write(p.to_pdu())
- f.close()
+ def filename(self):
+ """
+ Generate filename for this IXFRSet.
+ """
- def destroy_old_data(self):
- """
- Destroy old data files, presumably because our nonce changed and
- the old serial numbers are no longer valid.
- """
+ return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version)
- for i in glob.iglob("*.ix.*.v%d" % self.version):
- os.unlink(i)
- for i in glob.iglob("*.ax.v%d" % self.version):
- if i != self.filename():
- os.unlink(i)
+ def show(self):
+ """
+ Print this IXFRSet.
+ """
- @staticmethod
- def new_nonce(force_zero_nonce):
- """
- Create and return a new nonce value.
- """
+ logging.debug("# IXFR %d (%s) -> %d (%s) v%d",
+ self.from_serial, self.from_serial,
+ self.to_serial, self.to_serial,
+ self.version)
+ for p in self:
+ logging.debug(p)
- if force_zero_nonce:
- return 0
- try:
- return int(random.SystemRandom().getrandbits(16))
- except NotImplementedError:
- return int(random.getrandbits(16))
- def mark_current(self, force_zero_nonce = False):
+def kick_all(serial):
"""
- Save current serial number and nonce, creating new nonce if
- necessary. Creating a new nonce triggers cleanup of old state, as
- the new nonce invalidates all old serial numbers.
+ Kick any existing server processes to wake them up.
"""
- assert self.version in rpki.rtr.pdus.PDU.version_map
- old_serial, nonce = rpki.rtr.server.read_current(self.version)
- if old_serial is None or self.seq_ge(old_serial, self.serial):
- logging.debug("Creating new nonce and deleting stale data")
- nonce = self.new_nonce(force_zero_nonce)
- self.destroy_old_data()
- rpki.rtr.server.write_current(self.serial, nonce, self.version)
+ try:
+ os.stat(rpki.rtr.server.kickme_dir)
+ except OSError:
+ logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir)
+ os.makedirs(rpki.rtr.server.kickme_dir)
+
+ msg = "Good morning, serial %d is ready" % serial
+ sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base):
+ try:
+ logging.debug("# Kicking %s", name)
+ sock.sendto(msg, name)
+ except socket.error:
+ try:
+ logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name)
+ os.unlink(name)
+ except Exception, e:
+ logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e)
+ except Exception, e:
+ logging.warning("# Failed to kick %s and don't understand why: %s", name, e)
+ sock.close()
- def save_ixfr(self, other):
- """
- Comparing this AXFRSet with an older one and write the resulting
- IXFRSet to file with magic filename. Since we store PDUSets
- in sorted order, computing the difference is a trivial linear
- comparison.
- """
- f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb")
- old = other
- new = self
- len_old = len(old)
- len_new = len(new)
- i_old = i_new = 0
- while i_old < len_old and i_new < len_new:
- if old[i_old] < new[i_new]:
- f.write(old[i_old].to_pdu(announce = 0))
- i_old += 1
- elif old[i_old] > new[i_new]:
- f.write(new[i_new].to_pdu(announce = 1))
- i_new += 1
- else:
- i_old += 1
- i_new += 1
- for i in xrange(i_old, len_old):
- f.write(old[i].to_pdu(announce = 0))
- for i in xrange(i_new, len_new):
- f.write(new[i].to_pdu(announce = 1))
- f.close()
-
- def show(self):
- """
- Print this AXFRSet.
+def cronjob_main(args):
"""
-
- logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version)
- for p in self:
- logging.debug(p)
+ Run this right after running rcynic to wade through the ROAs and
+ router certificates that rcynic collects and translate that data
+ into the form used in the rpki-router protocol. Output is an
+ updated database containing both full dumps (AXFR) and incremental
+ dumps against a specific prior version (IXFR). After updating the
+ database, kicks any active servers, so that they can notify their
+ clients that a new version is available.
+ """
+
+ if args.rpki_rtr_dir:
+ try:
+ if not os.path.isdir(args.rpki_rtr_dir):
+ os.makedirs(args.rpki_rtr_dir)
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ logging.critical(str(e))
+ sys.exit(1)
+
+ for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True):
+
+ logging.debug("# Generating updates for protocol version %d", version)
+
+ old_ixfrs = glob.glob("*.ix.*.v%d" % version)
+
+ current = rpki.rtr.server.read_current(version)[0]
+ cutoff = Timestamp.now(-(24 * 60 * 60))
+ for f in glob.iglob("*.ax.v%d" % version):
+ t = Timestamp(int(f.split(".")[0]))
+ if t < cutoff and t != current:
+ logging.debug("# Deleting old file %s, timestamp %s", f, t)
+ os.unlink(f)
+
+ pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts)
+ if pdus == rpki.rtr.generator.AXFRSet.load_current(version):
+ logging.debug("# No change, new serial not needed")
+ continue
+ pdus.save_axfr()
+ for axfr in glob.iglob("*.ax.v%d" % version):
+ if axfr != pdus.filename():
+ pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr))
+ pdus.mark_current(args.force_zero_nonce)
+
+ logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial)
+
+ rpki.rtr.generator.kick_all(pdus.serial)
+
+ old_ixfrs.sort()
+ for ixfr in old_ixfrs:
+ try:
+ logging.debug("# Deleting old file %s", ixfr)
+ os.unlink(ixfr)
+ except OSError:
+ pass
-class IXFRSet(PDUSet):
- """
- Object representing an incremental set of PDUs, that is, the
- differences between one versioned and (theoretically) consistant set
- of prefixes and router certificates extracted from rcynic's output
- and another, with the announce fields set or cleared as necessary to
- indicate the changes.
- """
-
- @classmethod
- def load(cls, filename):
+def show_main(args):
"""
- Load an IXFRSet from a file, parse filename to obtain version and serials.
+ Display current rpki-rtr server database in textual form.
"""
- fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".")
- assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit()
- version = int(fn4[1:])
- self = cls._load_file(filename, version)
- self.from_serial = rpki.rtr.channels.Timestamp(fn3)
- self.to_serial = rpki.rtr.channels.Timestamp(fn1)
- return self
+ if args.rpki_rtr_dir:
+ try:
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ sys.exit(e)
- def filename(self):
- """
- Generate filename for this IXFRSet.
- """
+ g = glob.glob("*.ax.v*")
+ g.sort()
+ for f in g:
+ rpki.rtr.generator.AXFRSet.load(f).show()
- return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version)
+ g = glob.glob("*.ix.*.v*")
+ g.sort()
+ for f in g:
+ rpki.rtr.generator.IXFRSet.load(f).show()
- def show(self):
+def argparse_setup(subparsers):
"""
- Print this IXFRSet.
+ Set up argparse stuff for commands in this module.
"""
- logging.debug("# IXFR %d (%s) -> %d (%s) v%d",
- self.from_serial, self.from_serial,
- self.to_serial, self.to_serial,
- self.version)
- for p in self:
- logging.debug(p)
-
+ subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__,
+ help = "Generate RPKI-RTR database from rcynic output")
+ subparser.set_defaults(func = cronjob_main, default_log_to = "syslog")
+ subparser.add_argument("--scan-roas", help = "specify an external scan_roas program")
+ subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program")
+ subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero")
+ subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
-def kick_all(serial):
- """
- Kick any existing server processes to wake them up.
- """
-
- try:
- os.stat(rpki.rtr.server.kickme_dir)
- except OSError:
- logging.debug('# Creating directory "%s"', rpki.rtr.server.kickme_dir)
- os.makedirs(rpki.rtr.server.kickme_dir)
-
- msg = "Good morning, serial %d is ready" % serial
- sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
- for name in glob.iglob("%s.*" % rpki.rtr.server.kickme_base):
- try:
- logging.debug("# Kicking %s", name)
- sock.sendto(msg, name)
- except socket.error:
- try:
- logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name)
- os.unlink(name)
- except Exception, e:
- logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e)
- except Exception, e:
- logging.warning("# Failed to kick %s and don't understand why: %s", name, e)
- sock.close()
-
-
-def cronjob_main(args):
- """
- Run this right after running rcynic to wade through the ROAs and
- router certificates that rcynic collects and translate that data
- into the form used in the rpki-router protocol. Output is an
- updated database containing both full dumps (AXFR) and incremental
- dumps against a specific prior version (IXFR). After updating the
- database, kicks any active servers, so that they can notify their
- clients that a new version is available.
- """
-
- if args.rpki_rtr_dir:
- try:
- if not os.path.isdir(args.rpki_rtr_dir):
- os.makedirs(args.rpki_rtr_dir)
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- logging.critical(str(e))
- sys.exit(1)
-
- for version in sorted(rpki.rtr.server.PDU.version_map.iterkeys(), reverse = True):
-
- logging.debug("# Generating updates for protocol version %d", version)
-
- old_ixfrs = glob.glob("*.ix.*.v%d" % version)
-
- current = rpki.rtr.server.read_current(version)[0]
- cutoff = Timestamp.now(-(24 * 60 * 60))
- for f in glob.iglob("*.ax.v%d" % version):
- t = Timestamp(int(f.split(".")[0]))
- if t < cutoff and t != current:
- logging.debug("# Deleting old file %s, timestamp %s", f, t)
- os.unlink(f)
-
- pdus = rpki.rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts)
- if pdus == rpki.rtr.generator.AXFRSet.load_current(version):
- logging.debug("# No change, new serial not needed")
- continue
- pdus.save_axfr()
- for axfr in glob.iglob("*.ax.v%d" % version):
- if axfr != pdus.filename():
- pdus.save_ixfr(rpki.rtr.generator.AXFRSet.load(axfr))
- pdus.mark_current(args.force_zero_nonce)
-
- logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial)
-
- rpki.rtr.generator.kick_all(pdus.serial)
-
- old_ixfrs.sort()
- for ixfr in old_ixfrs:
- try:
- logging.debug("# Deleting old file %s", ixfr)
- os.unlink(ixfr)
- except OSError:
- pass
-
-
-def show_main(args):
- """
- Display current rpki-rtr server database in textual form.
- """
-
- if args.rpki_rtr_dir:
- try:
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- sys.exit(e)
-
- g = glob.glob("*.ax.v*")
- g.sort()
- for f in g:
- rpki.rtr.generator.AXFRSet.load(f).show()
-
- g = glob.glob("*.ix.*.v*")
- g.sort()
- for f in g:
- rpki.rtr.generator.IXFRSet.load(f).show()
-
-def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__,
- help = "Generate RPKI-RTR database from rcynic output")
- subparser.set_defaults(func = cronjob_main, default_log_to = "syslog")
- subparser.add_argument("--scan-roas", help = "specify an external scan_roas program")
- subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program")
- subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero")
- subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
-
- subparser = subparsers.add_parser("show", description = show_main.__doc__,
- help = "Display content of RPKI-RTR database")
- subparser.set_defaults(func = show_main, default_log_to = "stderr")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+ subparser = subparsers.add_parser("show", description = show_main.__doc__,
+ help = "Display content of RPKI-RTR database")
+ subparser.set_defaults(func = show_main, default_log_to = "stderr")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
diff --git a/rpki/rtr/main.py b/rpki/rtr/main.py
index 12de30cc..34f5598d 100644
--- a/rpki/rtr/main.py
+++ b/rpki/rtr/main.py
@@ -31,64 +31,64 @@ import argparse
class Formatter(logging.Formatter):
- converter = time.gmtime
+ converter = time.gmtime
- def __init__(self, debug, fmt, datefmt):
- self.debug = debug
- super(Formatter, self).__init__(fmt, datefmt)
+ def __init__(self, debug, fmt, datefmt):
+ self.debug = debug
+ super(Formatter, self).__init__(fmt, datefmt)
- def format(self, record):
- if getattr(record, "connection", None) is None:
- record.connection = ""
- return super(Formatter, self).format(record)
+ def format(self, record):
+ if getattr(record, "connection", None) is None:
+ record.connection = ""
+ return super(Formatter, self).format(record)
- def formatException(self, ei):
- if self.debug:
- return super(Formatter, self).formatException(ei)
- else:
- return str(ei[1])
+ def formatException(self, ei):
+ if self.debug:
+ return super(Formatter, self).formatException(ei)
+ else:
+ return str(ei[1])
def main():
- os.environ["TZ"] = "UTC"
- time.tzset()
-
- from rpki.rtr.server import argparse_setup as argparse_setup_server
- from rpki.rtr.client import argparse_setup as argparse_setup_client
- from rpki.rtr.generator import argparse_setup as argparse_setup_generator
-
- if "rpki.rtr.bgpdump" in sys.modules:
- from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump
- else:
- def argparse_setup_bgpdump(ignored):
- pass
-
- argparser = argparse.ArgumentParser(description = __doc__)
- argparser.add_argument("--debug", action = "store_true", help = "debugging mode")
- argparser.add_argument("--log-level", default = "debug",
- choices = ("debug", "info", "warning", "error", "critical"),
- type = lambda s: s.lower())
- argparser.add_argument("--log-to",
- choices = ("syslog", "stderr"))
- subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode")
- argparse_setup_server(subparsers)
- argparse_setup_client(subparsers)
- argparse_setup_generator(subparsers)
- argparse_setup_bgpdump(subparsers)
- args = argparser.parse_args()
-
- fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s"
-
- if (args.log_to or args.default_log_to) == "stderr":
- handler = logging.StreamHandler()
- fmt = "%(asctime)s " + fmt
- elif os.path.exists("/dev/log"):
- handler = logging.handlers.SysLogHandler("/dev/log")
- else:
- handler = logging.handlers.SysLogHandler()
-
- handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ"))
- logging.root.addHandler(handler)
- logging.root.setLevel(int(getattr(logging, args.log_level.upper())))
-
- return args.func(args)
+ os.environ["TZ"] = "UTC"
+ time.tzset()
+
+ from rpki.rtr.server import argparse_setup as argparse_setup_server
+ from rpki.rtr.client import argparse_setup as argparse_setup_client
+ from rpki.rtr.generator import argparse_setup as argparse_setup_generator
+
+ if "rpki.rtr.bgpdump" in sys.modules:
+ from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump
+ else:
+ def argparse_setup_bgpdump(ignored):
+ pass
+
+ argparser = argparse.ArgumentParser(description = __doc__)
+ argparser.add_argument("--debug", action = "store_true", help = "debugging mode")
+ argparser.add_argument("--log-level", default = "debug",
+ choices = ("debug", "info", "warning", "error", "critical"),
+ type = lambda s: s.lower())
+ argparser.add_argument("--log-to",
+ choices = ("syslog", "stderr"))
+ subparsers = argparser.add_subparsers(title = "Commands", metavar = "", dest = "mode")
+ argparse_setup_server(subparsers)
+ argparse_setup_client(subparsers)
+ argparse_setup_generator(subparsers)
+ argparse_setup_bgpdump(subparsers)
+ args = argparser.parse_args()
+
+ fmt = "rpki-rtr/" + args.mode + "%(connection)s[%(process)d] %(message)s"
+
+ if (args.log_to or args.default_log_to) == "stderr":
+ handler = logging.StreamHandler()
+ fmt = "%(asctime)s " + fmt
+ elif os.path.exists("/dev/log"):
+ handler = logging.handlers.SysLogHandler("/dev/log")
+ else:
+ handler = logging.handlers.SysLogHandler()
+
+ handler.setFormatter(Formatter(args.debug, fmt, "%Y-%m-%dT%H:%M:%SZ"))
+ logging.root.addHandler(handler)
+ logging.root.setLevel(int(getattr(logging, args.log_level.upper())))
+
+ return args.func(args)
diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py
index 0d2e5928..94f579a1 100644
--- a/rpki/rtr/pdus.py
+++ b/rpki/rtr/pdus.py
@@ -28,292 +28,292 @@ import rpki.POW
# Exceptions
class PDUException(Exception):
- """
- Parent exception type for exceptions that signal particular protocol
- errors. String value of exception instance will be the message to
- put in the ErrorReportPDU, error_report_code value of exception
- will be the numeric code to use.
- """
-
- def __init__(self, msg = None, pdu = None):
- super(PDUException, self).__init__()
- assert msg is None or isinstance(msg, (str, unicode))
- self.error_report_msg = msg
- self.error_report_pdu = pdu
-
- def __str__(self):
- return self.error_report_msg or self.__class__.__name__
-
- def make_error_report(self, version):
- return ErrorReportPDU(version = version,
- errno = self.error_report_code,
- errmsg = self.error_report_msg,
- errpdu = self.error_report_pdu)
+ """
+ Parent exception type for exceptions that signal particular protocol
+ errors. String value of exception instance will be the message to
+ put in the ErrorReportPDU, error_report_code value of exception
+ will be the numeric code to use.
+ """
+
+ def __init__(self, msg = None, pdu = None):
+ super(PDUException, self).__init__()
+ assert msg is None or isinstance(msg, (str, unicode))
+ self.error_report_msg = msg
+ self.error_report_pdu = pdu
+
+ def __str__(self):
+ return self.error_report_msg or self.__class__.__name__
+
+ def make_error_report(self, version):
+ return ErrorReportPDU(version = version,
+ errno = self.error_report_code,
+ errmsg = self.error_report_msg,
+ errpdu = self.error_report_pdu)
class UnsupportedProtocolVersion(PDUException):
- error_report_code = 4
+ error_report_code = 4
class UnsupportedPDUType(PDUException):
- error_report_code = 5
+ error_report_code = 5
class CorruptData(PDUException):
- error_report_code = 0
+ error_report_code = 0
# Decorators
def wire_pdu(cls, versions = None):
- """
- Class decorator to add a PDU class to the set of known PDUs
- for all supported protocol versions.
- """
+ """
+ Class decorator to add a PDU class to the set of known PDUs
+ for all supported protocol versions.
+ """
- for v in PDU.version_map.iterkeys() if versions is None else versions:
- assert cls.pdu_type not in PDU.version_map[v]
- PDU.version_map[v][cls.pdu_type] = cls
- return cls
+ for v in PDU.version_map.iterkeys() if versions is None else versions:
+ assert cls.pdu_type not in PDU.version_map[v]
+ PDU.version_map[v][cls.pdu_type] = cls
+ return cls
def wire_pdu_only(*versions):
- """
- Class decorator to add a PDU class to the set of known PDUs
- for specific protocol versions.
- """
+ """
+ Class decorator to add a PDU class to the set of known PDUs
+ for specific protocol versions.
+ """
- assert versions and all(v in PDU.version_map for v in versions)
- return lambda cls: wire_pdu(cls, versions)
+ assert versions and all(v in PDU.version_map for v in versions)
+ return lambda cls: wire_pdu(cls, versions)
def clone_pdu_root(root_pdu_class):
- """
- Replace a PDU root class's version_map with a two-level deep copy of itself,
- and return a class decorator which subclasses can use to replace their
- parent classes with themselves in the resulting cloned version map.
+ """
+ Replace a PDU root class's version_map with a two-level deep copy of itself,
+ and return a class decorator which subclasses can use to replace their
+ parent classes with themselves in the resulting cloned version map.
- This function is not itself a decorator, it returns one.
- """
+ This function is not itself a decorator, it returns one.
+ """
- root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems())
+ root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems())
- def decorator(cls):
- for pdu_map in root_pdu_class.version_map.itervalues():
- for pdu_type, pdu_class in pdu_map.items():
- if pdu_class in cls.__bases__:
- pdu_map[pdu_type] = cls
- return cls
+ def decorator(cls):
+ for pdu_map in root_pdu_class.version_map.itervalues():
+ for pdu_type, pdu_class in pdu_map.items():
+ if pdu_class in cls.__bases__:
+ pdu_map[pdu_type] = cls
+ return cls
- return decorator
+ return decorator
# PDUs
class PDU(object):
- """
- Base PDU. Real PDUs are subclasses of this class.
- """
+ """
+ Base PDU. Real PDUs are subclasses of this class.
+ """
- version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu
+ version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu
- _pdu = None # Cached when first generated
+ _pdu = None # Cached when first generated
- header_struct = struct.Struct("!BB2xL")
+ header_struct = struct.Struct("!BB2xL")
- def __init__(self, version):
- assert version in self.version_map
- self.version = version
+ def __init__(self, version):
+ assert version in self.version_map
+ self.version = version
- def __cmp__(self, other):
- return cmp(self.to_pdu(), other.to_pdu())
+ def __cmp__(self, other):
+ return cmp(self.to_pdu(), other.to_pdu())
- @property
- def default_version(self):
- return max(self.version_map.iterkeys())
+ @property
+ def default_version(self):
+ return max(self.version_map.iterkeys())
- def check(self):
- pass
+ def check(self):
+ pass
- @classmethod
- def read_pdu(cls, reader):
- return reader.update(need = cls.header_struct.size, callback = cls.got_header)
+ @classmethod
+ def read_pdu(cls, reader):
+ return reader.update(need = cls.header_struct.size, callback = cls.got_header)
- @classmethod
- def got_header(cls, reader):
- if not reader.ready():
- return None
- assert reader.available() >= cls.header_struct.size
- version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
- reader.check_version(version)
- if pdu_type not in cls.version_map[version]:
- raise UnsupportedPDUType(
- "Received unsupported PDU type %d" % pdu_type)
- if length < 8:
- raise CorruptData(
- "Received PDU with length %d, which is too short to be valid" % length)
- self = cls.version_map[version][pdu_type](version = version)
- return reader.update(need = length, callback = self.got_pdu)
+ @classmethod
+ def got_header(cls, reader):
+ if not reader.ready():
+ return None
+ assert reader.available() >= cls.header_struct.size
+ version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
+ reader.check_version(version)
+ if pdu_type not in cls.version_map[version]:
+ raise UnsupportedPDUType(
+ "Received unsupported PDU type %d" % pdu_type)
+ if length < 8:
+ raise CorruptData(
+ "Received PDU with length %d, which is too short to be valid" % length)
+ self = cls.version_map[version][pdu_type](version = version)
+ return reader.update(need = length, callback = self.got_pdu)
class PDUWithSerial(PDU):
- """
- Base class for PDUs consisting of just a serial number and nonce.
- """
-
- header_struct = struct.Struct("!BBHLL")
-
- def __init__(self, version, serial = None, nonce = None):
- super(PDUWithSerial, self).__init__(version)
- if serial is not None:
- assert isinstance(serial, int)
- self.serial = serial
- if nonce is not None:
- assert isinstance(nonce, int)
- self.nonce = nonce
-
- def __str__(self):
- return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce)
-
- def to_pdu(self):
"""
- Generate the wire format PDU.
+ Base class for PDUs consisting of just a serial number and nonce.
"""
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
- self.header_struct.size, self.serial)
- return self._pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if length != 12:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ header_struct = struct.Struct("!BBHLL")
+
+ def __init__(self, version, serial = None, nonce = None):
+ super(PDUWithSerial, self).__init__(version)
+ if serial is not None:
+ assert isinstance(serial, int)
+ self.serial = serial
+ if nonce is not None:
+ assert isinstance(nonce, int)
+ self.nonce = nonce
+
+ def __str__(self):
+ return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce)
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
+
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
+ self.header_struct.size, self.serial)
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 12:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
class PDUWithNonce(PDU):
- """
- Base class for PDUs consisting of just a nonce.
- """
-
- header_struct = struct.Struct("!BBHL")
-
- def __init__(self, version, nonce = None):
- super(PDUWithNonce, self).__init__(version)
- if nonce is not None:
- assert isinstance(nonce, int)
- self.nonce = nonce
-
- def __str__(self):
- return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce)
-
- def to_pdu(self):
"""
- Generate the wire format PDU.
+ Base class for PDUs consisting of just a nonce.
"""
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size)
- return self._pdu
+ header_struct = struct.Struct("!BBHL")
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if length != 8:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ def __init__(self, version, nonce = None):
+ super(PDUWithNonce, self).__init__(version)
+ if nonce is not None:
+ assert isinstance(nonce, int)
+ self.nonce = nonce
+ def __str__(self):
+ return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce)
-class PDUEmpty(PDU):
- """
- Base class for empty PDUs.
- """
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
- header_struct = struct.Struct("!BBHL")
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size)
+ return self._pdu
- def __str__(self):
- return "[%s]" % self.__class__.__name__
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 8:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
- def to_pdu(self):
+
+class PDUEmpty(PDU):
"""
- Generate the wire format PDU for this prefix.
+ Base class for empty PDUs.
"""
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size)
- return self._pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, zero, length = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if zero != 0:
- raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
- if length != 8:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ header_struct = struct.Struct("!BBHL")
+
+ def __str__(self):
+ return "[%s]" % self.__class__.__name__
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU for this prefix.
+ """
+
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size)
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, zero, length = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if zero != 0:
+ raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
+ if length != 8:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
@wire_pdu
class SerialNotifyPDU(PDUWithSerial):
- """
- Serial Notify PDU.
- """
+ """
+ Serial Notify PDU.
+ """
- pdu_type = 0
+ pdu_type = 0
@wire_pdu
class SerialQueryPDU(PDUWithSerial):
- """
- Serial Query PDU.
- """
+ """
+ Serial Query PDU.
+ """
- pdu_type = 1
+ pdu_type = 1
- def __init__(self, version, serial = None, nonce = None):
- super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce)
+ def __init__(self, version, serial = None, nonce = None):
+ super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce)
@wire_pdu
class ResetQueryPDU(PDUEmpty):
- """
- Reset Query PDU.
- """
+ """
+ Reset Query PDU.
+ """
- pdu_type = 2
+ pdu_type = 2
- def __init__(self, version):
- super(ResetQueryPDU, self).__init__(self.default_version if version is None else version)
+ def __init__(self, version):
+ super(ResetQueryPDU, self).__init__(self.default_version if version is None else version)
@wire_pdu
class CacheResponsePDU(PDUWithNonce):
- """
- Cache Response PDU.
- """
+ """
+ Cache Response PDU.
+ """
- pdu_type = 3
+ pdu_type = 3
def EndOfDataPDU(version, *args, **kwargs):
- """
- Factory for the EndOfDataPDU classes, which take different forms in
- different protocol versions.
- """
+ """
+ Factory for the EndOfDataPDU classes, which take different forms in
+ different protocol versions.
+ """
- if version == 0:
- return EndOfDataPDUv0(version, *args, **kwargs)
- if version == 1:
- return EndOfDataPDUv1(version, *args, **kwargs)
- raise NotImplementedError
+ if version == 0:
+ return EndOfDataPDUv0(version, *args, **kwargs)
+ if version == 1:
+ return EndOfDataPDUv1(version, *args, **kwargs)
+ raise NotImplementedError
# Min, max, and default values, from the current RFC 6810 bis I-D.
@@ -324,325 +324,325 @@ def EndOfDataPDU(version, *args, **kwargs):
default_refresh = 3600
def valid_refresh(refresh):
- if not isinstance(refresh, int) or refresh < 120 or refresh > 86400:
- raise ValueError
- return refresh
+ if not isinstance(refresh, int) or refresh < 120 or refresh > 86400:
+ raise ValueError
+ return refresh
default_retry = 600
def valid_retry(retry):
- if not isinstance(retry, int) or retry < 120 or retry > 7200:
- raise ValueError
- return retry
+ if not isinstance(retry, int) or retry < 120 or retry > 7200:
+ raise ValueError
+ return retry
default_expire = 7200
def valid_expire(expire):
- if not isinstance(expire, int) or expire < 600 or expire > 172800:
- raise ValueError
- return expire
+ if not isinstance(expire, int) or expire < 600 or expire > 172800:
+ raise ValueError
+ return expire
@wire_pdu_only(0)
class EndOfDataPDUv0(PDUWithSerial):
- """
- End of Data PDU, protocol version 0.
- """
+ """
+ End of Data PDU, protocol version 0.
+ """
- pdu_type = 7
+ pdu_type = 7
- def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None):
- super(EndOfDataPDUv0, self).__init__(version, serial, nonce)
- self.refresh = valid_refresh(default_refresh if refresh is None else refresh)
- self.retry = valid_retry( default_retry if retry is None else retry)
- self.expire = valid_expire( default_expire if expire is None else expire)
+ def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None):
+ super(EndOfDataPDUv0, self).__init__(version, serial, nonce)
+ self.refresh = valid_refresh(default_refresh if refresh is None else refresh)
+ self.retry = valid_retry( default_retry if retry is None else retry)
+ self.expire = valid_expire( default_expire if expire is None else expire)
@wire_pdu_only(1)
class EndOfDataPDUv1(EndOfDataPDUv0):
- """
- End of Data PDU, protocol version 1.
- """
+ """
+ End of Data PDU, protocol version 1.
+ """
- header_struct = struct.Struct("!BBHLLLLL")
+ header_struct = struct.Struct("!BBHLLLLL")
- def __str__(self):
- return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % (
- self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire)
+ def __str__(self):
+ return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % (
+ self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire)
- def to_pdu(self):
- """
- Generate the wire format PDU.
- """
+ def to_pdu(self):
+ """
+ Generate the wire format PDU.
+ """
- if self._pdu is None:
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
- self.header_struct.size, self.serial,
- self.refresh, self.retry, self.expire)
- return self._pdu
+ if self._pdu is None:
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce,
+ self.header_struct.size, self.serial,
+ self.refresh, self.retry, self.expire)
+ return self._pdu
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b = reader.get(self.header_struct.size)
- version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \
- = self.header_struct.unpack(b)
- assert version == self.version and pdu_type == self.pdu_type
- if length != 24:
- raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
- assert b == self.to_pdu()
- return self
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b = reader.get(self.header_struct.size)
+ version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \
+ = self.header_struct.unpack(b)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != 24:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
+ assert b == self.to_pdu()
+ return self
@wire_pdu
class CacheResetPDU(PDUEmpty):
- """
- Cache reset PDU.
- """
+ """
+ Cache reset PDU.
+ """
- pdu_type = 8
+ pdu_type = 8
class PrefixPDU(PDU):
- """
- Object representing one prefix. This corresponds closely to one PDU
- in the rpki-router protocol, so closely that we use lexical ordering
- of the wire format of the PDU as the ordering for this class.
-
- This is a virtual class, but the .from_text() constructor
- instantiates the correct concrete subclass (IPv4PrefixPDU or
- IPv6PrefixPDU) depending on the syntax of its input text.
- """
-
- header_struct = struct.Struct("!BB2xLBBBx")
- asnum_struct = struct.Struct("!L")
-
- def __str__(self):
- plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen)
- return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm,
- ":".join(("%02X" % ord(b) for b in self.to_pdu())))
-
- def show(self):
- logging.debug("# Class: %s", self.__class__.__name__)
- logging.debug("# ASN: %s", self.asn)
- logging.debug("# Prefix: %s", self.prefix)
- logging.debug("# Prefixlen: %s", self.prefixlen)
- logging.debug("# MaxPrefixlen: %s", self.max_prefixlen)
- logging.debug("# Announce: %s", self.announce)
-
- def check(self):
- """
- Check attributes to make sure they're within range.
- """
-
- if self.announce not in (0, 1):
- raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
- if self.prefix.bits != self.address_byte_count * 8:
- raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self)
- if self.prefixlen < 0 or self.prefixlen > self.prefix.bits:
- raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self)
- if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits:
- raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self)
- pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
- if len(self.to_pdu()) != pdulen:
- raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
-
- def to_pdu(self, announce = None):
- """
- Generate the wire format PDU for this prefix.
- """
-
- if announce is not None:
- assert announce in (0, 1)
- elif self._pdu is not None:
- return self._pdu
- pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
- pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen,
- announce if announce is not None else self.announce,
- self.prefixlen, self.max_prefixlen) +
- self.prefix.toBytes() +
- self.asnum_struct.pack(self.asn))
- if announce is None:
- assert self._pdu is None
- self._pdu = pdu
- return pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- b1 = reader.get(self.header_struct.size)
- b2 = reader.get(self.address_byte_count)
- b3 = reader.get(self.asnum_struct.size)
- version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
- assert version == self.version and pdu_type == self.pdu_type
- if length != len(b1) + len(b2) + len(b3):
- raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
- self.prefix = rpki.POW.IPAddress.fromBytes(b2)
- self.asn = self.asnum_struct.unpack(b3)[0]
- assert b1 + b2 + b3 == self.to_pdu()
- return self
+ """
+ Object representing one prefix. This corresponds closely to one PDU
+ in the rpki-router protocol, so closely that we use lexical ordering
+ of the wire format of the PDU as the ordering for this class.
+
+ This is a virtual class, but the .from_text() constructor
+ instantiates the correct concrete subclass (IPv4PrefixPDU or
+ IPv6PrefixPDU) depending on the syntax of its input text.
+ """
+
+ header_struct = struct.Struct("!BB2xLBBBx")
+ asnum_struct = struct.Struct("!L")
+
+ def __str__(self):
+ plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen)
+ return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm,
+ ":".join(("%02X" % ord(b) for b in self.to_pdu())))
+
+ def show(self):
+ logging.debug("# Class: %s", self.__class__.__name__)
+ logging.debug("# ASN: %s", self.asn)
+ logging.debug("# Prefix: %s", self.prefix)
+ logging.debug("# Prefixlen: %s", self.prefixlen)
+ logging.debug("# MaxPrefixlen: %s", self.max_prefixlen)
+ logging.debug("# Announce: %s", self.announce)
+
+ def check(self):
+ """
+ Check attributes to make sure they're within range.
+ """
+
+ if self.announce not in (0, 1):
+ raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
+ if self.prefix.bits != self.address_byte_count * 8:
+ raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self)
+ if self.prefixlen < 0 or self.prefixlen > self.prefix.bits:
+ raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self)
+ if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits:
+ raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self)
+ pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
+ if len(self.to_pdu()) != pdulen:
+ raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
+
+ def to_pdu(self, announce = None):
+ """
+ Generate the wire format PDU for this prefix.
+ """
+
+ if announce is not None:
+ assert announce in (0, 1)
+ elif self._pdu is not None:
+ return self._pdu
+ pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size
+ pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen,
+ announce if announce is not None else self.announce,
+ self.prefixlen, self.max_prefixlen) +
+ self.prefix.toBytes() +
+ self.asnum_struct.pack(self.asn))
+ if announce is None:
+ assert self._pdu is None
+ self._pdu = pdu
+ return pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ b1 = reader.get(self.header_struct.size)
+ b2 = reader.get(self.address_byte_count)
+ b3 = reader.get(self.asnum_struct.size)
+ version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
+ assert version == self.version and pdu_type == self.pdu_type
+ if length != len(b1) + len(b2) + len(b3):
+ raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
+ self.prefix = rpki.POW.IPAddress.fromBytes(b2)
+ self.asn = self.asnum_struct.unpack(b3)[0]
+ assert b1 + b2 + b3 == self.to_pdu()
+ return self
@wire_pdu
class IPv4PrefixPDU(PrefixPDU):
- """
- IPv4 flavor of a prefix.
- """
+ """
+ IPv4 flavor of a prefix.
+ """
- pdu_type = 4
- address_byte_count = 4
+ pdu_type = 4
+ address_byte_count = 4
@wire_pdu
class IPv6PrefixPDU(PrefixPDU):
- """
- IPv6 flavor of a prefix.
- """
+ """
+ IPv6 flavor of a prefix.
+ """
- pdu_type = 6
- address_byte_count = 16
+ pdu_type = 6
+ address_byte_count = 16
@wire_pdu_only(1)
class RouterKeyPDU(PDU):
- """
- Router Key PDU.
- """
-
- pdu_type = 9
-
- header_struct = struct.Struct("!BBBxL20sL")
-
- def __str__(self):
- return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn,
- base64.urlsafe_b64encode(self.ski).rstrip("="),
- ":".join(("%02X" % ord(b) for b in self.to_pdu())))
-
- def check(self):
- """
- Check attributes to make sure they're within range.
- """
-
- if self.announce not in (0, 1):
- raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
- if len(self.ski) != 20:
- raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self)
- pdulen = self.header_struct.size + len(self.key)
- if len(self.to_pdu()) != pdulen:
- raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
-
- def to_pdu(self, announce = None):
- if announce is not None:
- assert announce in (0, 1)
- elif self._pdu is not None:
- return self._pdu
- pdulen = self.header_struct.size + len(self.key)
- pdu = (self.header_struct.pack(self.version,
- self.pdu_type,
- announce if announce is not None else self.announce,
- pdulen,
- self.ski,
- self.asn)
- + self.key)
- if announce is None:
- assert self._pdu is None
- self._pdu = pdu
- return pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- header = reader.get(self.header_struct.size)
- version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header)
- assert version == self.version and pdu_type == self.pdu_type
- remaining = length - self.header_struct.size
- if remaining <= 0:
- raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self)
- self.key = reader.get(remaining)
- assert header + self.key == self.to_pdu()
- return self
+ """
+ Router Key PDU.
+ """
+
+ pdu_type = 9
+
+ header_struct = struct.Struct("!BBBxL20sL")
+
+ def __str__(self):
+ return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn,
+ base64.urlsafe_b64encode(self.ski).rstrip("="),
+ ":".join(("%02X" % ord(b) for b in self.to_pdu())))
+
+ def check(self):
+ """
+ Check attributes to make sure they're within range.
+ """
+
+ if self.announce not in (0, 1):
+ raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
+ if len(self.ski) != 20:
+ raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self)
+ pdulen = self.header_struct.size + len(self.key)
+ if len(self.to_pdu()) != pdulen:
+ raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
+
+ def to_pdu(self, announce = None):
+ if announce is not None:
+ assert announce in (0, 1)
+ elif self._pdu is not None:
+ return self._pdu
+ pdulen = self.header_struct.size + len(self.key)
+ pdu = (self.header_struct.pack(self.version,
+ self.pdu_type,
+ announce if announce is not None else self.announce,
+ pdulen,
+ self.ski,
+ self.asn)
+ + self.key)
+ if announce is None:
+ assert self._pdu is None
+ self._pdu = pdu
+ return pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ header = reader.get(self.header_struct.size)
+ version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
+ remaining = length - self.header_struct.size
+ if remaining <= 0:
+ raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self)
+ self.key = reader.get(remaining)
+ assert header + self.key == self.to_pdu()
+ return self
@wire_pdu
class ErrorReportPDU(PDU):
- """
- Error Report PDU.
- """
-
- pdu_type = 10
-
- header_struct = struct.Struct("!BBHL")
- string_struct = struct.Struct("!L")
-
- errors = {
- 2 : "No Data Available" }
-
- fatal = {
- 0 : "Corrupt Data",
- 1 : "Internal Error",
- 3 : "Invalid Request",
- 4 : "Unsupported Protocol Version",
- 5 : "Unsupported PDU Type",
- 6 : "Withdrawal of Unknown Record",
- 7 : "Duplicate Announcement Received" }
-
- assert set(errors) & set(fatal) == set()
-
- errors.update(fatal)
-
- codes = dict((v, k) for k, v in errors.items())
-
- def __init__(self, version, errno = None, errpdu = None, errmsg = None):
- super(ErrorReportPDU, self).__init__(version)
- assert errno is None or errno in self.errors
- self.errno = errno
- self.errpdu = errpdu
- self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno]
-
- def __str__(self):
- return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg)
-
- def to_counted_string(self, s):
- return self.string_struct.pack(len(s)) + s
-
- def read_counted_string(self, reader, remaining):
- assert remaining >= self.string_struct.size
- n = self.string_struct.unpack(reader.get(self.string_struct.size))[0]
- assert remaining >= self.string_struct.size + n
- return n, reader.get(n), (remaining - self.string_struct.size - n)
-
- def to_pdu(self):
- """
- Generate the wire format PDU for this error report.
- """
-
- if self._pdu is None:
- assert isinstance(self.errno, int)
- assert not isinstance(self.errpdu, ErrorReportPDU)
- p = self.errpdu
- if p is None:
- p = ""
- elif isinstance(p, PDU):
- p = p.to_pdu()
- assert isinstance(p, str)
- pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg)
- self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen)
- self._pdu += self.to_counted_string(p)
- self._pdu += self.to_counted_string(self.errmsg.encode("utf8"))
- return self._pdu
-
- def got_pdu(self, reader):
- if not reader.ready():
- return None
- header = reader.get(self.header_struct.size)
- version, pdu_type, self.errno, length = self.header_struct.unpack(header)
- assert version == self.version and pdu_type == self.pdu_type
- remaining = length - self.header_struct.size
- self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
- self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
- if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen:
- raise CorruptData("Got PDU length %d, expected %d" % (
- length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen))
- assert (header
- + self.to_counted_string(self.errpdu)
- + self.to_counted_string(self.errmsg.encode("utf8"))
- == self.to_pdu())
- return self
+ """
+ Error Report PDU.
+ """
+
+ pdu_type = 10
+
+ header_struct = struct.Struct("!BBHL")
+ string_struct = struct.Struct("!L")
+
+ errors = {
+ 2 : "No Data Available" }
+
+ fatal = {
+ 0 : "Corrupt Data",
+ 1 : "Internal Error",
+ 3 : "Invalid Request",
+ 4 : "Unsupported Protocol Version",
+ 5 : "Unsupported PDU Type",
+ 6 : "Withdrawal of Unknown Record",
+ 7 : "Duplicate Announcement Received" }
+
+ assert set(errors) & set(fatal) == set()
+
+ errors.update(fatal)
+
+ codes = dict((v, k) for k, v in errors.items())
+
+ def __init__(self, version, errno = None, errpdu = None, errmsg = None):
+ super(ErrorReportPDU, self).__init__(version)
+ assert errno is None or errno in self.errors
+ self.errno = errno
+ self.errpdu = errpdu
+ self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno]
+
+ def __str__(self):
+ return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg)
+
+ def to_counted_string(self, s):
+ return self.string_struct.pack(len(s)) + s
+
+ def read_counted_string(self, reader, remaining):
+ assert remaining >= self.string_struct.size
+ n = self.string_struct.unpack(reader.get(self.string_struct.size))[0]
+ assert remaining >= self.string_struct.size + n
+ return n, reader.get(n), (remaining - self.string_struct.size - n)
+
+ def to_pdu(self):
+ """
+ Generate the wire format PDU for this error report.
+ """
+
+ if self._pdu is None:
+ assert isinstance(self.errno, int)
+ assert not isinstance(self.errpdu, ErrorReportPDU)
+ p = self.errpdu
+ if p is None:
+ p = ""
+ elif isinstance(p, PDU):
+ p = p.to_pdu()
+ assert isinstance(p, str)
+ pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg)
+ self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen)
+ self._pdu += self.to_counted_string(p)
+ self._pdu += self.to_counted_string(self.errmsg.encode("utf8"))
+ return self._pdu
+
+ def got_pdu(self, reader):
+ if not reader.ready():
+ return None
+ header = reader.get(self.header_struct.size)
+ version, pdu_type, self.errno, length = self.header_struct.unpack(header)
+ assert version == self.version and pdu_type == self.pdu_type
+ remaining = length - self.header_struct.size
+ self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
+ self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
+ if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen:
+ raise CorruptData("Got PDU length %d, expected %d" % (
+ length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen))
+ assert (header
+ + self.to_counted_string(self.errpdu)
+ + self.to_counted_string(self.errmsg.encode("utf8"))
+ == self.to_pdu())
+ return self
diff --git a/rpki/rtr/server.py b/rpki/rtr/server.py
index 1c7a5e78..f57c3037 100644
--- a/rpki/rtr/server.py
+++ b/rpki/rtr/server.py
@@ -44,37 +44,37 @@ kickme_base = os.path.join(kickme_dir, "kickme")
class PDU(rpki.rtr.pdus.PDU):
- """
- Generic server PDU.
- """
-
- def send_file(self, server, filename):
"""
- Send a content of a file as a cache response. Caller should catch IOError.
+ Generic server PDU.
"""
- fn2 = os.path.splitext(filename)[1]
- assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version
-
- f = open(filename, "rb")
- server.push_pdu(CacheResponsePDU(version = server.version,
- nonce = server.current_nonce))
- server.push_file(f)
- server.push_pdu(EndOfDataPDU(version = server.version,
- serial = server.current_serial,
- nonce = server.current_nonce,
- refresh = server.refresh,
- retry = server.retry,
- expire = server.expire))
-
- def send_nodata(self, server):
- """
- Send a nodata error.
- """
+ def send_file(self, server, filename):
+ """
+ Send a content of a file as a cache response. Caller should catch IOError.
+ """
+
+ fn2 = os.path.splitext(filename)[1]
+ assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version
- server.push_pdu(ErrorReportPDU(version = server.version,
- errno = ErrorReportPDU.codes["No Data Available"],
- errpdu = self))
+ f = open(filename, "rb")
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
+ server.push_file(f)
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce,
+ refresh = server.refresh,
+ retry = server.retry,
+ expire = server.expire))
+
+ def send_nodata(self, server):
+ """
+ Send a nodata error.
+ """
+
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["No Data Available"],
+ errpdu = self))
clone_pdu = clone_pdu_root(PDU)
@@ -82,512 +82,512 @@ clone_pdu = clone_pdu_root(PDU)
@clone_pdu
class SerialQueryPDU(PDU, rpki.rtr.pdus.SerialQueryPDU):
- """
- Serial Query PDU.
- """
-
- def serve(self, server):
- """
- Received a serial query, send incremental transfer in response.
- If client is already up to date, just send an empty incremental
- transfer.
"""
-
- server.logger.debug(self)
- if server.get_serial() is None:
- self.send_nodata(server)
- elif server.current_nonce != self.nonce:
- server.logger.info("[Client requested wrong nonce, resetting client]")
- server.push_pdu(CacheResetPDU(version = server.version))
- elif server.current_serial == self.serial:
- server.logger.debug("[Client is already current, sending empty IXFR]")
- server.push_pdu(CacheResponsePDU(version = server.version,
- nonce = server.current_nonce))
- server.push_pdu(EndOfDataPDU(version = server.version,
- serial = server.current_serial,
- nonce = server.current_nonce,
- refresh = server.refresh,
- retry = server.retry,
- expire = server.expire))
- elif disable_incrementals:
- server.push_pdu(CacheResetPDU(version = server.version))
- else:
- try:
- self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version))
- except IOError:
- server.push_pdu(CacheResetPDU(version = server.version))
+ Serial Query PDU.
+ """
+
+ def serve(self, server):
+ """
+ Received a serial query, send incremental transfer in response.
+ If client is already up to date, just send an empty incremental
+ transfer.
+ """
+
+ server.logger.debug(self)
+ if server.get_serial() is None:
+ self.send_nodata(server)
+ elif server.current_nonce != self.nonce:
+ server.logger.info("[Client requested wrong nonce, resetting client]")
+ server.push_pdu(CacheResetPDU(version = server.version))
+ elif server.current_serial == self.serial:
+ server.logger.debug("[Client is already current, sending empty IXFR]")
+ server.push_pdu(CacheResponsePDU(version = server.version,
+ nonce = server.current_nonce))
+ server.push_pdu(EndOfDataPDU(version = server.version,
+ serial = server.current_serial,
+ nonce = server.current_nonce,
+ refresh = server.refresh,
+ retry = server.retry,
+ expire = server.expire))
+ elif disable_incrementals:
+ server.push_pdu(CacheResetPDU(version = server.version))
+ else:
+ try:
+ self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version))
+ except IOError:
+ server.push_pdu(CacheResetPDU(version = server.version))
@clone_pdu
class ResetQueryPDU(PDU, rpki.rtr.pdus.ResetQueryPDU):
- """
- Reset Query PDU.
- """
-
- def serve(self, server):
"""
- Received a reset query, send full current state in response.
+ Reset Query PDU.
"""
- server.logger.debug(self)
- if server.get_serial() is None:
- self.send_nodata(server)
- else:
- try:
- fn = "%d.ax.v%d" % (server.current_serial, server.version)
- self.send_file(server, fn)
- except IOError:
- server.push_pdu(ErrorReportPDU(version = server.version,
- errno = ErrorReportPDU.codes["Internal Error"],
- errpdu = self,
- errmsg = "Couldn't open %s" % fn))
+ def serve(self, server):
+ """
+ Received a reset query, send full current state in response.
+ """
+
+ server.logger.debug(self)
+ if server.get_serial() is None:
+ self.send_nodata(server)
+ else:
+ try:
+ fn = "%d.ax.v%d" % (server.current_serial, server.version)
+ self.send_file(server, fn)
+ except IOError:
+ server.push_pdu(ErrorReportPDU(version = server.version,
+ errno = ErrorReportPDU.codes["Internal Error"],
+ errpdu = self,
+ errmsg = "Couldn't open %s" % fn))
@clone_pdu
class ErrorReportPDU(rpki.rtr.pdus.ErrorReportPDU):
- """
- Error Report PDU.
- """
-
- def serve(self, server):
"""
- Received an ErrorReportPDU from client. Not much we can do beyond
- logging it, then killing the connection if error was fatal.
+ Error Report PDU.
"""
- server.logger.error(self)
- if self.errno in self.fatal:
- server.logger.error("[Shutting down due to reported fatal protocol error]")
- sys.exit(1)
+ def serve(self, server):
+ """
+ Received an ErrorReportPDU from client. Not much we can do beyond
+ logging it, then killing the connection if error was fatal.
+ """
+
+ server.logger.error(self)
+ if self.errno in self.fatal:
+ server.logger.error("[Shutting down due to reported fatal protocol error]")
+ sys.exit(1)
def read_current(version):
- """
- Read current serial number and nonce. Return None for both if
- serial and nonce not recorded. For backwards compatibility, treat
- file containing just a serial number as having a nonce of zero.
- """
-
- if version is None:
- return None, None
- try:
- with open("current.v%d" % version, "r") as f:
- values = tuple(int(s) for s in f.read().split())
- return values[0], values[1]
- except IndexError:
- return values[0], 0
- except IOError:
- return None, None
+ """
+ Read current serial number and nonce. Return None for both if
+ serial and nonce not recorded. For backwards compatibility, treat
+ file containing just a serial number as having a nonce of zero.
+ """
+
+ if version is None:
+ return None, None
+ try:
+ with open("current.v%d" % version, "r") as f:
+ values = tuple(int(s) for s in f.read().split())
+ return values[0], values[1]
+ except IndexError:
+ return values[0], 0
+ except IOError:
+ return None, None
def write_current(serial, nonce, version):
- """
- Write serial number and nonce.
- """
+ """
+ Write serial number and nonce.
+ """
- curfn = "current.v%d" % version
- tmpfn = curfn + "%d.tmp" % os.getpid()
- with open(tmpfn, "w") as f:
- f.write("%d %d\n" % (serial, nonce))
- os.rename(tmpfn, curfn)
+ curfn = "current.v%d" % version
+ tmpfn = curfn + "%d.tmp" % os.getpid()
+ with open(tmpfn, "w") as f:
+ f.write("%d %d\n" % (serial, nonce))
+ os.rename(tmpfn, curfn)
class FileProducer(object):
- """
- File-based producer object for asynchat.
- """
+ """
+ File-based producer object for asynchat.
+ """
- def __init__(self, handle, buffersize):
- self.handle = handle
- self.buffersize = buffersize
+ def __init__(self, handle, buffersize):
+ self.handle = handle
+ self.buffersize = buffersize
- def more(self):
- return self.handle.read(self.buffersize)
+ def more(self):
+ return self.handle.read(self.buffersize)
class ServerWriteChannel(rpki.rtr.channels.PDUChannel):
- """
- Kludge to deal with ssh's habit of sometimes (compile time option)
- invoking us with two unidirectional pipes instead of one
- bidirectional socketpair. All the server logic is in the
- ServerChannel class, this class just deals with sending the
- server's output to a different file descriptor.
- """
-
- def __init__(self):
"""
- Set up stdout.
+ Kludge to deal with ssh's habit of sometimes (compile time option)
+ invoking us with two unidirectional pipes instead of one
+ bidirectional socketpair. All the server logic is in the
+ ServerChannel class, this class just deals with sending the
+ server's output to a different file descriptor.
"""
- super(ServerWriteChannel, self).__init__(root_pdu_class = PDU)
- self.init_file_dispatcher(sys.stdout.fileno())
+ def __init__(self):
+ """
+ Set up stdout.
+ """
- def readable(self):
- """
- This channel is never readable.
- """
+ super(ServerWriteChannel, self).__init__(root_pdu_class = PDU)
+ self.init_file_dispatcher(sys.stdout.fileno())
- return False
+ def readable(self):
+ """
+ This channel is never readable.
+ """
- def push_file(self, f):
- """
- Write content of a file to stream.
- """
+ return False
- try:
- self.push_with_producer(FileProducer(f, self.ac_out_buffer_size))
- except OSError, e:
- if e.errno != errno.EAGAIN:
- raise
+ def push_file(self, f):
+ """
+ Write content of a file to stream.
+ """
+ try:
+ self.push_with_producer(FileProducer(f, self.ac_out_buffer_size))
+ except OSError, e:
+ if e.errno != errno.EAGAIN:
+ raise
-class ServerChannel(rpki.rtr.channels.PDUChannel):
- """
- Server protocol engine, handles upcalls from PDUChannel to
- implement protocol logic.
- """
- def __init__(self, logger, refresh, retry, expire):
+class ServerChannel(rpki.rtr.channels.PDUChannel):
"""
- Set up stdin and stdout as connection and start listening for
- first PDU.
+ Server protocol engine, handles upcalls from PDUChannel to
+ implement protocol logic.
"""
- super(ServerChannel, self).__init__(root_pdu_class = PDU)
- self.init_file_dispatcher(sys.stdin.fileno())
- self.writer = ServerWriteChannel()
- self.logger = logger
- self.refresh = refresh
- self.retry = retry
- self.expire = expire
- self.get_serial()
- self.start_new_pdu()
-
- def writable(self):
- """
- This channel is never writable.
- """
+ def __init__(self, logger, refresh, retry, expire):
+ """
+ Set up stdin and stdout as connection and start listening for
+ first PDU.
+ """
- return False
+ super(ServerChannel, self).__init__(root_pdu_class = PDU)
+ self.init_file_dispatcher(sys.stdin.fileno())
+ self.writer = ServerWriteChannel()
+ self.logger = logger
+ self.refresh = refresh
+ self.retry = retry
+ self.expire = expire
+ self.get_serial()
+ self.start_new_pdu()
- def push(self, data):
- """
- Redirect to writer channel.
- """
+ def writable(self):
+ """
+ This channel is never writable.
+ """
- return self.writer.push(data)
+ return False
- def push_with_producer(self, producer):
- """
- Redirect to writer channel.
- """
+ def push(self, data):
+ """
+ Redirect to writer channel.
+ """
- return self.writer.push_with_producer(producer)
+ return self.writer.push(data)
- def push_pdu(self, pdu):
- """
- Redirect to writer channel.
- """
+ def push_with_producer(self, producer):
+ """
+ Redirect to writer channel.
+ """
- return self.writer.push_pdu(pdu)
+ return self.writer.push_with_producer(producer)
- def push_file(self, f):
- """
- Redirect to writer channel.
- """
+ def push_pdu(self, pdu):
+ """
+ Redirect to writer channel.
+ """
- return self.writer.push_file(f)
+ return self.writer.push_pdu(pdu)
- def deliver_pdu(self, pdu):
- """
- Handle received PDU.
- """
+ def push_file(self, f):
+ """
+ Redirect to writer channel.
+ """
- pdu.serve(self)
+ return self.writer.push_file(f)
- def get_serial(self):
- """
- Read, cache, and return current serial number, or None if we can't
- find the serial number file. The latter condition should never
- happen, but maybe we got started in server mode while the cronjob
- mode instance is still building its database.
- """
+ def deliver_pdu(self, pdu):
+ """
+ Handle received PDU.
+ """
- self.current_serial, self.current_nonce = read_current(self.version)
- return self.current_serial
+ pdu.serve(self)
- def check_serial(self):
- """
- Check for a new serial number.
- """
+ def get_serial(self):
+ """
+ Read, cache, and return current serial number, or None if we can't
+ find the serial number file. The latter condition should never
+ happen, but maybe we got started in server mode while the cronjob
+ mode instance is still building its database.
+ """
- old_serial = self.current_serial
- return old_serial != self.get_serial()
+ self.current_serial, self.current_nonce = read_current(self.version)
+ return self.current_serial
- def notify(self, data = None, force = False):
- """
- Cronjob instance kicked us: check whether our serial number has
- changed, and send a notify message if so.
+ def check_serial(self):
+ """
+ Check for a new serial number.
+ """
- We have to check rather than just blindly notifying when kicked
- because the cronjob instance has no good way of knowing which
- protocol version we're running, thus has no good way of knowing
- whether we care about a particular change set or not.
- """
+ old_serial = self.current_serial
+ return old_serial != self.get_serial()
- if force or self.check_serial():
- self.push_pdu(SerialNotifyPDU(version = self.version,
- serial = self.current_serial,
- nonce = self.current_nonce))
- else:
- self.logger.debug("Cronjob kicked me but I see no serial change, ignoring")
+ def notify(self, data = None, force = False):
+ """
+ Cronjob instance kicked us: check whether our serial number has
+ changed, and send a notify message if so.
+
+ We have to check rather than just blindly notifying when kicked
+ because the cronjob instance has no good way of knowing which
+ protocol version we're running, thus has no good way of knowing
+ whether we care about a particular change set or not.
+ """
+
+ if force or self.check_serial():
+ self.push_pdu(SerialNotifyPDU(version = self.version,
+ serial = self.current_serial,
+ nonce = self.current_nonce))
+ else:
+ self.logger.debug("Cronjob kicked me but I see no serial change, ignoring")
class KickmeChannel(asyncore.dispatcher, object):
- """
- asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
- kick servers when it's time to send notify PDUs to clients.
- """
-
- def __init__(self, server):
- asyncore.dispatcher.__init__(self) # Old-style class
- self.server = server
- self.sockname = "%s.%d" % (kickme_base, os.getpid())
- self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM)
- try:
- self.bind(self.sockname)
- os.chmod(self.sockname, 0660)
- except socket.error, e:
- self.server.logger.exception("Couldn't bind() kickme socket: %r", e)
- self.close()
- except OSError, e:
- self.server.logger.exception("Couldn't chmod() kickme socket: %r", e)
-
- def writable(self):
"""
- This socket is read-only, never writable.
+ asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to
+ kick servers when it's time to send notify PDUs to clients.
"""
- return False
+ def __init__(self, server):
+ asyncore.dispatcher.__init__(self) # Old-style class
+ self.server = server
+ self.sockname = "%s.%d" % (kickme_base, os.getpid())
+ self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+ try:
+ self.bind(self.sockname)
+ os.chmod(self.sockname, 0660)
+ except socket.error, e:
+ self.server.logger.exception("Couldn't bind() kickme socket: %r", e)
+ self.close()
+ except OSError, e:
+ self.server.logger.exception("Couldn't chmod() kickme socket: %r", e)
+
+ def writable(self):
+ """
+ This socket is read-only, never writable.
+ """
+
+ return False
+
+ def handle_connect(self):
+ """
+ Ignore connect events (not very useful on datagram socket).
+ """
+
+ pass
+
+ def handle_read(self):
+ """
+ Handle receipt of a datagram.
+ """
+
+ data = self.recv(512)
+ self.server.notify(data)
+
+ def cleanup(self):
+ """
+ Clean up this dispatcher's socket.
+ """
+
+ self.close()
+ try:
+ os.unlink(self.sockname)
+ except: # pylint: disable=W0702
+ pass
- def handle_connect(self):
- """
- Ignore connect events (not very useful on datagram socket).
- """
+ def log(self, msg):
+ """
+ Intercept asyncore's logging.
+ """
- pass
+ self.server.logger.info(msg)
- def handle_read(self):
- """
- Handle receipt of a datagram.
- """
+ def log_info(self, msg, tag = "info"):
+ """
+ Intercept asyncore's logging.
+ """
- data = self.recv(512)
- self.server.notify(data)
+ self.server.logger.info("asyncore: %s: %s", tag, msg)
- def cleanup(self):
- """
- Clean up this dispatcher's socket.
- """
+ def handle_error(self):
+ """
+ Handle errors caught by asyncore main loop.
+ """
- self.close()
- try:
- os.unlink(self.sockname)
- except: # pylint: disable=W0702
- pass
+ self.server.logger.exception("[Unhandled exception]")
+ self.server.logger.critical("[Exiting after unhandled exception]")
+ sys.exit(1)
- def log(self, msg):
+
+def _hostport_tag():
"""
- Intercept asyncore's logging.
+ Construct hostname/address + port when we're running under a
+ protocol we understand well enough to do that. This is all
+ kludgery. Just grit your teeth, or perhaps just close your eyes.
"""
- self.server.logger.info(msg)
+ proto = None
- def log_info(self, msg, tag = "info"):
- """
- Intercept asyncore's logging.
- """
+ if proto is None:
+ try:
+ host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername()
+ proto = "tcp"
+ except: # pylint: disable=W0702
+ pass
- self.server.logger.info("asyncore: %s: %s", tag, msg)
+ if proto is None:
+ try:
+ host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2]
+ proto = "tcp"
+ except: # pylint: disable=W0702
+ pass
- def handle_error(self):
+ if proto is None:
+ try:
+ host, port = os.environ["SSH_CONNECTION"].split()[0:2]
+ proto = "ssh"
+ except: # pylint: disable=W0702
+ pass
+
+ if proto is None:
+ try:
+ host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT")
+ proto = "ssl"
+ except: # pylint: disable=W0702
+ pass
+
+ if proto is None:
+ return ""
+ elif not port:
+ return "/%s/%s" % (proto, host)
+ elif ":" in host:
+ return "/%s/%s.%s" % (proto, host, port)
+ else:
+ return "/%s/%s:%s" % (proto, host, port)
+
+
+def server_main(args):
"""
- Handle errors caught by asyncore main loop.
+ Implement the server side of the rpkk-router protocol. Other than
+ one PF_UNIX socket inode, this doesn't write anything to disk, so it
+ can be run with minimal privileges. Most of the work has already
+ been done by the database generator, so all this server has to do is
+ pass the results along to a client.
"""
- self.server.logger.exception("[Unhandled exception]")
- self.server.logger.critical("[Exiting after unhandled exception]")
- sys.exit(1)
+ logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag()))
+ logger.debug("[Starting]")
-def _hostport_tag():
- """
- Construct hostname/address + port when we're running under a
- protocol we understand well enough to do that. This is all
- kludgery. Just grit your teeth, or perhaps just close your eyes.
- """
-
- proto = None
+ if args.rpki_rtr_dir:
+ try:
+ os.chdir(args.rpki_rtr_dir)
+ except OSError, e:
+ sys.exit(e)
- if proto is None:
+ kickme = None
try:
- host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername()
- proto = "tcp"
- except: # pylint: disable=W0702
- pass
+ server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
+ kickme = rpki.rtr.server.KickmeChannel(server = server)
+ asyncore.loop(timeout = None)
+ signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition
+ except KeyboardInterrupt:
+ sys.exit(0)
+ finally:
+ signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition
+ if kickme is not None:
+ kickme.cleanup()
- if proto is None:
- try:
- host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2]
- proto = "tcp"
- except: # pylint: disable=W0702
- pass
- if proto is None:
- try:
- host, port = os.environ["SSH_CONNECTION"].split()[0:2]
- proto = "ssh"
- except: # pylint: disable=W0702
- pass
+def listener_main(args):
+ """
+ Totally insecure TCP listener for rpki-rtr protocol. We only
+ implement this because it's all that the routers currently support.
+ In theory, we will all be running TCP-AO in the future, at which
+ point this listener will go away or become a TCP-AO listener.
+ """
- if proto is None:
- try:
- host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT")
- proto = "ssl"
- except: # pylint: disable=W0702
- pass
+ # Perhaps we should daemonize? Deal with that later.
- if proto is None:
- return ""
- elif not port:
- return "/%s/%s" % (proto, host)
- elif ":" in host:
- return "/%s/%s.%s" % (proto, host, port)
- else:
- return "/%s/%s:%s" % (proto, host, port)
+ # server_main() handles args.rpki_rtr_dir.
+ listener = None
+ try:
+ listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
+ listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+ except: # pylint: disable=W0702
+ if listener is not None:
+ listener.close()
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ try:
+ listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ except AttributeError:
+ pass
+ listener.bind(("", args.port))
+ listener.listen(5)
+ logging.debug("[Listening on port %s]", args.port)
+ while True:
+ try:
+ s, ai = listener.accept()
+ except KeyboardInterrupt:
+ sys.exit(0)
+ logging.debug("[Received connection from %r]", ai)
+ pid = os.fork()
+ if pid == 0:
+ os.dup2(s.fileno(), 0) # pylint: disable=E1103
+ os.dup2(s.fileno(), 1) # pylint: disable=E1103
+ s.close()
+ #os.closerange(3, os.sysconf("SC_OPEN_MAX"))
+ server_main(args)
+ sys.exit()
+ else:
+ logging.debug("[Spawned server %d]", pid)
+ while True:
+ try:
+ pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612
+ if pid:
+ logging.debug("[Server %s exited]", pid)
+ continue
+ except: # pylint: disable=W0702
+ pass
+ break
-def server_main(args):
- """
- Implement the server side of the rpkk-router protocol. Other than
- one PF_UNIX socket inode, this doesn't write anything to disk, so it
- can be run with minimal privileges. Most of the work has already
- been done by the database generator, so all this server has to do is
- pass the results along to a client.
- """
- logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag()))
+def argparse_setup(subparsers):
+ """
+ Set up argparse stuff for commands in this module.
+ """
- logger.debug("[Starting]")
+ # These could have been lambdas, but doing it this way results in
+ # more useful error messages on argparse failures.
- if args.rpki_rtr_dir:
- try:
- os.chdir(args.rpki_rtr_dir)
- except OSError, e:
- sys.exit(e)
-
- kickme = None
- try:
- server = rpki.rtr.server.ServerChannel(logger = logger, refresh = args.refresh, retry = args.retry, expire = args.expire)
- kickme = rpki.rtr.server.KickmeChannel(server = server)
- asyncore.loop(timeout = None)
- signal.signal(signal.SIGINT, signal.SIG_IGN) # Theorized race condition
- except KeyboardInterrupt:
- sys.exit(0)
- finally:
- signal.signal(signal.SIGINT, signal.SIG_IGN) # Observed race condition
- if kickme is not None:
- kickme.cleanup()
+ def refresh(v):
+ return rpki.rtr.pdus.valid_refresh(int(v))
+ def retry(v):
+ return rpki.rtr.pdus.valid_retry(int(v))
-def listener_main(args):
- """
- Totally insecure TCP listener for rpki-rtr protocol. We only
- implement this because it's all that the routers currently support.
- In theory, we will all be running TCP-AO in the future, at which
- point this listener will go away or become a TCP-AO listener.
- """
-
- # Perhaps we should daemonize? Deal with that later.
-
- # server_main() handles args.rpki_rtr_dir.
-
- listener = None
- try:
- listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
- listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
- except: # pylint: disable=W0702
- if listener is not None:
- listener.close()
- listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- try:
- listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- except AttributeError:
- pass
- listener.bind(("", args.port))
- listener.listen(5)
- logging.debug("[Listening on port %s]", args.port)
- while True:
- try:
- s, ai = listener.accept()
- except KeyboardInterrupt:
- sys.exit(0)
- logging.debug("[Received connection from %r]", ai)
- pid = os.fork()
- if pid == 0:
- os.dup2(s.fileno(), 0) # pylint: disable=E1103
- os.dup2(s.fileno(), 1) # pylint: disable=E1103
- s.close()
- #os.closerange(3, os.sysconf("SC_OPEN_MAX"))
- server_main(args)
- sys.exit()
- else:
- logging.debug("[Spawned server %d]", pid)
- while True:
- try:
- pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612
- if pid:
- logging.debug("[Server %s exited]", pid)
- continue
- except: # pylint: disable=W0702
- pass
- break
+ def expire(v):
+ return rpki.rtr.pdus.valid_expire(int(v))
+ # Some duplication of arguments here, not enough to be worth huge
+ # effort to clean up, worry about it later in any case.
-def argparse_setup(subparsers):
- """
- Set up argparse stuff for commands in this module.
- """
-
- # These could have been lambdas, but doing it this way results in
- # more useful error messages on argparse failures.
-
- def refresh(v):
- return rpki.rtr.pdus.valid_refresh(int(v))
-
- def retry(v):
- return rpki.rtr.pdus.valid_retry(int(v))
-
- def expire(v):
- return rpki.rtr.pdus.valid_expire(int(v))
-
- # Some duplication of arguments here, not enough to be worth huge
- # effort to clean up, worry about it later in any case.
-
- subparser = subparsers.add_parser("server", description = server_main.__doc__,
- help = "RPKI-RTR protocol server")
- subparser.set_defaults(func = server_main, default_log_to = "syslog")
- subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
- subparser.add_argument("--retry", type = retry, help = "override default retry timer")
- subparser.add_argument("--expire", type = expire, help = "override default expire timer")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
-
- subparser = subparsers.add_parser("listener", description = listener_main.__doc__,
- help = "TCP listener for RPKI-RTR protocol server")
- subparser.set_defaults(func = listener_main, default_log_to = "syslog")
- subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
- subparser.add_argument("--retry", type = retry, help = "override default retry timer")
- subparser.add_argument("--expire", type = expire, help = "override default expire timer")
- subparser.add_argument("port", type = int, help = "TCP port on which to listen")
- subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+ subparser = subparsers.add_parser("server", description = server_main.__doc__,
+ help = "RPKI-RTR protocol server")
+ subparser.set_defaults(func = server_main, default_log_to = "syslog")
+ subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
+ subparser.add_argument("--retry", type = retry, help = "override default retry timer")
+ subparser.add_argument("--expire", type = expire, help = "override default expire timer")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")
+
+ subparser = subparsers.add_parser("listener", description = listener_main.__doc__,
+ help = "TCP listener for RPKI-RTR protocol server")
+ subparser.set_defaults(func = listener_main, default_log_to = "syslog")
+ subparser.add_argument("--refresh", type = refresh, help = "override default refresh timer")
+ subparser.add_argument("--retry", type = retry, help = "override default retry timer")
+ subparser.add_argument("--expire", type = expire, help = "override default expire timer")
+ subparser.add_argument("port", type = int, help = "TCP port on which to listen")
+ subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database")