diff options
author | Rob Austein <sra@hactrn.net> | 2009-04-07 05:52:40 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2009-04-07 05:52:40 +0000 |
commit | 5bdd24d94bd8403ddd123dfcd5b7480f1e80a3ac (patch) | |
tree | eadc66f7a9cab03a0d2280115c538e33c047bd85 /rtr-origin | |
parent | 1f00d2beeffa1208429a6d894a1b367b8b80e992 (diff) |
Clean up horrible mess of duplicated parsing code.
svn path=/rtr-origin/rtr-origin.py; revision=2319
Diffstat (limited to 'rtr-origin')
-rw-r--r-- | rtr-origin/rtr-origin.py | 284 |
1 files changed, 126 insertions, 158 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py index da3e54fa..6c59d47f 100644 --- a/rtr-origin/rtr-origin.py +++ b/rtr-origin/rtr-origin.py @@ -42,12 +42,43 @@ PERFORMANCE OF THIS SOFTWARE. """ import sys, os, struct, time, glob, socket, fcntl, signal -import asyncore, asynchat, subprocess +import asyncore, asynchat, subprocess, traceback import rpki.x509, rpki.ipaddrs, rpki.sundial os.environ["TZ"] = "UTC" time.tzset() +class read_buffer(object): + """Wrapper around synchronous/asynchronous read state.""" + + def __init__(self): + self.buffer = "" + + def update(self, need, callback): + self.need = need + self.callback = callback + return self.callback(self) + + def available(self): + return len(self.buffer) + + def needed(self): + return self.need - self.available() + + def ready(self): + return self.available() >= self.need + + def consume(self, n): + b = self.buffer[:n] + self.buffer = self.buffer[n:] + return b + + def feed(self, b): + self.buffer += b + + def retry(self): + return self.callback(self) + class pdu(object): """Object representing a generic PDU in the rpki-router protocol. Real PDUs are subclasses of this class. @@ -67,40 +98,18 @@ class pdu(object): pass @classmethod - def from_pdu_file(cls, f): - """Read one wire format PDU from a file. This is intended to be - used in an iterator, so it raises StopIteration on end of file. - """ - assert cls._pdu is None - b = f.read(cls.common_header_struct.size) - if b == "": - raise StopIteration - version, pdu_type = cls.common_header_struct.unpack(b) - assert version == cls.version, "PDU version is %d, expected %d" % (version, cls.version) - self = cls.pdu_map[pdu_type]() - self.from_pdu_file_helper(f, b) - self.check() - return self - - @classmethod - def initial_asynchat_decoder(cls, chat): - """Set up initial read for asynchat PDU reader.""" - chat.set_terminator(cls.common_header_struct.size) - chat.set_next_decoder(cls.chat_decode_common_header) + def read_pdu(cls, reader): + return reader.update(need = cls.common_header_struct.size, callback = cls.got_common_header) @classmethod - def chat_decode_common_header(cls, chat, b): - """Decode PDU header from an asynchat reader.""" - assert cls._pdu is None - version, pdu_type = cls.common_header_struct.unpack(b) + def got_common_header(cls, reader): + if not reader.ready(): + return None + assert reader.available() >= cls.common_header_struct.size + version, pdu_type = cls.common_header_struct.unpack(reader.buffer[:cls.common_header_struct.size]) assert version == cls.version, "PDU version is %d, expected %d" % (version, cls.version) self = cls.pdu_map[pdu_type]() - if len(b) >= self.header_struct.size: - return self.chat_decode_header(chat, b) - else: - chat.set_terminator(self.header_struct.size - cls.common_header_struct.size) - chat.set_next_decoder(self.chat_decode_header) - return None + return reader.update(need = self.header_struct.size, callback = self.got_header) def consume(self, client): """Handle results in test client. Default is just to print the PDU.""" @@ -127,17 +136,11 @@ class pdu_with_serial(pdu): self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.serial) return self._pdu - def from_pdu_file_helper(self, f, b): - """Read one wire format prefix PDU from a file.""" - b += f.read(self.header_struct.size - len(b)) - version, pdu_type, zero, self.serial = self.header_struct.unpack(b) - assert zero == 0 - assert b == self.to_pdu() - - def chat_decode_header(self, chat, b): - """Decode PDU from an asynchat reader.""" + def got_header(self, reader): + if not reader.ready(): + return None + b = reader.consume(self.header_struct.size) version, pdu_type, zero, self.serial = self.header_struct.unpack(b) - chat.consume(self.header_struct.size) assert zero == 0 assert b == self.to_pdu() return self @@ -156,17 +159,11 @@ class pdu_empty(pdu): self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0) return self._pdu - def from_pdu_file_helper(self, f, b): - """Read one wire format prefix PDU from a file.""" - b += f.read(self.header_struct.size - len(b)) - version, pdu_type, zero = self.header_struct.unpack(b) - assert zero == 0 - assert b == self.to_pdu() - - def chat_decode_header(self, chat, b): - """Decode PDU from an asynchat reader.""" + def got_header(self, reader): + if not reader.ready(): + return None + b = reader.consume(self.header_struct.size) version, pdu_type, zero = self.header_struct.unpack(b) - chat.consume(self.header_struct.size) assert zero == 0 assert b == self.to_pdu() return self @@ -317,41 +314,20 @@ class prefix(pdu): self._pdu = pdu return pdu - def from_pdu_file_helper(self, f, b): - """Read one wire format prefix PDU from a file.""" - b += f.read(self.header_struct.size - len(b)) - p = b - version, pdu_type, self.color, self.announce, self.prefixlen, self.max_prefixlen, source = self.header_struct.unpack(b) - assert source == self.source - b = f.read(self.addr_type.bits / 8) - p += b - self.prefix = self.addr_type.from_bytes(b) - b = f.read(self.asnum_struct.size) - p += b - self.asn = self.asnum_struct.unpack(b)[0] - assert p == self.to_pdu() - - def chat_decode_header(self, chat, b): - """Decode PDU header from an asynchat reader.""" - version, pdu_type, self.color, self.announce, self.prefixlen, self.max_prefixlen, source = self.header_struct.unpack(b) + def got_header(self, reader): + return reader.update(need = self.header_struct.size + self.addr_type.bits / 8 + self.asnum_struct.size, callback = self.got_pdu) + + def got_pdu(self, reader): + if not reader.ready(): + return None + b1 = reader.consume(self.header_struct.size) + b2 = reader.consume(self.addr_type.bits / 8) + b3 = reader.consume(self.asnum_struct.size) + version, pdu_type, self.color, self.announce, self.prefixlen, self.max_prefixlen, source = self.header_struct.unpack(b1) assert source == self.source - chat.consume(self.header_struct.size) - chat.set_terminator(self.addr_type.bits / 8) - chat.set_next_decoder(self.chat_decode_prefix) - return None - - def chat_decode_prefix(self, chat, b): - """Decode prefix from an asynchat reader.""" - self.prefix = self.addr_type.from_bytes(b) - chat.consume(self.addr_type.bits / 8) - chat.set_terminator(self.asnum_struct.size) - chat.set_next_decoder(self.chat_decode_asnum) - return None - - def chat_decode_asnum(self, chat, b): - """Decode autonomous system number from an asynchat reader.""" - self.asn = self.asnum_struct.unpack(b)[0] - chat.consume(self.asnum_struct.size) + self.prefix = self.addr_type.from_bytes(b2) + self.asn = self.asnum_struct.unpack(b3)[0] + assert b1 + b2 + b3 == self.to_pdu() return self class ipv4_prefix(prefix): @@ -397,58 +373,25 @@ class error_report(pdu): p + self.errmsg) return self._pdu - def from_pdu_file_helper(self, f, b): - """Read one wire format prefix PDU from a file.""" - b += f.read(self.header_struct.size - len(b)) - version, pdu_type, self.errno, self.pdulen, self.errlen = self.header_struct.unpack(b) - if self.pdulen: - # This is wrong, we should be checking the length but methods - # don't allows that yet. - self.errpdu = pdu.from_pdu_file(f) - if self.errlen: - self.errmsg = f.read(self.errlen) - - def chat_decode_header(self, chat, b): - """Decode PDU header from an asynchat reader.""" - version, pdu_type, self.errno, self.pdulen, self.errlen = self.header_struct.unpack(b) - chat.consume(self.header_struct.size) - if self.pdulen: - chat.set_terminator(self.pdulen) - chat.set_next_decoder(self.chat_decode_pdu) - return None - else: - return self.chat_decode_pdu(chat, b) - - def chat_decode_pdu(self, chat, b): - """Decode encapsulated PDU from an asynchat reader.""" - self.pdu = b[:self.pdulen] - chat.consume(self.pdulen) - if self.errlen: - chat.set_terminator(self.errlen) - chat.set_next_decoder(self.chat_decode_errmsg) + def got_header(self, reader): + if not reader.ready(): return None - else: - return self.chat_decode_errmsg(chat, b) + version, pdu_type, self.errno, self.pdulen, self.errlen = self.header_struct.unpack(reader.buffer[:self.header_struct.size]) + return reader.update(need = self.header_struct.size + self.pdulen + self.errlen, callback = self.got_pdu) - def chat_decode_errmsg(self, chat, b): - """Decode error message number from an asynchat reader.""" - self.errmsg = b[:self.errlen] - chat.consume(self.errlen) + def got_pdu(self, reader): + if not reader.ready(): + return None + b = reader.consume(self.header_struct.size) + self.errpdu = reader.consume(self.pdulen) + self.errmsg = reader.consume(self.errlen) + assert b + self.errpdu + self.errmsg == self.to_pdu() return self prefix.afi_map = { "\x00\x01" : ipv4_prefix, "\x00\x02" : ipv6_prefix } pdu.pdu_map = dict((p.pdu_type, p) for p in (ipv4_prefix, ipv6_prefix, serial_notify, serial_query, reset_query, cache_response, end_of_data, cache_reset, error_report)) -class pdufile(file): - """File subclass with PDU iterator.""" - - def __iter__(self): - return self - - def next(self): - return pdu.from_pdu_file(self) - class prefix_set(list): """Object representing a set of prefixes, that is, one versioned and (theoretically) consistant set of prefixes extracted from rcynic's @@ -481,11 +424,18 @@ class prefix_set(list): def _load_file(cls, filename): """Low-level method to read prefix_set from a file.""" self = cls() - f = pdufile(filename, "rb") - for p in f: + f = open(filename, "rb") + r = read_buffer() + while True: + p = pdu.read_pdu(r) + while p is None: + b = f.read(r.needed()) + if b == "": + assert r.available() == 0 + return self + r.feed(b) + p = r.retry() self.append(p) - f.close() - return self @classmethod def load_axfr(cls, filename): @@ -512,7 +462,7 @@ class prefix_set(list): def save_axfr(self): """Write AXFR-style prefix_set to file with magic filename.""" - f = pdufile("%d.ax" % self.serial, "wb") + f = open("%d.ax" % self.serial, "wb") for p in self: f.write(p.to_pdu()) f.close() @@ -535,7 +485,7 @@ class prefix_set(list): Since we store prefix_sets in sorted order, computing the difference is a trivial linear comparison. """ - f = pdufile("%d.ix.%d" % (self.serial, other.serial), "wb") + f = open("%d.ix.%d" % (self.serial, other.serial), "wb") old = other[:] new = self[:] while old and new: @@ -609,32 +559,32 @@ class pdu_asynchat(asynchat.async_chat): resulting PDUs. """ - def start_new_pdu(self): - """Starting read of a new PDU, set up initial decoder.""" - self.buffer = "" - self.next_decoder = None - pdu.initial_asynchat_decoder(self) - assert self.next_decoder is not None + def __init__(self, conn = None): + asynchat.async_chat.__init__(self, conn = conn) + self.reader = read_buffer() - def consume(self, n): - """Consume n bytes from the input buffer.""" - self.buffer = self.buffer[n:] + def start_new_pdu(self): + """Start read of a new PDU.""" + p = pdu.read_pdu(self.reader) + while p is not None: + self.deliver_pdu(p) + p = pdu.read_pdu(self.reader) + assert not self.reader.ready() + self.set_terminator(self.reader.needed()) def collect_incoming_data(self, data): - """Collect data into the input buffer.""" - self.buffer += data - - def set_next_decoder(self, decoder): - """Set decoder to use with the next chunk of data.""" - self.next_decoder = decoder - + """Collect data into the read buffer.""" + self.reader.feed(data) + def found_terminator(self): - """Got requested data, hand it to decoder. If we get back a PDU, - pass it up, then loop back to listen for another PDU. + """Got requested data, see if we now have a PDU. If so, pass it + along, then restart cycle for a new PDU. """ - pdu = self.next_decoder(self, self.buffer) - if pdu is not None: - self.deliver_pdu(pdu) + p = self.reader.retry() + if p is None: + self.set_terminator(self.reader.needed()) + else: + self.deliver_pdu(p) self.start_new_pdu() def deliver_pdu(self, pdu): @@ -658,6 +608,15 @@ class pdu_asynchat(asynchat.async_chat): """Intercept asynchat's logging.""" log("asynchat: %s: %s" % (tag, msg)) + def handle_error(self): + """Handle errors caught by asyncore main loop. Asyncore has a + default handler for this but I find its customized backtraces very + hard to read. + """ + log(traceback.format_exc()) + log("Closing channel %s" % repr(self)) + self.close() + class server_asynchat(pdu_asynchat): """Server protocol engine, handles upcalls from pdu_asynchat to implement protocol logic. @@ -665,7 +624,7 @@ class server_asynchat(pdu_asynchat): def __init__(self): """Set up stdin as connection and start listening for first PDU.""" - asynchat.async_chat.__init__(self) + pdu_asynchat.__init__(self) # # I don't know a sane way to get asynchat.async_chat.__init__() to # call asyncore.file_dispatcher.__init__(), so shut your eyes for @@ -741,7 +700,7 @@ class client_asynchat(pdu_asynchat): self.ssh = subprocess.Popen(["/usr/local/bin/python", "rtr-origin.py", "server"], stdin = s[0], stdout = s[0], close_fds = True) else: self.ssh = subprocess.Popen(sshargs, executable = "/usr/bin/ssh", stdin = s[0], stdout = s[0], close_fds = True) - asynchat.async_chat.__init__(self, conn = s[1]) + pdu_asynchat.__init__(self, conn = s[1]) self.start_new_pdu() def deliver_pdu(self, pdu): @@ -802,6 +761,15 @@ class server_wakeup(asyncore.dispatcher): """Intercept asyncore's logging.""" log("asyncore: %s: %s" % (tag, msg)) + def handle_error(self): + """Handle errors caught by asyncore main loop. Asyncore has a + default handler for this but I find its customized backtraces very + hard to read. + """ + log(traceback.format_exc()) + log("Closing channel %s" % repr(self)) + self.close() + def server_main(): """Main program for server mode. Not really written yet.""" wakeup = None |