diff options
Diffstat (limited to 'rpki/rtr/pdus.py')
-rw-r--r-- | rpki/rtr/pdus.py | 641 |
1 files changed, 641 insertions, 0 deletions
diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py new file mode 100644 index 00000000..d8921a07 --- /dev/null +++ b/rpki/rtr/pdus.py @@ -0,0 +1,641 @@ +# $Id$ +# +# Copyright (C) 2014 Dragon Research Labs ("DRL") +# Portions copyright (C) 2009-2013 Internet Systems Consortium ("ISC") +# +# Permission to use, copy, modify, and distribute this software for any +# purpose with or without fee is hereby granted, provided that the above +# copyright notices and this permission notice appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND DRL AND ISC DISCLAIM ALL +# WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL DRL OR +# ISC BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL +# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA +# OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER +# TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. + +""" +PDU classes for the RPKI-RTR protocol (RFC 6810 et sequalia). +""" + +import struct +import base64 +import logging +import rpki.POW + +# Exceptions + +class PDUException(Exception): + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the ErrorReportPDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + super(PDUException, self).__init__() + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) + +class UnsupportedProtocolVersion(PDUException): + error_report_code = 4 + +class UnsupportedPDUType(PDUException): + error_report_code = 5 + +class CorruptData(PDUException): + error_report_code = 0 + +# Decorators + +def wire_pdu(cls, versions = None): + """ + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. + """ + + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls + return cls + + +def wire_pdu_only(*versions): + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ + + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) + +def clone_pdu_root(root_pdu_class): + """ + Replace a PDU root class's version_map with a two-level deep copy of itself, + and return a class decorator which subclasses can use to replace their + parent classes with themselves in the resulting cloned version map. + + This function is not itself a decorator, it returns one. + """ + + root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) + + def decorator(cls): + for pdu_map in root_pdu_class.version_map.itervalues(): + for pdu_type, pdu_class in pdu_map.items(): + if pdu_class in cls.__bases__: + pdu_map[pdu_type] = cls + + return decorator + + +# PDUs + +class PDU(object): + """ + Base PDU. Real PDUs are subclasses of this class. + """ + + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu + + _pdu = None # Cached when first generated + + header_struct = struct.Struct("!BB2xL") + + def __init__(self, version): + assert version in self.version_map + self.version = version + + def __cmp__(self, other): + return cmp(self.to_pdu(), other.to_pdu()) + + @property + def default_version(self): + return max(self.version_map.iterkeys()) + + def check(self): + pass + + @classmethod + def read_pdu(cls, reader): + return reader.update(need = cls.header_struct.size, callback = cls.got_header) + + @classmethod + def got_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.header_struct.size + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) + self = cls.version_map[version][pdu_type](version = version) + return reader.update(need = length, callback = self.got_pdu) + + +class PDUWithSerial(PDU): + """ + Base class for PDUs consisting of just a serial number and nonce. + """ + + header_struct = struct.Struct("!BBHLL") + + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + + +class PDUWithNonce(PDU): + """ + Base class for PDUs consisting of just a nonce. + """ + + header_struct = struct.Struct("!BBHL") + + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + + +class PDUEmpty(PDU): + """ + Base class for empty PDUs. + """ + + header_struct = struct.Struct("!BBHL") + + def __str__(self): + return "[%s]" % self.__class__.__name__ + + def to_pdu(self): + """ + Generate the wire format PDU for this prefix. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + +@wire_pdu +class SerialNotifyPDU(PDUWithSerial): + """ + Serial Notify PDU. + """ + + pdu_type = 0 + + +@wire_pdu +class SerialQueryPDU(PDUWithSerial): + """ + Serial Query PDU. + """ + + pdu_type = 1 + + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + + +@wire_pdu +class ResetQueryPDU(PDUEmpty): + """ + Reset Query PDU. + """ + + pdu_type = 2 + + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + + +@wire_pdu +class CacheResponsePDU(PDUWithNonce): + """ + Cache Response PDU. + """ + + pdu_type = 3 + + +def EndOfDataPDU(version, *args, **kwargs): + """ + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ + + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError + + +@wire_pdu_only(0) +class EndOfDataPDUv0(PDUWithSerial): + """ + End of Data PDU, protocol version 0. + """ + + pdu_type = 7 + + # Default values, from the current RFC 6810 bis I-D. + # Putting these here lets us use them in our client API for both + # protocol versions, even though they can only be set in the + # protocol in version 1. + + refresh = 3600 + retry = 600 + expire = 7200 + + +@wire_pdu_only(1) +class EndOfDataPDUv1(EndOfDataPDUv0): + """ + End of Data PDU, protocol version 1. + """ + + header_struct = struct.Struct("!BBHLLLLL") + + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv1, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + if refresh is not None: + assert isinstance(refresh, int) + self.refresh = refresh + if retry is not None: + assert isinstance(retry, int) + self.retry = retry + if expire is not None: + assert isinstance(expire, int) + self.expire = expire + + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self + + +@wire_pdu +class CacheResetPDU(PDUEmpty): + """ + Cache reset PDU. + """ + + pdu_type = 8 + + +class PrefixPDU(PDU): + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + header_struct = struct.Struct("!BB2xLBBBx") + asnum_struct = struct.Struct("!L") + + def __str__(self): + plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def show(self): + logging.debug("# Class: %s", self.__class__.__name__) + logging.debug("# ASN: %s", self.asn) + logging.debug("# Prefix: %s", self.prefix) + logging.debug("# Prefixlen: %s", self.prefixlen) + logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) + logging.debug("# Announce: %s", self.announce) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefix.bits != self.address_byte_count * 8: + raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, + announce if announce is not None else self.announce, + self.prefixlen, self.max_prefixlen) + + self.prefix.toBytes() + + self.asnum_struct.pack(self.asn)) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.get(self.header_struct.size) + b2 = reader.get(self.address_byte_count) + b3 = reader.get(self.asnum_struct.size) + version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) + self.prefix = rpki.POW.IPAddress.fromBytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() + return self + + +@wire_pdu +class IPv4PrefixPDU(PrefixPDU): + """ + IPv4 flavor of a prefix. + """ + + pdu_type = 4 + address_byte_count = 4 + +@wire_pdu +class IPv6PrefixPDU(PrefixPDU): + """ + IPv6 flavor of a prefix. + """ + + pdu_type = 6 + address_byte_count = 16 + +@wire_pdu_only(1) +class RouterKeyPDU(PDU): + """ + Router Key PDU. + """ + + pdu_type = 9 + + header_struct = struct.Struct("!BBBxL20sL") + + def __str__(self): + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, + base64.urlsafe_b64encode(self.ski).rstrip("="), + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if len(self.ski) != 20: + raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) + pdulen = self.header_struct.size + len(self.key) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + len(self.key) + pdu = (self.header_struct.pack(self.version, + self.pdu_type, + announce if announce is not None else self.announce, + pdulen, + self.ski, + self.asn) + + self.key) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + if remaining <= 0: + raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) + self.key = reader.get(remaining) + assert header + self.key == self.to_pdu() + return self + + +@wire_pdu +class ErrorReportPDU(PDU): + """ + Error Report PDU. + """ + + pdu_type = 10 + + header_struct = struct.Struct("!BBHL") + string_struct = struct.Struct("!L") + + errors = { + 2 : "No Data Available" } + + fatal = { + 0 : "Corrupt Data", + 1 : "Internal Error", + 3 : "Invalid Request", + 4 : "Unsupported Protocol Version", + 5 : "Unsupported PDU Type", + 6 : "Withdrawal of Unknown Record", + 7 : "Duplicate Announcement Received" } + + assert set(errors) & set(fatal) == set() + + errors.update(fatal) + + codes = dict((v, k) for k, v in errors.items()) + + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) + assert errno is None or errno in self.errors + self.errno = errno + self.errpdu = errpdu + self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] + + def __str__(self): + return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) + + def to_counted_string(self, s): + return self.string_struct.pack(len(s)) + s + + def read_counted_string(self, reader, remaining): + assert remaining >= self.string_struct.size + n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] + assert remaining >= self.string_struct.size + n + return n, reader.get(n), (remaining - self.string_struct.size - n) + + def to_pdu(self): + """ + Generate the wire format PDU for this error report. + """ + + if self._pdu is None: + assert isinstance(self.errno, int) + assert not isinstance(self.errpdu, ErrorReportPDU) + p = self.errpdu + if p is None: + p = "" + elif isinstance(p, PDU): + p = p.to_pdu() + assert isinstance(p, str) + pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) + self._pdu += self.to_counted_string(p) + self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) + self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) + assert (header + + self.to_counted_string(self.errpdu) + + self.to_counted_string(self.errmsg.encode("utf8")) + == self.to_pdu()) + return self |