diff options
author | Rob Austein <sra@hactrn.net> | 2011-12-22 01:23:55 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2011-12-22 01:23:55 +0000 |
commit | 45df799f5fd1509cc0cefd809a4895e99e8e6982 (patch) | |
tree | b06fa388e41b76d8648290b55abbcab0fb29d237 /rtr-origin | |
parent | 33f8bf22f6849b90f8f53b329392f2ab69dbf1fc (diff) |
Generate proper error reports for unknown protocol version, unknown
PDU type, and various forms of corrupt data. We were catching all of
them already, but not reporting them correctly.
svn path=/trunk/; revision=4131
Diffstat (limited to 'rtr-origin')
-rwxr-xr-x | rtr-origin/rtr-origin.py | 89 |
1 files changed, 71 insertions, 18 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py index 5ebc0401..f10c237e 100755 --- a/rtr-origin/rtr-origin.py +++ b/rtr-origin/rtr-origin.py @@ -178,6 +178,36 @@ class read_buffer(object): """ return self.callback(self) +class PDUException(Exception): + """ + Parent exception type for exceptions that signal particular protocol + errors. String value of exception instance will be the message to + put in the error_report PDU, error_report_code value of exception + will be the numeric code to use. + """ + + def __init__(self, msg = None, pdu = None): + assert msg is None or isinstance(msg, (str, unicode)) + self.error_report_msg = msg + self.error_report_pdu = pdu + + def __str__(self): + return self.error_report_msg or self.__class__.__name__ + + def make_error_report(self): + return error_report(errno = self.error_report_code, + errmsg = self.error_report_msg, + errpdu = self.error_report_pdu) + +class UnsupportedProtocolVersion(PDUException): + error_report_code = 4 + +class UnsupportedPDUType(PDUException): + error_report_code = 5 + +class CorruptData(PDUException): + error_report_code = 0 + class pdu(object): """ Object representing a generic PDU in the rpki-router protocol. @@ -209,8 +239,15 @@ class pdu(object): return None assert reader.available() >= cls.header_struct.size version, pdu_type, whatever, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size]) - assert version == cls.version, "PDU version is %d, expected %d" % (version, cls.version) - assert length >= 8 + if version != cls.version: + raise UnsupportedProtocolVersion( + "Received PDU version %d, expected %d" % (version, cls.version)) + if pdu_type not in cls.pdu_map: + raise UnsupportedPDUType( + "Received unsupported PDU type %d" % pdu_type) + if length < 8: + raise CorruptData( + "Received PDU with length %d, which is too short to be valid" % length) self = cls.pdu_map[pdu_type]() return reader.update(need = length, callback = self.got_pdu) @@ -267,7 +304,8 @@ class pdu_with_serial(pdu): return None b = reader.get(self.header_struct.size) version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b) - assert length == 12 + if length != 12: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() return self @@ -299,7 +337,8 @@ class pdu_nonce(pdu): return None b = reader.get(self.header_struct.size) version, pdu_type, self.nonce, length = self.header_struct.unpack(b) - assert length == 8 + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() return self @@ -326,8 +365,10 @@ class pdu_empty(pdu): return None b = reader.get(self.header_struct.size) version, pdu_type, zero, length = self.header_struct.unpack(b) - assert zero == 0 - assert length == 8 + if zero != 0: + raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self) + if length != 8: + raise CorruptData("PDU length of %d can't be right" % length, pdu = self) assert b == self.to_pdu() return self @@ -487,11 +528,15 @@ class prefix(pdu): """ Check attributes to make sure they're within range. """ - assert self.announce in (0, 1) - assert self.prefixlen >= 0 and self.prefixlen <= self.addr_type.size * 8 - assert self.max_prefixlen >= self.prefixlen and self.max_prefixlen <= self.addr_type.size * 8 + if self.announce not in (0, 1): + raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self) + if self.prefixlen < 0 or self.prefixlen > self.addr_type.size * 8: + raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self) + if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.addr_type.size * 8: + raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self) pdulen = self.header_struct.size + self.addr_type.size + self.asnum_struct.size - assert len(self.to_pdu()) == pdulen, "Expected %d byte PDU, got %d" % pd(pdulen, len(self.to_pdu())) + if len(self.to_pdu()) != pdulen: + raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self) def to_pdu(self, announce = None): """ @@ -519,7 +564,8 @@ class prefix(pdu): b2 = reader.get(self.addr_type.size) b3 = reader.get(self.asnum_struct.size) version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1) - assert length == len(b1) + len(b2) + len(b3) + if length != len(b1) + len(b2) + len(b3): + raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self) self.prefix = self.addr_type(value = b2) self.asn = self.asnum_struct.unpack(b3)[0] assert b1 + b2 + b3 == self.to_pdu() @@ -631,7 +677,7 @@ class error_report(pdu): def to_pdu(self): """ - Generate the wire format PDU for this prefix. + Generate the wire format PDU for this error report. """ if self._pdu is None: assert isinstance(self.errno, int) @@ -656,7 +702,9 @@ class error_report(pdu): remaining = length - self.header_struct.size self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining) self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining) - assert length == self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen + if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen: + raise CorruptData("Got PDU length %d, expected %d" % ( + length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen)) assert header + self.to_counted_string(self.errpdu) + self.to_counted_string(self.errmsg.encode("utf8")) == self.to_pdu() return self @@ -946,12 +994,17 @@ class pdu_channel(asynchat.async_chat): """ Start read of a new PDU. """ - p = pdu.read_pdu(self.reader) - while p is not None: - self.deliver_pdu(p) + try: p = pdu.read_pdu(self.reader) - assert not self.reader.ready() - self.set_terminator(self.reader.needed()) + while p is not None: + self.deliver_pdu(p) + p = pdu.read_pdu(self.reader) + except PDUException, e: + self.push_pdu(e.make_error_report()) + self.close_when_done() + else: + assert not self.reader.ready() + self.set_terminator(self.reader.needed()) def collect_incoming_data(self, data): """ |