diff options
Diffstat (limited to 'rpki/rtr/pdus.py')
-rw-r--r-- | rpki/rtr/pdus.py | 960 |
1 files changed, 480 insertions, 480 deletions
diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py index 0d2e5928..94f579a1 100644 --- a/rpki/rtr/pdus.py +++ b/rpki/rtr/pdus.py @@ -28,292 +28,292 @@ import rpki.POW # Exceptions class PDUException(Exception): - """ - Parent exception type for exceptions that signal particular protocol - errors. String value of exception instance will be the message to - put in the ErrorReportPDU, error_report_code value of exception - will be the numeric code to use. - """ - - def __init__(self, msg = None, pdu = None): - super(PDUException, self).__init__() - assert msg is None or isinstance(msg, (str, unicode)) - self.error_report_msg = msg - self.error_report_pdu = pdu - - def __str__(self): - return self.error_report_msg or self.__class__.__name__ - - def make_error_report(self, version): - return ErrorReportPDU(version = version, - errno = self.error_report_code, - errmsg = self.error_report_msg, - errpdu = self.error_report_pdu) + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the ErrorReportPDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + super(PDUException, self).__init__() + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) class UnsupportedProtocolVersion(PDUException): - error_report_code = 4 + error_report_code = 4 class UnsupportedPDUType(PDUException): - error_report_code = 5 + error_report_code = 5 class CorruptData(PDUException): - error_report_code = 0 + error_report_code = 0 # Decorators def wire_pdu(cls, versions = None): - """ - Class decorator to add a PDU class to the set of known PDUs - for all supported protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. + """ - for v in PDU.version_map.iterkeys() if versions is None else versions: - assert cls.pdu_type not in PDU.version_map[v] - PDU.version_map[v][cls.pdu_type] = cls - return cls + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls + return cls def wire_pdu_only(*versions): - """ - Class decorator to add a PDU class to the set of known PDUs - for specific protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ - assert versions and all(v in PDU.version_map for v in versions) - return lambda cls: wire_pdu(cls, versions) + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) def clone_pdu_root(root_pdu_class): - """ - Replace a PDU root class's version_map with a two-level deep copy of itself, - and return a class decorator which subclasses can use to replace their - parent classes with themselves in the resulting cloned version map. + """ + Replace a PDU root class's version_map with a two-level deep copy of itself, + and return a class decorator which subclasses can use to replace their + parent classes with themselves in the resulting cloned version map. - This function is not itself a decorator, it returns one. - """ + This function is not itself a decorator, it returns one. + """ - root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) + root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) - def decorator(cls): - for pdu_map in root_pdu_class.version_map.itervalues(): - for pdu_type, pdu_class in pdu_map.items(): - if pdu_class in cls.__bases__: - pdu_map[pdu_type] = cls - return cls + def decorator(cls): + for pdu_map in root_pdu_class.version_map.itervalues(): + for pdu_type, pdu_class in pdu_map.items(): + if pdu_class in cls.__bases__: + pdu_map[pdu_type] = cls + return cls - return decorator + return decorator # PDUs class PDU(object): - """ - Base PDU. Real PDUs are subclasses of this class. - """ + """ + Base PDU. Real PDUs are subclasses of this class. + """ - version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu - _pdu = None # Cached when first generated + _pdu = None # Cached when first generated - header_struct = struct.Struct("!BB2xL") + header_struct = struct.Struct("!BB2xL") - def __init__(self, version): - assert version in self.version_map - self.version = version + def __init__(self, version): + assert version in self.version_map + self.version = version - def __cmp__(self, other): - return cmp(self.to_pdu(), other.to_pdu()) + def __cmp__(self, other): + return cmp(self.to_pdu(), other.to_pdu()) - @property - def default_version(self): - return max(self.version_map.iterkeys()) + @property + def default_version(self): + return max(self.version_map.iterkeys()) - def check(self): - pass + def check(self): + pass - @classmethod - def read_pdu(cls, reader): - return reader.update(need = cls.header_struct.size, callback = cls.got_header) + @classmethod + def read_pdu(cls, reader): + return reader.update(need = cls.header_struct.size, callback = cls.got_header) - @classmethod - def got_header(cls, reader): - if not reader.ready(): - return None - assert reader.available() >= cls.header_struct.size - version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - reader.check_version(version) - if pdu_type not in cls.version_map[version]: - raise UnsupportedPDUType( - "Received unsupported PDU type %d" % pdu_type) - if length < 8: - raise CorruptData( - "Received PDU with length %d, which is too short to be valid" % length) - self = cls.version_map[version][pdu_type](version = version) - return reader.update(need = length, callback = self.got_pdu) + @classmethod + def got_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.header_struct.size + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) + self = cls.version_map[version][pdu_type](version = version) + return reader.update(need = length, callback = self.got_pdu) class PDUWithSerial(PDU): - """ - Base class for PDUs consisting of just a serial number and nonce. - """ - - header_struct = struct.Struct("!BBHLL") - - def __init__(self, version, serial = None, nonce = None): - super(PDUWithSerial, self).__init__(version) - if serial is not None: - assert isinstance(serial, int) - self.serial = serial - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a serial number and nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 12: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHLL") + + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self class PDUWithNonce(PDU): - """ - Base class for PDUs consisting of just a nonce. - """ - - header_struct = struct.Struct("!BBHL") - - def __init__(self, version, nonce = None): - super(PDUWithNonce, self).__init__(version) - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) - return self._pdu + header_struct = struct.Struct("!BBHL") - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) -class PDUEmpty(PDU): - """ - Base class for empty PDUs. - """ + def to_pdu(self): + """ + Generate the wire format PDU. + """ - header_struct = struct.Struct("!BBHL") + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu - def __str__(self): - return "[%s]" % self.__class__.__name__ + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self - def to_pdu(self): + +class PDUEmpty(PDU): """ - Generate the wire format PDU for this prefix. + Base class for empty PDUs. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, zero, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if zero != 0: - raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHL") + + def __str__(self): + return "[%s]" % self.__class__.__name__ + + def to_pdu(self): + """ + Generate the wire format PDU for this prefix. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class SerialNotifyPDU(PDUWithSerial): - """ - Serial Notify PDU. - """ + """ + Serial Notify PDU. + """ - pdu_type = 0 + pdu_type = 0 @wire_pdu class SerialQueryPDU(PDUWithSerial): - """ - Serial Query PDU. - """ + """ + Serial Query PDU. + """ - pdu_type = 1 + pdu_type = 1 - def __init__(self, version, serial = None, nonce = None): - super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) @wire_pdu class ResetQueryPDU(PDUEmpty): - """ - Reset Query PDU. - """ + """ + Reset Query PDU. + """ - pdu_type = 2 + pdu_type = 2 - def __init__(self, version): - super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) @wire_pdu class CacheResponsePDU(PDUWithNonce): - """ - Cache Response PDU. - """ + """ + Cache Response PDU. + """ - pdu_type = 3 + pdu_type = 3 def EndOfDataPDU(version, *args, **kwargs): - """ - Factory for the EndOfDataPDU classes, which take different forms in - different protocol versions. - """ + """ + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ - if version == 0: - return EndOfDataPDUv0(version, *args, **kwargs) - if version == 1: - return EndOfDataPDUv1(version, *args, **kwargs) - raise NotImplementedError + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError # Min, max, and default values, from the current RFC 6810 bis I-D. @@ -324,325 +324,325 @@ def EndOfDataPDU(version, *args, **kwargs): default_refresh = 3600 def valid_refresh(refresh): - if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: - raise ValueError - return refresh + if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: + raise ValueError + return refresh default_retry = 600 def valid_retry(retry): - if not isinstance(retry, int) or retry < 120 or retry > 7200: - raise ValueError - return retry + if not isinstance(retry, int) or retry < 120 or retry > 7200: + raise ValueError + return retry default_expire = 7200 def valid_expire(expire): - if not isinstance(expire, int) or expire < 600 or expire > 172800: - raise ValueError - return expire + if not isinstance(expire, int) or expire < 600 or expire > 172800: + raise ValueError + return expire @wire_pdu_only(0) class EndOfDataPDUv0(PDUWithSerial): - """ - End of Data PDU, protocol version 0. - """ + """ + End of Data PDU, protocol version 0. + """ - pdu_type = 7 + pdu_type = 7 - def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): - super(EndOfDataPDUv0, self).__init__(version, serial, nonce) - self.refresh = valid_refresh(default_refresh if refresh is None else refresh) - self.retry = valid_retry( default_retry if retry is None else retry) - self.expire = valid_expire( default_expire if expire is None else expire) + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv0, self).__init__(version, serial, nonce) + self.refresh = valid_refresh(default_refresh if refresh is None else refresh) + self.retry = valid_retry( default_retry if retry is None else retry) + self.expire = valid_expire( default_expire if expire is None else expire) @wire_pdu_only(1) class EndOfDataPDUv1(EndOfDataPDUv0): - """ - End of Data PDU, protocol version 1. - """ + """ + End of Data PDU, protocol version 1. + """ - header_struct = struct.Struct("!BBHLLLLL") + header_struct = struct.Struct("!BBHLLLLL") - def __str__(self): - return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( - self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) - def to_pdu(self): - """ - Generate the wire format PDU. - """ + def to_pdu(self): + """ + Generate the wire format PDU. + """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial, - self.refresh, self.retry, self.expire) - return self._pdu + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ - = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 24: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class CacheResetPDU(PDUEmpty): - """ - Cache reset PDU. - """ + """ + Cache reset PDU. + """ - pdu_type = 8 + pdu_type = 8 class PrefixPDU(PDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - header_struct = struct.Struct("!BB2xLBBBx") - asnum_struct = struct.Struct("!L") - - def __str__(self): - plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def show(self): - logging.debug("# Class: %s", self.__class__.__name__) - logging.debug("# ASN: %s", self.asn) - logging.debug("# Prefix: %s", self.prefix) - logging.debug("# Prefixlen: %s", self.prefixlen) - logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) - logging.debug("# Announce: %s", self.announce) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if self.prefix.bits != self.address_byte_count * 8: - raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) - if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: - raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) - if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: - raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - """ - Generate the wire format PDU for this prefix. - """ - - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, - announce if announce is not None else self.announce, - self.prefixlen, self.max_prefixlen) + - self.prefix.toBytes() + - self.asnum_struct.pack(self.asn)) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b1 = reader.get(self.header_struct.size) - b2 = reader.get(self.address_byte_count) - b3 = reader.get(self.asnum_struct.size) - version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) - assert version == self.version and pdu_type == self.pdu_type - if length != len(b1) + len(b2) + len(b3): - raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) - self.prefix = rpki.POW.IPAddress.fromBytes(b2) - self.asn = self.asnum_struct.unpack(b3)[0] - assert b1 + b2 + b3 == self.to_pdu() - return self + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + header_struct = struct.Struct("!BB2xLBBBx") + asnum_struct = struct.Struct("!L") + + def __str__(self): + plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def show(self): + logging.debug("# Class: %s", self.__class__.__name__) + logging.debug("# ASN: %s", self.asn) + logging.debug("# Prefix: %s", self.prefix) + logging.debug("# Prefixlen: %s", self.prefixlen) + logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) + logging.debug("# Announce: %s", self.announce) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefix.bits != self.address_byte_count * 8: + raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, + announce if announce is not None else self.announce, + self.prefixlen, self.max_prefixlen) + + self.prefix.toBytes() + + self.asnum_struct.pack(self.asn)) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.get(self.header_struct.size) + b2 = reader.get(self.address_byte_count) + b3 = reader.get(self.asnum_struct.size) + version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) + self.prefix = rpki.POW.IPAddress.fromBytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() + return self @wire_pdu class IPv4PrefixPDU(PrefixPDU): - """ - IPv4 flavor of a prefix. - """ + """ + IPv4 flavor of a prefix. + """ - pdu_type = 4 - address_byte_count = 4 + pdu_type = 4 + address_byte_count = 4 @wire_pdu class IPv6PrefixPDU(PrefixPDU): - """ - IPv6 flavor of a prefix. - """ + """ + IPv6 flavor of a prefix. + """ - pdu_type = 6 - address_byte_count = 16 + pdu_type = 6 + address_byte_count = 16 @wire_pdu_only(1) class RouterKeyPDU(PDU): - """ - Router Key PDU. - """ - - pdu_type = 9 - - header_struct = struct.Struct("!BBBxL20sL") - - def __str__(self): - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, - base64.urlsafe_b64encode(self.ski).rstrip("="), - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if len(self.ski) != 20: - raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) - pdulen = self.header_struct.size + len(self.key) - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + len(self.key) - pdu = (self.header_struct.pack(self.version, - self.pdu_type, - announce if announce is not None else self.announce, - pdulen, - self.ski, - self.asn) - + self.key) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - if remaining <= 0: - raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) - self.key = reader.get(remaining) - assert header + self.key == self.to_pdu() - return self + """ + Router Key PDU. + """ + + pdu_type = 9 + + header_struct = struct.Struct("!BBBxL20sL") + + def __str__(self): + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, + base64.urlsafe_b64encode(self.ski).rstrip("="), + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if len(self.ski) != 20: + raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) + pdulen = self.header_struct.size + len(self.key) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + len(self.key) + pdu = (self.header_struct.pack(self.version, + self.pdu_type, + announce if announce is not None else self.announce, + pdulen, + self.ski, + self.asn) + + self.key) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + if remaining <= 0: + raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) + self.key = reader.get(remaining) + assert header + self.key == self.to_pdu() + return self @wire_pdu class ErrorReportPDU(PDU): - """ - Error Report PDU. - """ - - pdu_type = 10 - - header_struct = struct.Struct("!BBHL") - string_struct = struct.Struct("!L") - - errors = { - 2 : "No Data Available" } - - fatal = { - 0 : "Corrupt Data", - 1 : "Internal Error", - 3 : "Invalid Request", - 4 : "Unsupported Protocol Version", - 5 : "Unsupported PDU Type", - 6 : "Withdrawal of Unknown Record", - 7 : "Duplicate Announcement Received" } - - assert set(errors) & set(fatal) == set() - - errors.update(fatal) - - codes = dict((v, k) for k, v in errors.items()) - - def __init__(self, version, errno = None, errpdu = None, errmsg = None): - super(ErrorReportPDU, self).__init__(version) - assert errno is None or errno in self.errors - self.errno = errno - self.errpdu = errpdu - self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] - - def __str__(self): - return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) - - def to_counted_string(self, s): - return self.string_struct.pack(len(s)) + s - - def read_counted_string(self, reader, remaining): - assert remaining >= self.string_struct.size - n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] - assert remaining >= self.string_struct.size + n - return n, reader.get(n), (remaining - self.string_struct.size - n) - - def to_pdu(self): - """ - Generate the wire format PDU for this error report. - """ - - if self._pdu is None: - assert isinstance(self.errno, int) - assert not isinstance(self.errpdu, ErrorReportPDU) - p = self.errpdu - if p is None: - p = "" - elif isinstance(p, PDU): - p = p.to_pdu() - assert isinstance(p, str) - pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) - self._pdu += self.to_counted_string(p) - self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.errno, length = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) - self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) - if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: - raise CorruptData("Got PDU length %d, expected %d" % ( - length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) - assert (header - + self.to_counted_string(self.errpdu) - + self.to_counted_string(self.errmsg.encode("utf8")) - == self.to_pdu()) - return self + """ + Error Report PDU. + """ + + pdu_type = 10 + + header_struct = struct.Struct("!BBHL") + string_struct = struct.Struct("!L") + + errors = { + 2 : "No Data Available" } + + fatal = { + 0 : "Corrupt Data", + 1 : "Internal Error", + 3 : "Invalid Request", + 4 : "Unsupported Protocol Version", + 5 : "Unsupported PDU Type", + 6 : "Withdrawal of Unknown Record", + 7 : "Duplicate Announcement Received" } + + assert set(errors) & set(fatal) == set() + + errors.update(fatal) + + codes = dict((v, k) for k, v in errors.items()) + + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) + assert errno is None or errno in self.errors + self.errno = errno + self.errpdu = errpdu + self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] + + def __str__(self): + return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) + + def to_counted_string(self, s): + return self.string_struct.pack(len(s)) + s + + def read_counted_string(self, reader, remaining): + assert remaining >= self.string_struct.size + n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] + assert remaining >= self.string_struct.size + n + return n, reader.get(n), (remaining - self.string_struct.size - n) + + def to_pdu(self): + """ + Generate the wire format PDU for this error report. + """ + + if self._pdu is None: + assert isinstance(self.errno, int) + assert not isinstance(self.errpdu, ErrorReportPDU) + p = self.errpdu + if p is None: + p = "" + elif isinstance(p, PDU): + p = p.to_pdu() + assert isinstance(p, str) + pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) + self._pdu += self.to_counted_string(p) + self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) + self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) + assert (header + + self.to_counted_string(self.errpdu) + + self.to_counted_string(self.errmsg.encode("utf8")) + == self.to_pdu()) + return self |