diff options
Diffstat (limited to 'rtr-origin')
-rwxr-xr-x | rtr-origin/rtr-origin.py | 89 |
1 files changed, 71 insertions, 18 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py index 5ebc0401..f10c237e 100755 --- a/rtr-origin/rtr-origin.py +++ b/rtr-origin/rtr-origin.py @@ -178,6 +178,36 @@ class read_buffer(object): """ return self.callback(self) +class PDUException(Exception): + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the error_report PDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self): + return error_report(errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) + +class UnsupportedProtocolVersion(PDUException): + error_report_code = 4 + +class UnsupportedPDUType(PDUException): + error_report_code = 5 + +class CorruptData(PDUException): + error_report_code = 0 + class pdu(object): """ Object representing a generic PDU in the rpki-router protocol. @@ -209,8 +239,15 @@ class pdu(object): return None assert reader.available() >= cls.header_struct.size version, pdu_type, whatever, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - assert version == cls.version, "PDU version is %d, expected %d" % (version, cls.version) - assert length >= 8 + if version != cls.version: + raise UnsupportedProtocolVersion( + "Received PDU version %d, expected %d" % (version, cls.version)) + if pdu_type not in cls.pdu_map: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) self = cls.pdu_map[pdu_type]() return reader.update(need = length, callback = self.got_pdu) @@ -267,7 +304,8 @@ class pdu_with_serial(pdu): return None b = reader.get(self.header_struct.size) version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) - assert length == 12 + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() return self @@ -299,7 +337,8 @@ class pdu_nonce(pdu): return None b = reader.get(self.header_struct.size) version, pdu_type, self.nonce, length = self.header_struct.unpack(b) - assert length == 8 + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() return self @@ -326,8 +365,10 @@ class pdu_empty(pdu): return None b = reader.get(self.header_struct.size) version, pdu_type, zero, length = self.header_struct.unpack(b) - assert zero == 0 - assert length == 8 + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() return self @@ -487,11 +528,15 @@ class prefix(pdu): """ Check attributes to make sure they're within range. """ - assert self.announce in (0, 1) - assert self.prefixlen >= 0 and self.prefixlen <= self.addr_type.size * 8 - assert self.max_prefixlen >= self.prefixlen and self.max_prefixlen <= self.addr_type.size * 8 + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.addr_type.size * 8: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.addr_type.size * 8: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) pdulen = self.header_struct.size + self.addr_type.size + self.asnum_struct.size - assert len(self.to_pdu()) == pdulen, "Expected %d byte PDU, got %d" % pd(pdulen, len(self.to_pdu())) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) def to_pdu(self, announce = None): """ @@ -519,7 +564,8 @@ class prefix(pdu): b2 = reader.get(self.addr_type.size) b3 = reader.get(self.asnum_struct.size) version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) - assert length == len(b1) + len(b2) + len(b3) + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) self.prefix = self.addr_type(value = b2) self.asn = self.asnum_struct.unpack(b3)[0] assert b1 + b2 + b3 == self.to_pdu() @@ -631,7 +677,7 @@ class error_report(pdu): def to_pdu(self): """ - Generate the wire format PDU for this prefix. + Generate the wire format PDU for this error report. """ if self._pdu is None: assert isinstance(self.errno, int) @@ -656,7 +702,9 @@ class error_report(pdu): remaining = length - self.header_struct.size self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) - assert length == self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) assert header + self.to_counted_string(self.errpdu) + self.to_counted_string(self.errmsg.encode("utf8")) == self.to_pdu() return self @@ -946,12 +994,17 @@ class pdu_channel(asynchat.async_chat): """ Start read of a new PDU. """ - p = pdu.read_pdu(self.reader) - while p is not None: - self.deliver_pdu(p) + try: p = pdu.read_pdu(self.reader) - assert not self.reader.ready() - self.set_terminator(self.reader.needed()) + while p is not None: + self.deliver_pdu(p) + p = pdu.read_pdu(self.reader) + except PDUException, e: + self.push_pdu(e.make_error_report()) + self.close_when_done() + else: + assert not self.reader.ready() + self.set_terminator(self.reader.needed()) def collect_incoming_data(self, data): """ |