aboutsummaryrefslogtreecommitdiff
path: root/rtr-origin
diff options
context:
space:
mode:
Diffstat (limited to 'rtr-origin')
-rwxr-xr-xrtr-origin/rtr-origin.py89
1 files changed, 71 insertions, 18 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py
index 5ebc0401..f10c237e 100755
--- a/rtr-origin/rtr-origin.py
+++ b/rtr-origin/rtr-origin.py
@@ -178,6 +178,36 @@ class read_buffer(object):
"""
return self.callback(self)
+class PDUException(Exception):
+ """
+ Parent exception type for exceptions that signal particular protocol
+ errors. String value of exception instance will be the message to
+ put in the error_report PDU, error_report_code value of exception
+ will be the numeric code to use.
+ """
+
+ def __init__(self, msg = None, pdu = None):
+ assert msg is None or isinstance(msg, (str, unicode))
+ self.error_report_msg = msg
+ self.error_report_pdu = pdu
+
+ def __str__(self):
+ return self.error_report_msg or self.__class__.__name__
+
+ def make_error_report(self):
+ return error_report(errno = self.error_report_code,
+ errmsg = self.error_report_msg,
+ errpdu = self.error_report_pdu)
+
+class UnsupportedProtocolVersion(PDUException):
+ error_report_code = 4
+
+class UnsupportedPDUType(PDUException):
+ error_report_code = 5
+
+class CorruptData(PDUException):
+ error_report_code = 0
+
class pdu(object):
"""
Object representing a generic PDU in the rpki-router protocol.
@@ -209,8 +239,15 @@ class pdu(object):
return None
assert reader.available() >= cls.header_struct.size
version, pdu_type, whatever, length = cls.header_struct.unpack(reader.buffer[:cls.header_struct.size])
- assert version == cls.version, "PDU version is %d, expected %d" % (version, cls.version)
- assert length >= 8
+ if version != cls.version:
+ raise UnsupportedProtocolVersion(
+ "Received PDU version %d, expected %d" % (version, cls.version))
+ if pdu_type not in cls.pdu_map:
+ raise UnsupportedPDUType(
+ "Received unsupported PDU type %d" % pdu_type)
+ if length < 8:
+ raise CorruptData(
+ "Received PDU with length %d, which is too short to be valid" % length)
self = cls.pdu_map[pdu_type]()
return reader.update(need = length, callback = self.got_pdu)
@@ -267,7 +304,8 @@ class pdu_with_serial(pdu):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, self.nonce, length, self.serial = self.header_struct.unpack(b)
- assert length == 12
+ if length != 12:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
return self
@@ -299,7 +337,8 @@ class pdu_nonce(pdu):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, self.nonce, length = self.header_struct.unpack(b)
- assert length == 8
+ if length != 8:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
return self
@@ -326,8 +365,10 @@ class pdu_empty(pdu):
return None
b = reader.get(self.header_struct.size)
version, pdu_type, zero, length = self.header_struct.unpack(b)
- assert zero == 0
- assert length == 8
+ if zero != 0:
+ raise CorruptData("Must-be-zero field isn't zero" % length, pdu = self)
+ if length != 8:
+ raise CorruptData("PDU length of %d can't be right" % length, pdu = self)
assert b == self.to_pdu()
return self
@@ -487,11 +528,15 @@ class prefix(pdu):
"""
Check attributes to make sure they're within range.
"""
- assert self.announce in (0, 1)
- assert self.prefixlen >= 0 and self.prefixlen <= self.addr_type.size * 8
- assert self.max_prefixlen >= self.prefixlen and self.max_prefixlen <= self.addr_type.size * 8
+ if self.announce not in (0, 1):
+ raise CorruptData("Announce value %d is neither zero nor one" % self.announce, pdu = self)
+ if self.prefixlen < 0 or self.prefixlen > self.addr_type.size * 8:
+ raise CorruptData("Implausible prefix length %d" % self.prefixlen, pdu = self)
+ if self.max_prefixlen < self.prefixlen or self.max_prefixlen > self.addr_type.size * 8:
+ raise CorruptData("Implausible max prefix length %d" % self.max_prefixlen, pdu = self)
pdulen = self.header_struct.size + self.addr_type.size + self.asnum_struct.size
- assert len(self.to_pdu()) == pdulen, "Expected %d byte PDU, got %d" % pd(pdulen, len(self.to_pdu()))
+ if len(self.to_pdu()) != pdulen:
+ raise CorruptData("Expected %d byte PDU, got %d" % (pdulen, len(self.to_pdu())), pdu = self)
def to_pdu(self, announce = None):
"""
@@ -519,7 +564,8 @@ class prefix(pdu):
b2 = reader.get(self.addr_type.size)
b3 = reader.get(self.asnum_struct.size)
version, pdu_type, length, self.announce, self.prefixlen, self.max_prefixlen = self.header_struct.unpack(b1)
- assert length == len(b1) + len(b2) + len(b3)
+ if length != len(b1) + len(b2) + len(b3):
+ raise CorruptData("Got PDU length %d, expected %d" % (length, len(b1) + len(b2) + len(b3)), pdu = self)
self.prefix = self.addr_type(value = b2)
self.asn = self.asnum_struct.unpack(b3)[0]
assert b1 + b2 + b3 == self.to_pdu()
@@ -631,7 +677,7 @@ class error_report(pdu):
def to_pdu(self):
"""
- Generate the wire format PDU for this prefix.
+ Generate the wire format PDU for this error report.
"""
if self._pdu is None:
assert isinstance(self.errno, int)
@@ -656,7 +702,9 @@ class error_report(pdu):
remaining = length - self.header_struct.size
self.pdulen, self.errpdu, remaining = self.read_counted_string(reader, remaining)
self.errlen, self.errmsg, remaining = self.read_counted_string(reader, remaining)
- assert length == self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen
+ if length != self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen:
+ raise CorruptData("Got PDU length %d, expected %d" % (
+ length, self.header_struct.size + self.string_struct.size * 2 + self.pdulen + self.errlen))
assert header + self.to_counted_string(self.errpdu) + self.to_counted_string(self.errmsg.encode("utf8")) == self.to_pdu()
return self
@@ -946,12 +994,17 @@ class pdu_channel(asynchat.async_chat):
"""
Start read of a new PDU.
"""
- p = pdu.read_pdu(self.reader)
- while p is not None:
- self.deliver_pdu(p)
+ try:
p = pdu.read_pdu(self.reader)
- assert not self.reader.ready()
- self.set_terminator(self.reader.needed())
+ while p is not None:
+ self.deliver_pdu(p)
+ p = pdu.read_pdu(self.reader)
+ except PDUException, e:
+ self.push_pdu(e.make_error_report())
+ self.close_when_done()
+ else:
+ assert not self.reader.ready()
+ self.set_terminator(self.reader.needed())
def collect_incoming_data(self, data):
"""