diff options
author | Rob Austein <sra@hactrn.net> | 2014-04-26 02:32:07 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2014-04-26 02:32:07 +0000 |
commit | 5e40749b8591f241f482afadb295cf66a20c8d04 (patch) | |
tree | edc56c5f3e8ad65425ca5ca73b1512f82458c436 /rpki | |
parent | 9319a6d6a9a54d516075f57cf09962f469a91b34 (diff) |
Refactor rpki-rtr code.
svn path=/trunk/; revision=5817
Diffstat (limited to 'rpki')
-rw-r--r-- | rpki/rpki_rtr/__init__.py | 0 | ||||
-rw-r--r-- | rpki/rpki_rtr/channels.py | 246 | ||||
-rw-r--r-- | rpki/rpki_rtr/client.py | 558 | ||||
-rw-r--r-- | rpki/rpki_rtr/generator.py | 573 | ||||
-rw-r--r-- | rpki/rpki_rtr/pdus.py | 641 | ||||
-rw-r--r-- | rpki/rpki_rtr/server.py | 618 |
6 files changed, 2636 insertions, 0 deletions
diff --git a/rpki/rpki_rtr/__init__.py b/rpki/rpki_rtr/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/rpki/rpki_rtr/__init__.py diff --git a/rpki/rpki_rtr/channels.py b/rpki/rpki_rtr/channels.py new file mode 100644 index 00000000..d2a8972f --- /dev/null +++ b/rpki/rpki_rtr/channels.py @@ -0,0 +1,246 @@ +# $Id$ +# +# Copyright (C) 2014 Dragon Research Labs ("DRL") +# Portions copyright (C) 2009-2013 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notices and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND DRL AND ISC DISCLAIM ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL OR +# ISC BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA +# OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +# TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +""" +I/O system of RPKI-RTR protocol implementation. +""" + +import os +import sys +import time +import fcntl +import errno +import logging +import asyncore +import asynchat +import rpki.rpki_rtr.pdus + + +class Timestamp(int): + """ + Wrapper around time module. + """ + + def __new__(cls, t): + # http://stackoverflow.com/questions/7471255/pythons-super-and-new-confused-me + #return int.__new__(cls, t) + return super(Timestamp, cls).__new__(cls, t) + + @classmethod + def now(cls, delta = 0): + return cls(time.time() + delta) + + def __str__(self): + return time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self)) + + +class ReadBuffer(object): + """ + Wrapper around synchronous/asynchronous read state. + + This also handles tracking the current protocol version, + because it has to go somewhere and there's no better place. + """ + + def __init__(self): + self.buffer = "" + self.version = None + + def update(self, need, callback): + """ + Update count of needed bytes and callback, then dispatch to callback. + """ + + self.need = need + self.callback = callback + return self.retry() + + def retry(self): + """ + Try dispatching to the callback again. + """ + + return self.callback(self) + + def available(self): + """ + How much data do we have available in this buffer? + """ + + return len(self.buffer) + + def needed(self): + """ + How much more data does this buffer need to become ready? + """ + + return self.need - self.available() + + def ready(self): + """ + Is this buffer ready to read yet? + """ + + return self.available() >= self.need + + def get(self, n): + """ + Hand some data to the caller. + """ + + b = self.buffer[:n] + self.buffer = self.buffer[n:] + return b + + def put(self, b): + """ + Accumulate some data. + """ + + self.buffer += b + + def check_version(self, version): + """ + Track version number of PDUs read from this buffer. + Once set, the version must not change. + """ + + if self.version is not None and version != self.version: + raise rpki.rpki_rtr.pdus.CorruptData( + "Received PDU version %d, expected %d" % (version, self.version)) + if self.version is None and version not in rpki.rpki_rtr.pdus.PDU.version_map: + raise rpki.rpki_rtr.pdus.UnsupportedProtocolVersion( + "Received PDU version %d, known versions %s" % (version, ", ".PDU.version_map.iterkeys())) + self.version = version + + +class PDUChannel(asynchat.async_chat, object): + """ + asynchat subclass that understands our PDUs. This just handles + network I/O. Specific engines (client, server) should be subclasses + of this with methods that do something useful with the resulting + PDUs. + """ + + def __init__(self, root_pdu_class, sock = None): + asynchat.async_chat.__init__(self, sock) # Old-style class, can't use super() + self.reader = ReadBuffer() + assert issubclass(root_pdu_class, rpki.rpki_rtr.pdus.PDU) + self.root_pdu_class = root_pdu_class + + @property + def version(self): + return self.reader.version + + @version.setter + def version(self, version): + self.reader.check_version(version) + + def start_new_pdu(self): + """ + Start read of a new PDU. + """ + + try: + p = self.root_pdu_class.read_pdu(self.reader) + while p is not None: + self.deliver_pdu(p) + p = self.root_pdu_class.read_pdu(self.reader) + except rpki.rpki_rtr.pdus.PDUException, e: + self.push_pdu(e.make_error_report(version = self.version)) + self.close_when_done() + else: + assert not self.reader.ready() + self.set_terminator(self.reader.needed()) + + def collect_incoming_data(self, data): + """ + Collect data into the read buffer. + """ + + self.reader.put(data) + + def found_terminator(self): + """ + Got requested data, see if we now have a PDU. If so, pass it + along, then restart cycle for a new PDU. + """ + + p = self.reader.retry() + if p is None: + self.set_terminator(self.reader.needed()) + else: + self.deliver_pdu(p) + self.start_new_pdu() + + def push_pdu(self, pdu): + """ + Write PDU to stream. + """ + + try: + self.push(pdu.to_pdu()) + except OSError, e: + if e.errno != errno.EAGAIN: + raise + + def log(self, msg): + """ + Intercept asyncore's logging. + """ + + logging.info(msg) + + def log_info(self, msg, tag = "info"): + """ + Intercept asynchat's logging. + """ + + logging.info("asynchat: %s: %s", tag, msg) + + def handle_error(self): + """ + Handle errors caught by asyncore main loop. + """ + + logging.exception("[Unhandled exception]") + logging.critical("[Exiting after unhandled exception]") + sys.exit(1) + + def init_file_dispatcher(self, fd): + """ + Kludge to plug asyncore.file_dispatcher into asynchat. Call from + subclass's __init__() method, after calling + PDUChannel.__init__(), and don't read this on a full stomach. + """ + + self.connected = True + self._fileno = fd + self.socket = asyncore.file_wrapper(fd) + self.add_channel() + flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + def handle_close(self): + """ + Exit when channel closed. + """ + + asynchat.async_chat.handle_close(self) + sys.exit(0) diff --git a/rpki/rpki_rtr/client.py b/rpki/rpki_rtr/client.py new file mode 100644 index 00000000..8143e1df --- /dev/null +++ b/rpki/rpki_rtr/client.py @@ -0,0 +1,558 @@ +# $Id$ +# +# Copyright (C) 2014 Dragon Research Labs ("DRL") +# Portions copyright (C) 2009-2013 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notices and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND DRL AND ISC DISCLAIM ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL OR +# ISC BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA +# OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +# TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +""" +Client implementation for the RPKI-RTR protocol (RFC 6810 et sequalia). +""" + +import os +import sys +import time +import base64 +import socket +import signal +import logging +import asyncore +import subprocess +import rpki.rpki_rtr.pdus +import rpki.rpki_rtr.channels + +from rpki.rpki_rtr.pdus import ResetQueryPDU, SerialQueryPDU +from rpki.rpki_rtr.channels import Timestamp + + +class PDU(rpki.rpki_rtr.pdus.PDU): + """ + Object representing a generic PDU in the rpki-router protocol. + Real PDUs are subclasses of this class. + """ + + def consume(self, client): + """ + Handle results in test client. Default behavior is just to print + out the PDU; data PDU subclasses may override this. + """ + + logging.debug(self) + + +clone_pdu = rpki.rpki_rtr.pdus.clone_pdu_root(PDU) + + +@clone_pdu +class SerialNotifyPDU(rpki.rpki_rtr.pdus.SerialNotifyPDU): + """ + Serial Notify PDU. + """ + + def consume(self, client): + """ + Respond to a SerialNotifyPDU with either a SerialQueryPDU or a + ResetQueryPDU, depending on what we already know. + """ + + logging.debug(self) + if client.current_serial is None or client.current_nonce != self.nonce: + client.push_pdu(ResetQueryPDU(version = client.version)) + elif self.serial != client.current_serial: + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.current_serial, + nonce = client.current_nonce)) + else: + logging.debug("[Notify did not change serial number, ignoring]") + + +@clone_pdu +class CacheResponsePDU(rpki.rpki_rtr.pdus.CacheResponsePDU): + """ + Cache Response PDU. + """ + + def consume(self, client): + """ + Handle CacheResponsePDU. + """ + + logging.debug(self) + if self.nonce != client.current_nonce: + logging.debug("[Nonce changed, resetting]") + client.cache_reset() + +@clone_pdu +class EndOfDataPDUv0(rpki.rpki_rtr.pdus.EndOfDataPDUv0): + """ + End of Data PDU, protocol version 0. + """ + + def consume(self, client): + """ + Handle EndOfDataPDU response. + """ + + logging.debug(self) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + +@clone_pdu +class EndOfDataPDUv1(rpki.rpki_rtr.pdus.EndOfDataPDUv1): + """ + End of Data PDU, protocol version 1. + """ + + def consume(self, client): + """ + Handle EndOfDataPDU response. + """ + + logging.debug(self) + client.end_of_data(self.version, self.serial, self.nonce, self.refresh, self.retry, self.expire) + + +@clone_pdu +class CacheResetPDU(rpki.rpki_rtr.pdus.CacheResetPDU): + """ + Cache reset PDU. + """ + + def consume(self, client): + """ + Handle CacheResetPDU response, by issuing a ResetQueryPDU. + """ + + logging.debug(self) + client.cache_reset() + client.push_pdu(ResetQueryPDU(version = client.version)) + + +class PrefixPDU(rpki.rpki_rtr.pdus.PrefixPDU): + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + def consume(self, client): + """ + Handle one incoming prefix PDU + """ + + logging.debug(self) + client.consume_prefix(self) + + +@clone_pdu +class IPv4PrefixPDU(PrefixPDU, rpki.rpki_rtr.pdus.IPv4PrefixPDU): + """ + IPv4 flavor of a prefix. + """ + + pass + +@clone_pdu +class IPv6PrefixPDU(PrefixPDU, rpki.rpki_rtr.pdus.IPv6PrefixPDU): + """ + IPv6 flavor of a prefix. + """ + + pass + +@clone_pdu +class RouterKeyPDU(rpki.rpki_rtr.pdus.RouterKeyPDU): + """ + Router Key PDU. + """ + + def consume(self, client): + """ + Handle one incoming Router Key PDU + """ + + logging.debug(self) + client.consume_routerkey(self) + + +class ClientChannel(rpki.rpki_rtr.channels.PDUChannel): + """ + Client protocol engine, handles upcalls from PDUChannel. + """ + + current_serial = None + current_nonce = None + sql = None + host = None + port = None + cache_id = None + + # For initial test purposes, let's use the minimum allowed values + # from the RFC 6810 bis I-D as the initial defaults for refresh and + # retry, and the maximum allowed for expire; these will be overriden + # as soon as we receive an EndOfDataPDU. + # + refresh = 120 + retry = 120 + expire = 172800 + + def __init__(self, sock, proc, killsig, host, port): + self.killsig = killsig + self.proc = proc + self.host = host + self.port = port + super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU) + self.start_new_pdu() + + @classmethod + def ssh(cls, host, port): + """ + Set up ssh connection and start listening for first PDU. + """ + + argv = ("ssh", "-p", port, "-s", host, "rpki-rtr") + logging.debug("[Running ssh: %s]", " ".join(argv)) + s = socket.socketpair() + return cls(sock = s[1], + proc = subprocess.Popen(argv, executable = "/usr/bin/ssh", + stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, + host = host, port = port) + + @classmethod + def tcp(cls, host, port): + """ + Set up TCP connection and start listening for first PDU. + """ + + logging.debug("[Starting raw TCP connection to %s:%s]", host, port) + try: + addrinfo = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM) + except socket.error, e: + logging.debug("[socket.getaddrinfo() failed: %s]", e) + else: + for ai in addrinfo: + af, socktype, proto, cn, sa = ai # pylint: disable=W0612 + logging.debug("[Trying addr %s port %s]", sa[0], sa[1]) + try: + s = socket.socket(af, socktype, proto) + except socket.error, e: + logging.debug("[socket.socket() failed: %s]", e) + continue + try: + s.connect(sa) + except socket.error, e: + logging.exception("[socket.connect() failed: %s]", e) + s.close() + continue + return cls(sock = s, proc = None, killsig = None, + host = host, port = port) + sys.exit(1) + + @classmethod + def loopback(cls, host, port): + """ + Set up loopback connection and start listening for first PDU. + """ + + s = socket.socketpair() + logging.debug("[Using direct subprocess kludge for testing]") + argv = (sys.executable, sys.argv[0], "server") + return cls(sock = s[1], + proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGINT, + host = host, port = port) + + @classmethod + def tls(cls, host, port): + """ + Set up TLS connection and start listening for first PDU. + + NB: This uses OpenSSL's "s_client" command, which does not + check server certificates properly, so this is not suitable for + production use. Fixing this would be a trivial change, it just + requires using a client program which does check certificates + properly (eg, gnutls-cli, or stunnel's client mode if that works + for such purposes this week). + """ + + argv = ("openssl", "s_client", "-tls1", "-quiet", "-connect", "%s:%s" % (host, port)) + logging.debug("[Running: %s]", " ".join(argv)) + s = socket.socketpair() + return cls(sock = s[1], + proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, + host = host, port = port) + + def setup_sql(self, sqlname): + """ + Set up an SQLite database to contain the table we receive. If + necessary, we will create the database. + """ + + import sqlite3 + missing = not os.path.exists(sqlname) + self.sql = sqlite3.connect(sqlname, detect_types = sqlite3.PARSE_DECLTYPES) + self.sql.text_factory = str + cur = self.sql.cursor() + cur.execute("PRAGMA foreign_keys = on") + if missing: + cur.execute(''' + CREATE TABLE cache ( + cache_id INTEGER PRIMARY KEY NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + version INTEGER, + nonce INTEGER, + serial INTEGER, + updated INTEGER, + refresh INTEGER, + retry INTEGER, + expire INTEGER, + UNIQUE (host, port))''') + cur.execute(''' + CREATE TABLE prefix ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + prefix TEXT NOT NULL, + prefixlen INTEGER NOT NULL, + max_prefixlen INTEGER NOT NULL, + UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''') + + cur.execute(''' + CREATE TABLE routerkey ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + ski TEXT NOT NULL, + key TEXT NOT NULL, + UNIQUE (cache_id, asn, ski), + UNIQUE (cache_id, asn, key))''') + + cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire " + "FROM cache WHERE host = ? AND port = ?", + (self.host, self.port)) + try: + self.cache_id, version, self.current_nonce, self.current_serial, refresh, retry, expire = cur.fetchone() + if version is not None: + self.version = version + if refresh is not None: + self.refresh = refresh + if retry is not None: + self.retry = retry + if expire is not None: + self.expire = expire + except TypeError: + cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) + self.cache_id = cur.lastrowid + self.sql.commit() + logging.info("[Session %d version %s nonce %s serial %s refresh %s retry %s expire %s]", + self.cache_id, self.version, self.current_nonce, + self.current_serial, self.refresh, self.retry, self.expire) + + def cache_reset(self): + """ + Handle CacheResetPDU actions. + """ + + self.current_serial = None + if self.sql: + # + # For some reason there was no commit here. Dunno why. + # See if adding one breaks anything.... + # + cur = self.sql.cursor() + cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) + cur.execute("DELETE FROM routerkey WHERE cache_id = ?", (self.cache_id,)) + cur.execute("UPDATE cache SET version = ?, serial = NULL WHERE cache_id = ?", (self.version, self.cache_id)) + self.sql.commit() + + def end_of_data(self, version, serial, nonce, refresh, retry, expire): + """ + Handle EndOfDataPDU actions. + """ + + assert version == self.version + self.current_serial = serial + self.current_nonce = nonce + self.refresh = refresh + self.retry = retry + self.expire = expire + if self.sql: + self.sql.execute("UPDATE cache SET" + " version = ?, serial = ?, nonce = ?," + " refresh = ?, retry = ?, expire = ?," + " updated = datetime('now') " + "WHERE cache_id = ?", + (version, serial, nonce, refresh, retry, expire, self.cache_id)) + self.sql.commit() + + def consume_prefix(self, prefix): + """ + Handle one prefix PDU. + """ + + if self.sql: + values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) + if prefix.announce: + self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) " + "VALUES (?, ?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM prefix " + "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?", + values) + + def consume_routerkey(self, routerkey): + """ + Handle one Router Key PDU. + """ + + if self.sql: + values = (self.cache_id, routerkey.asn, + base64.urlsafe_b64encode(routerkey.ski).rstrip("="), + base64.b64encode(routerkey.key)) + if routerkey.announce: + self.sql.execute("INSERT INTO routerkey (cache_id, asn, ski, key) " + "VALUES (?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM routerkey " + "WHERE cache_id = ? AND asn = ? AND (ski = ? OR key = ?)", + values) + + def deliver_pdu(self, pdu): + """ + Handle received PDU. + """ + + pdu.consume(self) + + def push_pdu(self, pdu): + """ + Log outbound PDU then write it to stream. + """ + + logging.debug(pdu) + super(ClientChannel, self).push_pdu(pdu) + + def cleanup(self): + """ + Force clean up this client's child process. If everything goes + well, child will have exited already before this method is called, + but we may need to whack it with a stick if something breaks. + """ + + if self.proc is not None and self.proc.returncode is None: + try: + os.kill(self.proc.pid, self.killsig) + except OSError: + pass + + def handle_close(self): + """ + Intercept close event so we can log it, then shut down. + """ + + logging.debug("Server closed channel") + super(ClientChannel, self).handle_close() + + +def client_main(args): + """ + Toy client, intended only for debugging. + + This program takes one or more arguments. The first argument + determines what kind of connection it should open to the server, the + remaining arguments are connection details specific to this + particular type of connection. + + If the first argument is "loopback", the client will run a copy of + the server directly in a subprocess, and communicate with it via a + PF_UNIX socket pair. This sub-mode takes no further arguments. + + If the first argument is "ssh", the client will attempt to run ssh + in as subprocess to connect to the server using the ssh subsystem + mechanism as specified for this protocol. The remaining arguments + should be a hostname (or IP address in a form acceptable to ssh) and + a TCP port number. + + If the first argument is "tcp", the client will attempt to open a + direct (and completely insecure!) TCP connection to the server. + The remaining arguments should be a hostname (or IP address) and + a TCP port number. + + If the first argument is "tls", the client will attempt to open a + TLS connection to the server. The remaining arguments should be a + hostname (or IP address) and a TCP port number. + + An optional final name is the name of a file containing a SQLite + database in which to store the received table. If specified, this + database will be created if missing. + """ + + logging.debug("[Startup]") + + constructor = getattr(rpki.rpki_rtr.client.ClientChannel, args.protocol) + + client = None + try: + client = constructor(args.host, args.port) + if args.sql_database: + client.setup_sql(args.sql_database) + while True: + if client.current_serial is None or client.current_nonce is None: + client.push_pdu(ResetQueryPDU(version = client.version)) + else: + client.push_pdu(SerialQueryPDU(version = client.version, + serial = client.current_serial, + nonce = client.current_nonce)) + polled = Timestamp.now() + wakeup = None + while True: + if wakeup != polled + client.refresh: + wakeup = Timestamp(polled + client.refresh) + logging.info("[Last client poll %s, next %s]", polled, wakeup) + remaining = wakeup - time.time() + if remaining < 0: + break + asyncore.loop(timeout = remaining, count = 1) + + except KeyboardInterrupt: + sys.exit(0) + finally: + if client is not None: + client.cleanup() + + +def argparse_setup(subparsers): + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("client", description = client_main.__doc__, + help = "Test client for RPKI-RTR protocol") + subparser.set_defaults(func = client_main, default_log_to = "stderr") + subparser.add_argument("--sql-database", help = "filename for sqlite3 database of client state") + subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol") + subparser.add_argument("host", nargs = "?", help = "server host") + subparser.add_argument("port", nargs = "?", help = "server port") diff --git a/rpki/rpki_rtr/generator.py b/rpki/rpki_rtr/generator.py new file mode 100644 index 00000000..5ef2c3dc --- /dev/null +++ b/rpki/rpki_rtr/generator.py @@ -0,0 +1,573 @@ +# $Id$ +# +# Copyright (C) 2014 Dragon Research Labs ("DRL") +# Portions copyright (C) 2009-2013 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notices and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND DRL AND ISC DISCLAIM ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL OR +# ISC BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA +# OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +# TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +""" +Database generator for RPKI-RTR server (RFC 6810 et sequalia). +""" + +import os +import sys +import glob +import base64 +import socket +import logging +import subprocess +import rpki.POW +import rpki.oids +import rpki.rpki_rtr.pdus +import rpki.rpki_rtr.channels +import rpki.rpki_rtr.server + +from rpki.rpki_rtr.channels import Timestamp + +class PrefixPDU(rpki.rpki_rtr.pdus.PrefixPDU): + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + @staticmethod + def from_text(version, asn, addr): + """ + Construct a prefix from its text form. + """ + + cls = IPv6PrefixPDU if ":" in addr else IPv4PrefixPDU + self = cls(version = version) + self.asn = long(asn) + p, l = addr.split("/") + self.prefix = rpki.POW.IPAddress(p) + if "-" in l: + self.prefixlen, self.max_prefixlen = tuple(int(i) for i in l.split("-")) + else: + self.prefixlen = self.max_prefixlen = int(l) + self.announce = 1 + self.check() + return self + + @staticmethod + def from_roa(version, asn, prefix_tuple): + """ + Construct a prefix from a ROA. + """ + + address, length, maxlength = prefix_tuple + cls = IPv6PrefixPDU if address.version == 6 else IPv4PrefixPDU + self = cls(version = version) + self.asn = asn + self.prefix = address + self.prefixlen = length + self.max_prefixlen = length if maxlength is None else maxlength + self.announce = 1 + self.check() + return self + + +class IPv4PrefixPDU(PrefixPDU): + """ + IPv4 flavor of a prefix. + """ + + pdu_type = 4 + address_byte_count = 4 + +class IPv6PrefixPDU(PrefixPDU): + """ + IPv6 flavor of a prefix. + """ + + pdu_type = 6 + address_byte_count = 16 + +class RouterKeyPDU(rpki.rpki_rtr.pdus.RouterKeyPDU): + """ + Router Key PDU. + """ + + @classmethod + def from_text(cls, version, asn, gski, key): + """ + Construct a router key from its text form. + """ + + self = cls(version = version) + self.asn = long(asn) + self.ski = base64.urlsafe_b64decode(gski + "=") + self.key = base64.b64decode(key) + self.announce = 1 + self.check() + return self + + @classmethod + def from_certificate(cls, version, asn, ski, key): + """ + Construct a router key from a certificate. + """ + + self = cls(version = version) + self.asn = asn + self.ski = ski + self.key = key + self.announce = 1 + self.check() + return self + + +class ROA(rpki.POW.ROA): # pylint: disable=W0232 + """ + Minor additions to rpki.POW.ROA. + """ + + @classmethod + def derReadFile(cls, fn): # pylint: disable=E1002 + self = super(ROA, cls).derReadFile(fn) + self.extractWithoutVerifying() + return self + + @property + def prefixes(self): + v4, v6 = self.getPrefixes() + if v4 is not None: + for p in v4: + yield p + if v6 is not None: + for p in v6: + yield p + +class X509(rpki.POW.X509): # pylint: disable=W0232 + """ + Minor additions to rpki.POW.X509. + """ + + @property + def asns(self): + resources = self.getRFC3779() + if resources is not None and resources[0] is not None: + for min_asn, max_asn in resources[0]: + for asn in xrange(min_asn, max_asn + 1): + yield asn + + +class PDUSet(list): + """ + Object representing a set of PDUs, that is, one versioned and + (theoretically) consistant set of prefixes and router keys extracted + from rcynic's output. + """ + + def __init__(self, version): + assert version in rpki.rpki_rtr.pdus.PDU.version_map + super(PDUSet, self).__init__() + self.version = version + + @classmethod + def _load_file(cls, filename, version): + """ + Low-level method to read PDUSet from a file. + """ + + self = cls(version = version) + f = open(filename, "rb") + r = rpki.rpki_rtr.channels.ReadBuffer() + while True: + p = rpki.rpki_rtr.pdus.PDU.read_pdu(r) + while p is None: + b = f.read(r.needed()) + if b == "": + assert r.available() == 0 + return self + r.put(b) + p = r.retry() + assert p.version == self.version + self.append(p) + + @staticmethod + def seq_ge(a, b): + return ((a - b) % (1 << 32)) < (1 << 31) + + +class AXFRSet(PDUSet): + """ + Object representing a complete set of PDUs, that is, one versioned + and (theoretically) consistant set of prefixes and router + certificates extracted from rcynic's output, all with the announce + field set. + """ + + @classmethod + def parse_rcynic(cls, rcynic_dir, version, scan_roas = None, scan_routercerts = None): + """ + Parse ROAS and router certificates fetched (and validated!) by + rcynic to create a new AXFRSet. + + In normal operation, we use os.walk() and the rpki.POW library to + parse these data directly, but we can, if so instructed, use + external programs instead, for testing, simulation, or to provide + a way to inject local data. + + At some point the ability to parse these data from external + programs may move to a separate constructor function, so that we + can make this one a bit simpler and faster. + """ + + self = cls(version = version) + self.serial = rpki.rpki_rtr.channels.Timestamp.now() + + include_routercerts = RouterKeyPDU.pdu_type in rpki.rpki_rtr.pdus.PDU.version_map[version] + + if scan_roas is None or (scan_routercerts is None and include_routercerts): + for root, dirs, files in os.walk(rcynic_dir): # pylint: disable=W0612 + for fn in files: + if scan_roas is None and fn.endswith(".roa"): + roa = ROA.derReadFile(os.path.join(root, fn)) + asn = roa.getASID() + self.extend(PrefixPDU.from_roa(version = version, asn = asn, prefix_tuple = prefix_tuple) + for prefix_tuple in roa.prefixes) + if include_routercerts and scan_routercerts is None and fn.endswith(".cer"): + x = X509.derReadFile(os.path.join(root, fn)) + eku = x.getEKU() + if eku is not None and rpki.oids.id_kp_bgpsec_router in eku: + ski = x.getSKI() + key = x.getPublicKey().derWritePublic() + self.extend(RouterKeyPDU.from_certificate(version = version, asn = asn, ski = ski, key = key) + for asn in x.asns) + + if scan_roas is not None: + try: + p = subprocess.Popen((scan_roas, rcynic_dir), stdout = subprocess.PIPE) + for line in p.stdout: + line = line.split() + asn = line[1] + self.extend(PrefixPDU.from_text(version = version, asn = asn, addr = addr) + for addr in line[2:]) + except OSError, e: + sys.exit("Could not run %s: %s" % (scan_roas, e)) + + if include_routercerts and scan_routercerts is not None: + try: + p = subprocess.Popen((scan_routercerts, rcynic_dir), stdout = subprocess.PIPE) + for line in p.stdout: + line = line.split() + gski = line[0] + key = line[-1] + self.extend(RouterKeyPDU.from_text(version = version, asn = asn, gski = gski, key = key) + for asn in line[1:-1]) + except OSError, e: + sys.exit("Could not run %s: %s" % (scan_routercerts, e)) + + self.sort() + for i in xrange(len(self) - 2, -1, -1): + if self[i] == self[i + 1]: + del self[i + 1] + return self + + @classmethod + def load(cls, filename): + """ + Load an AXFRSet from a file, parse filename to obtain version and serial. + """ + + fn1, fn2, fn3 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ax" and fn3.startswith("v") and fn3[1:].isdigit() + version = int(fn3[1:]) + self = cls._load_file(filename, version) + self.serial = rpki.rpki_rtr.channels.Timestamp(fn1) + return self + + def filename(self): + """ + Generate filename for this AXFRSet. + """ + + return "%d.ax.v%d" % (self.serial, self.version) + + @classmethod + def load_current(cls, version): + """ + Load current AXFRSet. Return None if can't. + """ + + serial = rpki.rpki_rtr.server.read_current(version)[0] + if serial is None: + return None + try: + return cls.load("%d.ax.v%d" % (serial, version)) + except IOError: + return None + + def save_axfr(self): + """ + Write AXFRSet to file with magic filename. + """ + + f = open(self.filename(), "wb") + for p in self: + f.write(p.to_pdu()) + f.close() + + def destroy_old_data(self): + """ + Destroy old data files, presumably because our nonce changed and + the old serial numbers are no longer valid. + """ + + for i in glob.iglob("*.ix.*.v%d" % self.version): + os.unlink(i) + for i in glob.iglob("*.ax.v%d" % self.version): + if i != self.filename(): + os.unlink(i) + + def mark_current(self): + """ + Save current serial number and nonce, creating new nonce if + necessary. Creating a new nonce triggers cleanup of old state, as + the new nonce invalidates all old serial numbers. + """ + + assert self.version in rpki.rpki_rtr.pdus.PDU.version_map + old_serial, nonce = rpki.rpki_rtr.server.read_current(self.version) + if old_serial is None or self.seq_ge(old_serial, self.serial): + logging.debug("Creating new nonce and deleting stale data") + nonce = rpki.rpki_rtr.server.new_nonce() + self.destroy_old_data() + rpki.rpki_rtr.server.write_current(self.serial, nonce, self.version) + + def save_ixfr(self, other): + """ + Comparing this AXFRSet with an older one and write the resulting + IXFRSet to file with magic filename. Since we store PDUSets + in sorted order, computing the difference is a trivial linear + comparison. + """ + + f = open("%d.ix.%d.v%d" % (self.serial, other.serial, self.version), "wb") + old = other + new = self + len_old = len(old) + len_new = len(new) + i_old = i_new = 0 + while i_old < len_old and i_new < len_new: + if old[i_old] < new[i_new]: + f.write(old[i_old].to_pdu(announce = 0)) + i_old += 1 + elif old[i_old] > new[i_new]: + f.write(new[i_new].to_pdu(announce = 1)) + i_new += 1 + else: + i_old += 1 + i_new += 1 + for i in xrange(i_old, len_old): + f.write(old[i].to_pdu(announce = 0)) + for i in xrange(i_new, len_new): + f.write(new[i].to_pdu(announce = 1)) + f.close() + + def show(self): + """ + Print this AXFRSet. + """ + + logging.debug("# AXFR %d (%s) v%d", self.serial, self.serial, self.version) + for p in self: + logging.debug(p) + + +class IXFRSet(PDUSet): + """ + Object representing an incremental set of PDUs, that is, the + differences between one versioned and (theoretically) consistant set + of prefixes and router certificates extracted from rcynic's output + and another, with the announce fields set or cleared as necessary to + indicate the changes. + """ + + @classmethod + def load(cls, filename): + """ + Load an IXFRSet from a file, parse filename to obtain version and serials. + """ + + fn1, fn2, fn3, fn4 = os.path.basename(filename).split(".") + assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() and fn4.startswith("v") and fn4[1:].isdigit() + version = int(fn4[1:]) + self = cls._load_file(filename, version) + self.from_serial = rpki.rpki_rtr.channels.Timestamp(fn3) + self.to_serial = rpki.rpki_rtr.channels.Timestamp(fn1) + return self + + def filename(self): + """ + Generate filename for this IXFRSet. + """ + + return "%d.ix.%d.v%d" % (self.to_serial, self.from_serial, self.version) + + def show(self): + """ + Print this IXFRSet. + """ + + logging.debug("# IXFR %d (%s) -> %d (%s) v%d", + self.from_serial, self.from_serial, + self.to_serial, self.to_serial, + self.version) + for p in self: + logging.debug(p) + + +def kick_all(serial): + """ + Kick any existing server processes to wake them up. + """ + + try: + os.stat(rpki.rpki_rtr.server.kickme_dir) + except OSError: + logging.debug('# Creating directory "%s"', rpki.rpki_rtr.server.kickme_dir) + os.makedirs(rpki.rpki_rtr.server.kickme_dir) + + msg = "Good morning, serial %d is ready" % serial + sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) + for name in glob.iglob("%s.*" % rpki.rpki_rtr.server.kickme_base): + try: + logging.debug("# Kicking %s", name) + sock.sendto(msg, name) + except socket.error: + try: + logging.exception("# Failed to kick %s, probably dead socket, attempting cleanup", name) + os.unlink(name) + except Exception, e: + logging.exception("# Couldn't unlink suspected dead socket %s: %s", name, e) + except Exception, e: + logging.warning("# Failed to kick %s and don't understand why: %s", name, e) + sock.close() + + +def cronjob_main(args): + """ + Run this mode right after rcynic to do the real work of groveling + through the ROAs that rcynic collects and translating that data into + the form used in the rpki-router protocol. This mode prepares both + full dumps (AXFR) and incremental dumps against a specific prior + version (IXFR). [Terminology here borrowed from DNS, as is much of + the protocol design.] Finally, this mode kicks any active servers, + so that they can notify their clients that a new version is + available. + + Run this in the directory where you want to write its output files, + which should also be the directory in which you run this program in + --server mode. + + This mode takes one argument on the command line, which specifies + the directory name of rcynic's authenticated output tree (normally + $somewhere/rcynic-data/authenticated/). + """ + + if args.rpki_rtr_dir: + try: + if not os.path.isdir(args.rpki_rtr_dir): + os.makedirs(args.rpki_rtr_dir) + os.chdir(args.rpki_rtr_dir) + except OSError, e: + logging.critical(str(e)) + sys.exit(1) + + for version in sorted(rpki.rpki_rtr.server.PDU.version_map.iterkeys(), reverse = True): + + logging.debug("# Generating updates for protocol version %d", version) + + old_ixfrs = glob.glob("*.ix.*.v%d" % version) + + current = rpki.rpki_rtr.server.read_current(version)[0] + cutoff = Timestamp.now(-(24 * 60 * 60)) + for f in glob.iglob("*.ax.v%d" % version): + t = Timestamp(int(f.split(".")[0])) + if t < cutoff and t != current: + logging.debug("# Deleting old file %s, timestamp %s", f, t) + os.unlink(f) + + pdus = rpki.rpki_rtr.generator.AXFRSet.parse_rcynic(args.rcynic_dir, version, args.scan_roas, args.scan_routercerts) + if pdus == rpki.rpki_rtr.generator.AXFRSet.load_current(version): + logging.debug("# No change, new serial not needed") + continue + pdus.save_axfr() + for axfr in glob.iglob("*.ax.v%d" % version): + if axfr != pdus.filename(): + pdus.save_ixfr(rpki.rpki_rtr.generator.AXFRSet.load(axfr)) + pdus.mark_current() + + logging.debug("# New serial is %d (%s)", pdus.serial, pdus.serial) + + rpki.rpki_rtr.generator.kick_all(pdus.serial) + + old_ixfrs.sort() + for ixfr in old_ixfrs: + try: + logging.debug("# Deleting old file %s", ixfr) + os.unlink(ixfr) + except OSError: + pass + + +def show_main(args): + """ + Display dumps created by --cronjob mode in textual form. + Intended only for debugging. + + This mode takes no command line arguments. Run it in the directory + where you ran --cronjob mode. + """ + + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) + + g = glob.glob("*.ax.v*") + g.sort() + for f in g: + rpki.rpki_rtr.generator.AXFRSet.load(f).show() + + g = glob.glob("*.ix.*.v*") + g.sort() + for f in g: + rpki.rpki_rtr.generator.IXFRSet.load(f).show() + +def argparse_setup(subparsers): + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("cronjob", description = cronjob_main.__doc__, + help = "Generate RPKI-RTR database from rcynic output") + subparser.set_defaults(func = cronjob_main, default_log_to = "syslog") + subparser.add_argument("--scan-roas", help = "specify an external scan_roas program") + subparser.add_argument("--scan-routercerts", help = "specify an external scan_routercerts program") + subparser.add_argument("rcynic_dir", help = "directory containing validated rcynic output tree") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + + subparser = subparsers.add_parser("show", description = show_main.__doc__, + help = "Display content of RPKI-RTR database") + subparser.set_defaults(func = show_main, default_log_to = "stderr") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") diff --git a/rpki/rpki_rtr/pdus.py b/rpki/rpki_rtr/pdus.py new file mode 100644 index 00000000..d8921a07 --- /dev/null +++ b/rpki/rpki_rtr/pdus.py @@ -0,0 +1,641 @@ +# $Id$ +# +# Copyright (C) 2014 Dragon Research Labs ("DRL") +# Portions copyright (C) 2009-2013 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notices and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND DRL AND ISC DISCLAIM ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL OR +# ISC BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA +# OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +# TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +""" +PDU classes for the RPKI-RTR protocol (RFC 6810 et sequalia). +""" + +import struct +import base64 +import logging +import rpki.POW + +# Exceptions + +class PDUException(Exception): + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the ErrorReportPDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + super(PDUException, self).__init__() + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) + +class UnsupportedProtocolVersion(PDUException): + error_report_code = 4 + +class UnsupportedPDUType(PDUException): + error_report_code = 5 + +class CorruptData(PDUException): + error_report_code = 0 + +# Decorators + +def wire_pdu(cls, versions = None): + """ + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. + """ + + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls + return cls + + +def wire_pdu_only(*versions): + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ + + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) + +def clone_pdu_root(root_pdu_class): + """ + Replace a PDU root class's version_map with a two-level deep copy of itself, + and return a class decorator which subclasses can use to replace their + parent classes with themselves in the resulting cloned version map. + + This function is not itself a decorator, it returns one. + """ + + root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) + + def decorator(cls): + for pdu_map in root_pdu_class.version_map.itervalues(): + for pdu_type, pdu_class in pdu_map.items(): + if pdu_class in cls.__bases__: + pdu_map[pdu_type] = cls + + return decorator + + +# PDUs + +class PDU(object): + """ + Base PDU. Real PDUs are subclasses of this class. + """ + + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu + + _pdu = None # Cached when first generated + + header_struct = struct.Struct("!BB2xL") + + def __init__(self, version): + assert version in self.version_map + self.version = version + + def __cmp__(self, other): + return cmp(self.to_pdu(), other.to_pdu()) + + @property + def default_version(self): + return max(self.version_map.iterkeys()) + + def check(self): + pass + + @classmethod + def read_pdu(cls, reader): + return reader.update(need = cls.header_struct.size, callback = cls.got_header) + + @classmethod + def got_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.header_struct.size + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) + self = cls.version_map[version][pdu_type](version = version) + return reader.update(need = length, callback = self.got_pdu) + + +class PDUWithSerial(PDU): + """ + Base class for PDUs consisting of just a serial number and nonce. + """ + + header_struct = struct.Struct("!BBHLL") + + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + + +class PDUWithNonce(PDU): + """ + Base class for PDUs consisting of just a nonce. + """ + + header_struct = struct.Struct("!BBHL") + + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + + +class PDUEmpty(PDU): + """ + Base class for empty PDUs. + """ + + header_struct = struct.Struct("!BBHL") + + def __str__(self): + return "[%s]" % self.__class__.__name__ + + def to_pdu(self): + """ + Generate the wire format PDU for this prefix. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + +@wire_pdu +class SerialNotifyPDU(PDUWithSerial): + """ + Serial Notify PDU. + """ + + pdu_type = 0 + + +@wire_pdu +class SerialQueryPDU(PDUWithSerial): + """ + Serial Query PDU. + """ + + pdu_type = 1 + + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + + +@wire_pdu +class ResetQueryPDU(PDUEmpty): + """ + Reset Query PDU. + """ + + pdu_type = 2 + + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + + +@wire_pdu +class CacheResponsePDU(PDUWithNonce): + """ + Cache Response PDU. + """ + + pdu_type = 3 + + +def EndOfDataPDU(version, *args, **kwargs): + """ + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ + + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError + + +@wire_pdu_only(0) +class EndOfDataPDUv0(PDUWithSerial): + """ + End of Data PDU, protocol version 0. + """ + + pdu_type = 7 + + # Default values, from the current RFC 6810 bis I-D. + # Putting these here lets us use them in our client API for both + # protocol versions, even though they can only be set in the + # protocol in version 1. + + refresh = 3600 + retry = 600 + expire = 7200 + + +@wire_pdu_only(1) +class EndOfDataPDUv1(EndOfDataPDUv0): + """ + End of Data PDU, protocol version 1. + """ + + header_struct = struct.Struct("!BBHLLLLL") + + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv1, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + if refresh is not None: + assert isinstance(refresh, int) + self.refresh = refresh + if retry is not None: + assert isinstance(retry, int) + self.retry = retry + if expire is not None: + assert isinstance(expire, int) + self.expire = expire + + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + + +@wire_pdu +class CacheResetPDU(PDUEmpty): + """ + Cache reset PDU. + """ + + pdu_type = 8 + + +class PrefixPDU(PDU): + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + header_struct = struct.Struct("!BB2xLBBBx") + asnum_struct = struct.Struct("!L") + + def __str__(self): + plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def show(self): + logging.debug("# Class: %s", self.__class__.__name__) + logging.debug("# ASN: %s", self.asn) + logging.debug("# Prefix: %s", self.prefix) + logging.debug("# Prefixlen: %s", self.prefixlen) + logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) + logging.debug("# Announce: %s", self.announce) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefix.bits != self.address_byte_count * 8: + raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, + announce if announce is not None else self.announce, + self.prefixlen, self.max_prefixlen) + + self.prefix.toBytes() + + self.asnum_struct.pack(self.asn)) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.get(self.header_struct.size) + b2 = reader.get(self.address_byte_count) + b3 = reader.get(self.asnum_struct.size) + version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) + self.prefix = rpki.POW.IPAddress.fromBytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() + return self + + +@wire_pdu +class IPv4PrefixPDU(PrefixPDU): + """ + IPv4 flavor of a prefix. + """ + + pdu_type = 4 + address_byte_count = 4 + +@wire_pdu +class IPv6PrefixPDU(PrefixPDU): + """ + IPv6 flavor of a prefix. + """ + + pdu_type = 6 + address_byte_count = 16 + +@wire_pdu_only(1) +class RouterKeyPDU(PDU): + """ + Router Key PDU. + """ + + pdu_type = 9 + + header_struct = struct.Struct("!BBBxL20sL") + + def __str__(self): + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, + base64.urlsafe_b64encode(self.ski).rstrip("="), + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if len(self.ski) != 20: + raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) + pdulen = self.header_struct.size + len(self.key) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + len(self.key) + pdu = (self.header_struct.pack(self.version, + self.pdu_type, + announce if announce is not None else self.announce, + pdulen, + self.ski, + self.asn) + + self.key) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + if remaining <= 0: + raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) + self.key = reader.get(remaining) + assert header + self.key == self.to_pdu() + return self + + +@wire_pdu +class ErrorReportPDU(PDU): + """ + Error Report PDU. + """ + + pdu_type = 10 + + header_struct = struct.Struct("!BBHL") + string_struct = struct.Struct("!L") + + errors = { + 2 : "No Data Available" } + + fatal = { + 0 : "Corrupt Data", + 1 : "Internal Error", + 3 : "Invalid Request", + 4 : "Unsupported Protocol Version", + 5 : "Unsupported PDU Type", + 6 : "Withdrawal of Unknown Record", + 7 : "Duplicate Announcement Received" } + + assert set(errors) & set(fatal) == set() + + errors.update(fatal) + + codes = dict((v, k) for k, v in errors.items()) + + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) + assert errno is None or errno in self.errors + self.errno = errno + self.errpdu = errpdu + self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] + + def __str__(self): + return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) + + def to_counted_string(self, s): + return self.string_struct.pack(len(s)) + s + + def read_counted_string(self, reader, remaining): + assert remaining >= self.string_struct.size + n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] + assert remaining >= self.string_struct.size + n + return n, reader.get(n), (remaining - self.string_struct.size - n) + + def to_pdu(self): + """ + Generate the wire format PDU for this error report. + """ + + if self._pdu is None: + assert isinstance(self.errno, int) + assert not isinstance(self.errpdu, ErrorReportPDU) + p = self.errpdu + if p is None: + p = "" + elif isinstance(p, PDU): + p = p.to_pdu() + assert isinstance(p, str) + pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) + self._pdu += self.to_counted_string(p) + self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) + self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) + assert (header + + self.to_counted_string(self.errpdu) + + self.to_counted_string(self.errmsg.encode("utf8")) + == self.to_pdu()) + return self diff --git a/rpki/rpki_rtr/server.py b/rpki/rpki_rtr/server.py new file mode 100644 index 00000000..cd687ad2 --- /dev/null +++ b/rpki/rpki_rtr/server.py @@ -0,0 +1,618 @@ +# $Id$ +# +# Copyright (C) 2014 Dragon Research Labs ("DRL") +# Portions copyright (C) 2009-2013 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notices and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND DRL AND ISC DISCLAIM ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL OR +# ISC BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA +# OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +# TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +""" +Server implementation for the RPKI-RTR protocol (RFC 6810 et sequalia). +""" + +import os +import sys +import errno +import socket +import random +import logging +import asyncore +import rpki.POW +import rpki.oids +import rpki.rpki_rtr.pdus +import rpki.rpki_rtr.channels + +from rpki.rpki_rtr.pdus import (clone_pdu_root, + CacheResponsePDU, EndOfDataPDU, CacheResetPDU, CacheResponsePDU, + EndOfDataPDU, CacheResetPDU, CacheResetPDU, SerialNotifyPDU) + + +# Disable incremental updates. Debugging only, should be False in production. +disable_incrementals = False + +# These should be configurable in some sane fashion. +kickme_dir = "sockets" +kickme_base = os.path.join(kickme_dir, "kickme") + + +class PDU(rpki.rpki_rtr.pdus.PDU): + """ + Generic server PDU. + """ + + def send_file(self, server, filename): + """ + Send a content of a file as a cache response. Caller should catch IOError. + """ + + fn2 = os.path.splitext(filename)[1] + assert fn2.startswith(".v") and fn2[2:].isdigit() and int(fn2[2:]) == server.version + + f = open(filename, "rb") + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_file(f) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce)) + + def send_nodata(self, server): + """ + Send a nodata error. + """ + + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["No Data Available"], + errpdu = self)) + + +clone_pdu = clone_pdu_root(PDU) + + +@clone_pdu +class SerialQueryPDU(PDU, rpki.rpki_rtr.pdus.SerialQueryPDU): + """ + Serial Query PDU. + """ + + def serve(self, server): + """ + Received a serial query, send incremental transfer in response. + If client is already up to date, just send an empty incremental + transfer. + """ + + server.logger.debug(self) + if server.get_serial() is None: + self.send_nodata(server) + elif server.current_nonce != self.nonce: + server.logger.info("[Client requested wrong nonce, resetting client]") + server.push_pdu(CacheResetPDU(version = server.version)) + elif server.current_serial == self.serial: + server.logger.debug("[Client is already current, sending empty IXFR]") + server.push_pdu(CacheResponsePDU(version = server.version, + nonce = server.current_nonce)) + server.push_pdu(EndOfDataPDU(version = server.version, + serial = server.current_serial, + nonce = server.current_nonce)) + elif disable_incrementals: + server.push_pdu(CacheResetPDU(version = server.version)) + else: + try: + self.send_file(server, "%d.ix.%d.v%d" % (server.current_serial, self.serial, server.version)) + except IOError: + server.push_pdu(CacheResetPDU(version = server.version)) + + +@clone_pdu +class ResetQueryPDU(PDU, rpki.rpki_rtr.pdus.ResetQueryPDU): + """ + Reset Query PDU. + """ + + def serve(self, server): + """ + Received a reset query, send full current state in response. + """ + + server.logger.debug(self) + if server.get_serial() is None: + self.send_nodata(server) + else: + try: + fn = "%d.ax.v%d" % (server.current_serial, server.version) + self.send_file(server, fn) + except IOError: + server.push_pdu(ErrorReportPDU(version = server.version, + errno = ErrorReportPDU.codes["Internal Error"], + errpdu = self, + errmsg = "Couldn't open %s" % fn)) + + +@clone_pdu +class ErrorReportPDU(rpki.rpki_rtr.pdus.ErrorReportPDU): + """ + Error Report PDU. + """ + + def serve(self, server): + """ + Received an ErrorReportPDU from client. Not much we can do beyond + logging it, then killing the connection if error was fatal. + """ + + server.logger.error(self) + if self.errno in self.fatal: + server.logger.error("[Shutting down due to reported fatal protocol error]") + sys.exit(1) + + +def read_current(version): + """ + Read current serial number and nonce. Return None for both if + serial and nonce not recorded. For backwards compatibility, treat + file containing just a serial number as having a nonce of zero. + """ + + if version is None: + return None, None + try: + with open("current.v%d" % version, "r") as f: + values = tuple(int(s) for s in f.read().split()) + return values[0], values[1] + except IndexError: + return values[0], 0 + except IOError: + return None, None + + +def write_current(serial, nonce, version): + """ + Write serial number and nonce. + """ + + curfn = "current.v%d" % version + tmpfn = curfn + "%d.tmp" % os.getpid() + with open(tmpfn, "w") as f: + f.write("%d %d\n" % (serial, nonce)) + os.rename(tmpfn, curfn) + + +def new_nonce(force_zero_nonce = False): + """ + Create and return a new nonce value. + """ + + if force_zero_nonce: + return 0 + try: + return int(random.SystemRandom().getrandbits(16)) + except NotImplementedError: + return int(random.getrandbits(16)) + + +class FileProducer(object): + """ + File-based producer object for asynchat. + """ + + def __init__(self, handle, buffersize): + self.handle = handle + self.buffersize = buffersize + + def more(self): + return self.handle.read(self.buffersize) + + +class ServerWriteChannel(rpki.rpki_rtr.channels.PDUChannel): + """ + Kludge to deal with ssh's habit of sometimes (compile time option) + invoking us with two unidirectional pipes instead of one + bidirectional socketpair. All the server logic is in the + ServerChannel class, this class just deals with sending the + server's output to a different file descriptor. + """ + + def __init__(self): + """ + Set up stdout. + """ + + super(ServerWriteChannel, self).__init__(root_pdu_class = PDU) + self.init_file_dispatcher(sys.stdout.fileno()) + + def readable(self): + """ + This channel is never readable. + """ + + return False + + def push_file(self, f): + """ + Write content of a file to stream. + """ + + try: + self.push_with_producer(FileProducer(f, self.ac_out_buffer_size)) + except OSError, e: + if e.errno != errno.EAGAIN: + raise + + +class ServerChannel(rpki.rpki_rtr.channels.PDUChannel): + """ + Server protocol engine, handles upcalls from PDUChannel to + implement protocol logic. + """ + + def __init__(self, logger): + """ + Set up stdin and stdout as connection and start listening for + first PDU. + """ + + super(ServerChannel, self).__init__(root_pdu_class = PDU) + self.init_file_dispatcher(sys.stdin.fileno()) + self.writer = ServerWriteChannel() + self.logger = logger + self.get_serial() + self.start_new_pdu() + + def writable(self): + """ + This channel is never writable. + """ + + return False + + def push(self, data): + """ + Redirect to writer channel. + """ + + return self.writer.push(data) + + def push_with_producer(self, producer): + """ + Redirect to writer channel. + """ + + return self.writer.push_with_producer(producer) + + def push_pdu(self, pdu): + """ + Redirect to writer channel. + """ + + return self.writer.push_pdu(pdu) + + def push_file(self, f): + """ + Redirect to writer channel. + """ + + return self.writer.push_file(f) + + def deliver_pdu(self, pdu): + """ + Handle received PDU. + """ + + pdu.serve(self) + + def get_serial(self): + """ + Read, cache, and return current serial number, or None if we can't + find the serial number file. The latter condition should never + happen, but maybe we got started in server mode while the cronjob + mode instance is still building its database. + """ + + self.current_serial, self.current_nonce = read_current(self.version) + return self.current_serial + + def check_serial(self): + """ + Check for a new serial number. + """ + + old_serial = self.current_serial + return old_serial != self.get_serial() + + def notify(self, data = None): + """ + Cronjob instance kicked us: check whether our serial number has + changed, and send a notify message if so. + + We have to check rather than just blindly notifying when kicked + because the cronjob instance has no good way of knowing which + protocol version we're running, thus has no good way of knowing + whether we care about a particular change set or not. + """ + + if self.check_serial(): + self.push_pdu(SerialNotifyPDU(version = self.version, + serial = self.current_serial, + nonce = self.current_nonce)) + else: + self.logger.debug("Cronjob kicked me but I see no serial change, ignoring") + + +class KickmeChannel(asyncore.dispatcher, object): + """ + asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to + kick servers when it's time to send notify PDUs to clients. + """ + + def __init__(self, server): + asyncore.dispatcher.__init__(self) # Old-style class + self.server = server + self.sockname = "%s.%d" % (kickme_base, os.getpid()) + self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) + try: + self.bind(self.sockname) + os.chmod(self.sockname, 0660) + except socket.error, e: + self.server.logger.exception("Couldn't bind() kickme socket: %r", e) + self.close() + except OSError, e: + self.server.logger.exception("Couldn't chmod() kickme socket: %r", e) + + def writable(self): + """ + This socket is read-only, never writable. + """ + + return False + + def handle_connect(self): + """ + Ignore connect events (not very useful on datagram socket). + """ + + pass + + def handle_read(self): + """ + Handle receipt of a datagram. + """ + + data = self.recv(512) + self.server.notify(data) + + def cleanup(self): + """ + Clean up this dispatcher's socket. + """ + + self.close() + try: + os.unlink(self.sockname) + except: # pylint: disable=W0702 + pass + + def log(self, msg): + """ + Intercept asyncore's logging. + """ + + self.server.logger.info(msg) + + def log_info(self, msg, tag = "info"): + """ + Intercept asyncore's logging. + """ + + self.server.logger.info("asyncore: %s: %s", tag, msg) + + def handle_error(self): + """ + Handle errors caught by asyncore main loop. + """ + + self.server.logger.exception("[Unhandled exception]") + self.server.logger.critical("[Exiting after unhandled exception]") + sys.exit(1) + + +def _hostport_tag(): + """ + Construct hostname/address + port when we're running under a + protocol we understand well enough to do that. This is all + kludgery. Just grit your teeth, or perhaps just close your eyes. + """ + + proto = None + + if proto is None: + try: + host, port = socket.fromfd(0, socket.AF_INET, socket.SOCK_STREAM).getpeername() + proto = "tcp" + except: # pylint: disable=W0702 + pass + + if proto is None: + try: + host, port = socket.fromfd(0, socket.AF_INET6, socket.SOCK_STREAM).getpeername()[0:2] + proto = "tcp" + except: # pylint: disable=W0702 + pass + + if proto is None: + try: + host, port = os.environ["SSH_CONNECTION"].split()[0:2] + proto = "ssh" + except: # pylint: disable=W0702 + pass + + if proto is None: + try: + host, port = os.environ["REMOTE_HOST"], os.getenv("REMOTE_PORT") + proto = "ssl" + except: # pylint: disable=W0702 + pass + + if proto is None: + return "" + elif not port: + return "/%s/%s" % (proto, host) + elif ":" in host: + return "/%s/%s.%s" % (proto, host, port) + else: + return "/%s/%s:%s" % (proto, host, port) + + +def server_main(args): + """ + Implement the server side of the rpkk-router protocol. Other than + one PF_UNIX socket inode, this doesn't write anything to disk, so it + can be run with minimal privileges. Most of the hard work has + already been done in --cronjob mode, so all that this mode has to do + is serve up the results. + + In production use this server should run under sshd. The subsystem + mechanism in sshd does not allow us to pass arguments on the command + line, so setting this up might require a wrapper script, but in + production use you will probably want to lock down the public key + used to authenticate the ssh session so that it can only run this + one command, in which case you can just specify the full command + including any arguments in the authorized_keys file. + + Unless you do something special, sshd will have this program running + in whatever it thinks is the home directory associated with the + username given in the ssh prototocol setup, so it may be easiest to + set this up so that the home directory sshd puts this program into + is the one where --cronjob left its files for this mode to pick up. + + This mode must be run in the directory where you ran --cronjob mode. + + This mode takes one optional argument: if provided, the argument is + the name of a directory to which the program should chdir() on + startup; this may simplify setup when running under inetd. + + The server is event driven, so everything interesting happens in the + channel classes. + """ + + logger = logging.LoggerAdapter(logging.root, dict(connection = _hostport_tag())) + + logger.debug("[Starting]") + + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) + + if args.force_zero_nonce: + logger.warning("--force_zero_nonce not implemented at the moment, ignoring") + + kickme = None + try: + server = rpki.rpki_rtr.server.ServerChannel(logger = logger) + kickme = rpki.rpki_rtr.server.KickmeChannel(server = server) + asyncore.loop(timeout = None) + except KeyboardInterrupt: + sys.exit(0) + finally: + if kickme is not None: + kickme.cleanup() + + +def listener_main(args): + """ + Simple plain-TCP listener. Listens on a specified TCP port, upon + receiving a connection, forks the process and starts child executing + at server_main(). + + First argument (required) is numeric port number. + + Second argument (optional) is directory, like --server. + + NB: plain-TCP is completely insecure. We only implement this + because it's all that the routers currently support. In theory, we + will all be running TCP-AO in the future, at which point this will + go away. + """ + + # Perhaps we should daemonize? Deal with that later. + + if args.rpki_rtr_dir: + try: + os.chdir(args.rpki_rtr_dir) + except OSError, e: + sys.exit(e) + + listener = None + try: + listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + listener.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except: # pylint: disable=W0702 + if listener is not None: + listener.close() + listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except AttributeError: + pass + listener.bind(("", args.port)) + listener.listen(5) + logging.debug("[Listening on port %s]", args.port) + while True: + s, ai = listener.accept() + logging.debug("[Received connection from %r]", ai) + pid = os.fork() + if pid == 0: + os.dup2(s.fileno(), 0) # pylint: disable=E1103 + os.dup2(s.fileno(), 1) # pylint: disable=E1103 + s.close() + #os.closerange(3, os.sysconf("SC_OPEN_MAX")) + # + logging.warning("Should be reconfiguring logging here, but we're lame") + #global log_tag + #log_tag = "rtr-origin/server" + rpki.rpki_rtr.server.hostport_tag() + #syslog.closelog() + #syslog.openlog(log_tag, syslog.LOG_PID, syslog_facility) + server_main(()) + sys.exit() + else: + logging.debug("[Spawned server %d]", pid) + try: + while True: + pid, status = os.waitpid(0, os.WNOHANG) # pylint: disable=W0612 + if pid: + logging.debug("[Server %s exited]", pid) + else: + break + except: # pylint: disable=W0702 + pass + + +def argparse_setup(subparsers): + """ + Set up argparse stuff for commands in this module. + """ + + subparser = subparsers.add_parser("server", description = server_main.__doc__, + help = "RPKI-RTR protocol server") + subparser.set_defaults(func = server_main, default_log_to = "syslog") + subparser.add_argument("--force_zero_nonce", action = "store_true", help = "force nonce value of zero") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") + + subparser = subparsers.add_parser("listener", description = listener_main.__doc__, + help = "TCP listener for RPKI-RTR protocol server") + subparser.set_defaults(func = listener_main, default_log_to = "syslog") + subparser.add_argument("port", type = int, help = "TCP port on which to listen") + subparser.add_argument("rpki_rtr_dir", nargs = "?", help = "directory containing RPKI-RTR database") |