diff options
author | Rob Austein <sra@hactrn.net> | 2015-10-26 06:29:00 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2015-10-26 06:29:00 +0000 |
commit | b46deb1417dc3596e9ac9fe2fe8cc0b7f42457e7 (patch) | |
tree | ca0dc0276d1adc168bc3337ce0564c4ec4957c1b /rpki/rtr/pdus.py | |
parent | 397beaf6d9900dc3b3cb612c89ebf1d57b1d16f6 (diff) |
"Any programmer who fails to comply with the standard naming, formatting,
or commenting conventions should be shot. If it so happens that it is
inconvenient to shoot him, then he is to be politely requested to recode
his program in adherence to the above standard."
-- Michael Spier, Digital Equipment Corporation
svn path=/branches/tk705/; revision=6152
Diffstat (limited to 'rpki/rtr/pdus.py')
-rw-r--r-- | rpki/rtr/pdus.py | 960 |
1 files changed, 480 insertions, 480 deletions
diff --git a/rpki/rtr/pdus.py b/rpki/rtr/pdus.py index 0d2e5928..94f579a1 100644 --- a/rpki/rtr/pdus.py +++ b/rpki/rtr/pdus.py @@ -28,292 +28,292 @@ import rpki.POW # Exceptions class PDUException(Exception): - """ - Parent exception type for exceptions that signal particular protocol - errors. String value of exception instance will be the message to - put in the ErrorReportPDU, error_report_code value of exception - will be the numeric code to use. - """ - - def __init__(self, msg = None, pdu = None): - super(PDUException, self).__init__() - assert msg is None or isinstance(msg, (str, unicode)) - self.error_report_msg = msg - self.error_report_pdu = pdu - - def __str__(self): - return self.error_report_msg or self.__class__.__name__ - - def make_error_report(self, version): - return ErrorReportPDU(version = version, - errno = self.error_report_code, - errmsg = self.error_report_msg, - errpdu = self.error_report_pdu) + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the ErrorReportPDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + super(PDUException, self).__init__() + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self, version): + return ErrorReportPDU(version = version, + errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) class UnsupportedProtocolVersion(PDUException): - error_report_code = 4 + error_report_code = 4 class UnsupportedPDUType(PDUException): - error_report_code = 5 + error_report_code = 5 class CorruptData(PDUException): - error_report_code = 0 + error_report_code = 0 # Decorators def wire_pdu(cls, versions = None): - """ - Class decorator to add a PDU class to the set of known PDUs - for all supported protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for all supported protocol versions. + """ - for v in PDU.version_map.iterkeys() if versions is None else versions: - assert cls.pdu_type not in PDU.version_map[v] - PDU.version_map[v][cls.pdu_type] = cls - return cls + for v in PDU.version_map.iterkeys() if versions is None else versions: + assert cls.pdu_type not in PDU.version_map[v] + PDU.version_map[v][cls.pdu_type] = cls + return cls def wire_pdu_only(*versions): - """ - Class decorator to add a PDU class to the set of known PDUs - for specific protocol versions. - """ + """ + Class decorator to add a PDU class to the set of known PDUs + for specific protocol versions. + """ - assert versions and all(v in PDU.version_map for v in versions) - return lambda cls: wire_pdu(cls, versions) + assert versions and all(v in PDU.version_map for v in versions) + return lambda cls: wire_pdu(cls, versions) def clone_pdu_root(root_pdu_class): - """ - Replace a PDU root class's version_map with a two-level deep copy of itself, - and return a class decorator which subclasses can use to replace their - parent classes with themselves in the resulting cloned version map. + """ + Replace a PDU root class's version_map with a two-level deep copy of itself, + and return a class decorator which subclasses can use to replace their + parent classes with themselves in the resulting cloned version map. - This function is not itself a decorator, it returns one. - """ + This function is not itself a decorator, it returns one. + """ - root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) + root_pdu_class.version_map = dict((k, v.copy()) for k, v in root_pdu_class.version_map.iteritems()) - def decorator(cls): - for pdu_map in root_pdu_class.version_map.itervalues(): - for pdu_type, pdu_class in pdu_map.items(): - if pdu_class in cls.__bases__: - pdu_map[pdu_type] = cls - return cls + def decorator(cls): + for pdu_map in root_pdu_class.version_map.itervalues(): + for pdu_type, pdu_class in pdu_map.items(): + if pdu_class in cls.__bases__: + pdu_map[pdu_type] = cls + return cls - return decorator + return decorator # PDUs class PDU(object): - """ - Base PDU. Real PDUs are subclasses of this class. - """ + """ + Base PDU. Real PDUs are subclasses of this class. + """ - version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu + version_map = {0 : {}, 1 : {}} # Updated by @wire_pdu - _pdu = None # Cached when first generated + _pdu = None # Cached when first generated - header_struct = struct.Struct("!BB2xL") + header_struct = struct.Struct("!BB2xL") - def __init__(self, version): - assert version in self.version_map - self.version = version + def __init__(self, version): + assert version in self.version_map + self.version = version - def __cmp__(self, other): - return cmp(self.to_pdu(), other.to_pdu()) + def __cmp__(self, other): + return cmp(self.to_pdu(), other.to_pdu()) - @property - def default_version(self): - return max(self.version_map.iterkeys()) + @property + def default_version(self): + return max(self.version_map.iterkeys()) - def check(self): - pass + def check(self): + pass - @classmethod - def read_pdu(cls, reader): - return reader.update(need = cls.header_struct.size, callback = cls.got_header) + @classmethod + def read_pdu(cls, reader): + return reader.update(need = cls.header_struct.size, callback = cls.got_header) - @classmethod - def got_header(cls, reader): - if not reader.ready(): - return None - assert reader.available() >= cls.header_struct.size - version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - reader.check_version(version) - if pdu_type not in cls.version_map[version]: - raise UnsupportedPDUType( - "Received unsupported PDU type %d" % pdu_type) - if length < 8: - raise CorruptData( - "Received PDU with length %d, which is too short to be valid" % length) - self = cls.version_map[version][pdu_type](version = version) - return reader.update(need = length, callback = self.got_pdu) + @classmethod + def got_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.header_struct.size + version, pdu_type, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) + reader.check_version(version) + if pdu_type not in cls.version_map[version]: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) + self = cls.version_map[version][pdu_type](version = version) + return reader.update(need = length, callback = self.got_pdu) class PDUWithSerial(PDU): - """ - Base class for PDUs consisting of just a serial number and nonce. - """ - - header_struct = struct.Struct("!BBHLL") - - def __init__(self, version, serial = None, nonce = None): - super(PDUWithSerial, self).__init__(version) - if serial is not None: - assert isinstance(serial, int) - self.serial = serial - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a serial number and nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 12: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHLL") + + def __init__(self, version, serial = None, nonce = None): + super(PDUWithSerial, self).__init__(version) + if serial is not None: + assert isinstance(serial, int) + self.serial = serial + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + + def __str__(self): + return "[%s, serial #%d nonce %d]" % (self.__class__.__name__, self.serial, self.nonce) + + def to_pdu(self): + """ + Generate the wire format PDU. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self class PDUWithNonce(PDU): - """ - Base class for PDUs consisting of just a nonce. - """ - - header_struct = struct.Struct("!BBHL") - - def __init__(self, version, nonce = None): - super(PDUWithNonce, self).__init__(version) - if nonce is not None: - assert isinstance(nonce, int) - self.nonce = nonce - - def __str__(self): - return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) - - def to_pdu(self): """ - Generate the wire format PDU. + Base class for PDUs consisting of just a nonce. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) - return self._pdu + header_struct = struct.Struct("!BBHL") - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def __init__(self, version, nonce = None): + super(PDUWithNonce, self).__init__(version) + if nonce is not None: + assert isinstance(nonce, int) + self.nonce = nonce + def __str__(self): + return "[%s, nonce %d]" % (self.__class__.__name__, self.nonce) -class PDUEmpty(PDU): - """ - Base class for empty PDUs. - """ + def to_pdu(self): + """ + Generate the wire format PDU. + """ - header_struct = struct.Struct("!BBHL") + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, self.header_struct.size) + return self._pdu - def __str__(self): - return "[%s]" % self.__class__.__name__ + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self - def to_pdu(self): + +class PDUEmpty(PDU): """ - Generate the wire format PDU for this prefix. + Base class for empty PDUs. """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, zero, length = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if zero != 0: - raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) - if length != 8: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + header_struct = struct.Struct("!BBHL") + + def __str__(self): + return "[%s]" % self.__class__.__name__ + + def to_pdu(self): + """ + Generate the wire format PDU for this prefix. + """ + + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.header_struct.size) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, zero, length = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class SerialNotifyPDU(PDUWithSerial): - """ - Serial Notify PDU. - """ + """ + Serial Notify PDU. + """ - pdu_type = 0 + pdu_type = 0 @wire_pdu class SerialQueryPDU(PDUWithSerial): - """ - Serial Query PDU. - """ + """ + Serial Query PDU. + """ - pdu_type = 1 + pdu_type = 1 - def __init__(self, version, serial = None, nonce = None): - super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) + def __init__(self, version, serial = None, nonce = None): + super(SerialQueryPDU, self).__init__(self.default_version if version is None else version, serial, nonce) @wire_pdu class ResetQueryPDU(PDUEmpty): - """ - Reset Query PDU. - """ + """ + Reset Query PDU. + """ - pdu_type = 2 + pdu_type = 2 - def __init__(self, version): - super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) + def __init__(self, version): + super(ResetQueryPDU, self).__init__(self.default_version if version is None else version) @wire_pdu class CacheResponsePDU(PDUWithNonce): - """ - Cache Response PDU. - """ + """ + Cache Response PDU. + """ - pdu_type = 3 + pdu_type = 3 def EndOfDataPDU(version, *args, **kwargs): - """ - Factory for the EndOfDataPDU classes, which take different forms in - different protocol versions. - """ + """ + Factory for the EndOfDataPDU classes, which take different forms in + different protocol versions. + """ - if version == 0: - return EndOfDataPDUv0(version, *args, **kwargs) - if version == 1: - return EndOfDataPDUv1(version, *args, **kwargs) - raise NotImplementedError + if version == 0: + return EndOfDataPDUv0(version, *args, **kwargs) + if version == 1: + return EndOfDataPDUv1(version, *args, **kwargs) + raise NotImplementedError # Min, max, and default values, from the current RFC 6810 bis I-D. @@ -324,325 +324,325 @@ def EndOfDataPDU(version, *args, **kwargs): default_refresh = 3600 def valid_refresh(refresh): - if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: - raise ValueError - return refresh + if not isinstance(refresh, int) or refresh < 120 or refresh > 86400: + raise ValueError + return refresh default_retry = 600 def valid_retry(retry): - if not isinstance(retry, int) or retry < 120 or retry > 7200: - raise ValueError - return retry + if not isinstance(retry, int) or retry < 120 or retry > 7200: + raise ValueError + return retry default_expire = 7200 def valid_expire(expire): - if not isinstance(expire, int) or expire < 600 or expire > 172800: - raise ValueError - return expire + if not isinstance(expire, int) or expire < 600 or expire > 172800: + raise ValueError + return expire @wire_pdu_only(0) class EndOfDataPDUv0(PDUWithSerial): - """ - End of Data PDU, protocol version 0. - """ + """ + End of Data PDU, protocol version 0. + """ - pdu_type = 7 + pdu_type = 7 - def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): - super(EndOfDataPDUv0, self).__init__(version, serial, nonce) - self.refresh = valid_refresh(default_refresh if refresh is None else refresh) - self.retry = valid_retry( default_retry if retry is None else retry) - self.expire = valid_expire( default_expire if expire is None else expire) + def __init__(self, version, serial = None, nonce = None, refresh = None, retry = None, expire = None): + super(EndOfDataPDUv0, self).__init__(version, serial, nonce) + self.refresh = valid_refresh(default_refresh if refresh is None else refresh) + self.retry = valid_retry( default_retry if retry is None else retry) + self.expire = valid_expire( default_expire if expire is None else expire) @wire_pdu_only(1) class EndOfDataPDUv1(EndOfDataPDUv0): - """ - End of Data PDU, protocol version 1. - """ + """ + End of Data PDU, protocol version 1. + """ - header_struct = struct.Struct("!BBHLLLLL") + header_struct = struct.Struct("!BBHLLLLL") - def __str__(self): - return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( - self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) + def __str__(self): + return "[%s, serial #%d nonce %d refresh %d retry %d expire %d]" % ( + self.__class__.__name__, self.serial, self.nonce, self.refresh, self.retry, self.expire) - def to_pdu(self): - """ - Generate the wire format PDU. - """ + def to_pdu(self): + """ + Generate the wire format PDU. + """ - if self._pdu is None: - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, - self.header_struct.size, self.serial, - self.refresh, self.retry, self.expire) - return self._pdu + if self._pdu is None: + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.nonce, + self.header_struct.size, self.serial, + self.refresh, self.retry, self.expire) + return self._pdu - def got_pdu(self, reader): - if not reader.ready(): - return None - b = reader.get(self.header_struct.size) - version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ - = self.header_struct.unpack(b) - assert version == self.version and pdu_type == self.pdu_type - if length != 24: - raise CorruptData("PDU length of %d can't be right" % length, pdu = self) - assert b == self.to_pdu() - return self + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.get(self.header_struct.size) + version, pdu_type, self.nonce, length, self.serial, self.refresh, self.retry, self.expire \ + = self.header_struct.unpack(b) + assert version == self.version and pdu_type == self.pdu_type + if length != 24: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) + assert b == self.to_pdu() + return self @wire_pdu class CacheResetPDU(PDUEmpty): - """ - Cache reset PDU. - """ + """ + Cache reset PDU. + """ - pdu_type = 8 + pdu_type = 8 class PrefixPDU(PDU): - """ - Object representing one prefix. This corresponds closely to one PDU - in the rpki-router protocol, so closely that we use lexical ordering - of the wire format of the PDU as the ordering for this class. - - This is a virtual class, but the .from_text() constructor - instantiates the correct concrete subclass (IPv4PrefixPDU or - IPv6PrefixPDU) depending on the syntax of its input text. - """ - - header_struct = struct.Struct("!BB2xLBBBx") - asnum_struct = struct.Struct("!L") - - def __str__(self): - plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def show(self): - logging.debug("# Class: %s", self.__class__.__name__) - logging.debug("# ASN: %s", self.asn) - logging.debug("# Prefix: %s", self.prefix) - logging.debug("# Prefixlen: %s", self.prefixlen) - logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) - logging.debug("# Announce: %s", self.announce) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if self.prefix.bits != self.address_byte_count * 8: - raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) - if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: - raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) - if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: - raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - """ - Generate the wire format PDU for this prefix. - """ - - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size - pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, - announce if announce is not None else self.announce, - self.prefixlen, self.max_prefixlen) + - self.prefix.toBytes() + - self.asnum_struct.pack(self.asn)) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - b1 = reader.get(self.header_struct.size) - b2 = reader.get(self.address_byte_count) - b3 = reader.get(self.asnum_struct.size) - version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) - assert version == self.version and pdu_type == self.pdu_type - if length != len(b1) + len(b2) + len(b3): - raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) - self.prefix = rpki.POW.IPAddress.fromBytes(b2) - self.asn = self.asnum_struct.unpack(b3)[0] - assert b1 + b2 + b3 == self.to_pdu() - return self + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. + + This is a virtual class, but the .from_text() constructor + instantiates the correct concrete subclass (IPv4PrefixPDU or + IPv6PrefixPDU) depending on the syntax of its input text. + """ + + header_struct = struct.Struct("!BB2xLBBBx") + asnum_struct = struct.Struct("!L") + + def __str__(self): + plm = "%s/%s-%s" % (self.prefix, self.prefixlen, self.max_prefixlen) + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, plm, + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def show(self): + logging.debug("# Class: %s", self.__class__.__name__) + logging.debug("# ASN: %s", self.asn) + logging.debug("# Prefix: %s", self.prefix) + logging.debug("# Prefixlen: %s", self.prefixlen) + logging.debug("# MaxPrefixlen: %s", self.max_prefixlen) + logging.debug("# Announce: %s", self.announce) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefix.bits != self.address_byte_count * 8: + raise CorruptData("IP address length %d does not match expectation" % self.prefix.bits, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.prefix.bits: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.prefix.bits: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + """ + Generate the wire format PDU for this prefix. + """ + + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + self.prefix.bits/8 + self.asnum_struct.size + pdu = (self.header_struct.pack(self.version, self.pdu_type, pdulen, + announce if announce is not None else self.announce, + self.prefixlen, self.max_prefixlen) + + self.prefix.toBytes() + + self.asnum_struct.pack(self.asn)) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.get(self.header_struct.size) + b2 = reader.get(self.address_byte_count) + b3 = reader.get(self.asnum_struct.size) + version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) + assert version == self.version and pdu_type == self.pdu_type + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) + self.prefix = rpki.POW.IPAddress.fromBytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() + return self @wire_pdu class IPv4PrefixPDU(PrefixPDU): - """ - IPv4 flavor of a prefix. - """ + """ + IPv4 flavor of a prefix. + """ - pdu_type = 4 - address_byte_count = 4 + pdu_type = 4 + address_byte_count = 4 @wire_pdu class IPv6PrefixPDU(PrefixPDU): - """ - IPv6 flavor of a prefix. - """ + """ + IPv6 flavor of a prefix. + """ - pdu_type = 6 - address_byte_count = 16 + pdu_type = 6 + address_byte_count = 16 @wire_pdu_only(1) class RouterKeyPDU(PDU): - """ - Router Key PDU. - """ - - pdu_type = 9 - - header_struct = struct.Struct("!BBBxL20sL") - - def __str__(self): - return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, - base64.urlsafe_b64encode(self.ski).rstrip("="), - ":".join(("%02X" % ord(b) for b in self.to_pdu()))) - - def check(self): - """ - Check attributes to make sure they're within range. - """ - - if self.announce not in (0, 1): - raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) - if len(self.ski) != 20: - raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) - pdulen = self.header_struct.size + len(self.key) - if len(self.to_pdu()) != pdulen: - raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) - - def to_pdu(self, announce = None): - if announce is not None: - assert announce in (0, 1) - elif self._pdu is not None: - return self._pdu - pdulen = self.header_struct.size + len(self.key) - pdu = (self.header_struct.pack(self.version, - self.pdu_type, - announce if announce is not None else self.announce, - pdulen, - self.ski, - self.asn) - + self.key) - if announce is None: - assert self._pdu is None - self._pdu = pdu - return pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - if remaining <= 0: - raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) - self.key = reader.get(remaining) - assert header + self.key == self.to_pdu() - return self + """ + Router Key PDU. + """ + + pdu_type = 9 + + header_struct = struct.Struct("!BBBxL20sL") + + def __str__(self): + return "%s %8s %-32s %s" % ("+" if self.announce else "-", self.asn, + base64.urlsafe_b64encode(self.ski).rstrip("="), + ":".join(("%02X" % ord(b) for b in self.to_pdu()))) + + def check(self): + """ + Check attributes to make sure they're within range. + """ + + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if len(self.ski) != 20: + raise CorruptData("Implausible SKI length %d" % len(self.ski), pdu = self) + pdulen = self.header_struct.size + len(self.key) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) + + def to_pdu(self, announce = None): + if announce is not None: + assert announce in (0, 1) + elif self._pdu is not None: + return self._pdu + pdulen = self.header_struct.size + len(self.key) + pdu = (self.header_struct.pack(self.version, + self.pdu_type, + announce if announce is not None else self.announce, + pdulen, + self.ski, + self.asn) + + self.key) + if announce is None: + assert self._pdu is None + self._pdu = pdu + return pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.announce, length, self.ski, self.asn = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + if remaining <= 0: + raise CorruptData("Got PDU length %d, minimum is %d" % (length, self.header_struct.size + 1), pdu = self) + self.key = reader.get(remaining) + assert header + self.key == self.to_pdu() + return self @wire_pdu class ErrorReportPDU(PDU): - """ - Error Report PDU. - """ - - pdu_type = 10 - - header_struct = struct.Struct("!BBHL") - string_struct = struct.Struct("!L") - - errors = { - 2 : "No Data Available" } - - fatal = { - 0 : "Corrupt Data", - 1 : "Internal Error", - 3 : "Invalid Request", - 4 : "Unsupported Protocol Version", - 5 : "Unsupported PDU Type", - 6 : "Withdrawal of Unknown Record", - 7 : "Duplicate Announcement Received" } - - assert set(errors) & set(fatal) == set() - - errors.update(fatal) - - codes = dict((v, k) for k, v in errors.items()) - - def __init__(self, version, errno = None, errpdu = None, errmsg = None): - super(ErrorReportPDU, self).__init__(version) - assert errno is None or errno in self.errors - self.errno = errno - self.errpdu = errpdu - self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] - - def __str__(self): - return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) - - def to_counted_string(self, s): - return self.string_struct.pack(len(s)) + s - - def read_counted_string(self, reader, remaining): - assert remaining >= self.string_struct.size - n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] - assert remaining >= self.string_struct.size + n - return n, reader.get(n), (remaining - self.string_struct.size - n) - - def to_pdu(self): - """ - Generate the wire format PDU for this error report. - """ - - if self._pdu is None: - assert isinstance(self.errno, int) - assert not isinstance(self.errpdu, ErrorReportPDU) - p = self.errpdu - if p is None: - p = "" - elif isinstance(p, PDU): - p = p.to_pdu() - assert isinstance(p, str) - pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) - self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) - self._pdu += self.to_counted_string(p) - self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) - return self._pdu - - def got_pdu(self, reader): - if not reader.ready(): - return None - header = reader.get(self.header_struct.size) - version, pdu_type, self.errno, length = self.header_struct.unpack(header) - assert version == self.version and pdu_type == self.pdu_type - remaining = length - self.header_struct.size - self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) - self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) - if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: - raise CorruptData("Got PDU length %d, expected %d" % ( - length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) - assert (header - + self.to_counted_string(self.errpdu) - + self.to_counted_string(self.errmsg.encode("utf8")) - == self.to_pdu()) - return self + """ + Error Report PDU. + """ + + pdu_type = 10 + + header_struct = struct.Struct("!BBHL") + string_struct = struct.Struct("!L") + + errors = { + 2 : "No Data Available" } + + fatal = { + 0 : "Corrupt Data", + 1 : "Internal Error", + 3 : "Invalid Request", + 4 : "Unsupported Protocol Version", + 5 : "Unsupported PDU Type", + 6 : "Withdrawal of Unknown Record", + 7 : "Duplicate Announcement Received" } + + assert set(errors) & set(fatal) == set() + + errors.update(fatal) + + codes = dict((v, k) for k, v in errors.items()) + + def __init__(self, version, errno = None, errpdu = None, errmsg = None): + super(ErrorReportPDU, self).__init__(version) + assert errno is None or errno in self.errors + self.errno = errno + self.errpdu = errpdu + self.errmsg = errmsg if errmsg is not None or errno is None else self.errors[errno] + + def __str__(self): + return "[%s, error #%s: %r]" % (self.__class__.__name__, self.errno, self.errmsg) + + def to_counted_string(self, s): + return self.string_struct.pack(len(s)) + s + + def read_counted_string(self, reader, remaining): + assert remaining >= self.string_struct.size + n = self.string_struct.unpack(reader.get(self.string_struct.size))[0] + assert remaining >= self.string_struct.size + n + return n, reader.get(n), (remaining - self.string_struct.size - n) + + def to_pdu(self): + """ + Generate the wire format PDU for this error report. + """ + + if self._pdu is None: + assert isinstance(self.errno, int) + assert not isinstance(self.errpdu, ErrorReportPDU) + p = self.errpdu + if p is None: + p = "" + elif isinstance(p, PDU): + p = p.to_pdu() + assert isinstance(p, str) + pdulen = self.header_struct.size + self.string_struct.size * 2 + len(p) + len(self.errmsg) + self._pdu = self.header_struct.pack(self.version, self.pdu_type, self.errno, pdulen) + self._pdu += self.to_counted_string(p) + self._pdu += self.to_counted_string(self.errmsg.encode("utf8")) + return self._pdu + + def got_pdu(self, reader): + if not reader.ready(): + return None + header = reader.get(self.header_struct.size) + version, pdu_type, self.errno, length = self.header_struct.unpack(header) + assert version == self.version and pdu_type == self.pdu_type + remaining = length - self.header_struct.size + self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) + self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) + assert (header + + self.to_counted_string(self.errpdu) + + self.to_counted_string(self.errmsg.encode("utf8")) + == self.to_pdu()) + return self |