diff options
author | Rob Austein <sra@hactrn.net> | 2009-05-29 05:42:52 +0000 |
---|---|---|
committer | Rob Austein <sra@hactrn.net> | 2009-05-29 05:42:52 +0000 |
commit | da7de456c5b471ea123187a0d1c8fea01f271e13 (patch) | |
tree | df043e98771cba53302b8a3219a5f76abe4b41b0 | |
parent | 955920648564f605aa0ded98853e169410ec3831 (diff) |
Handle missing socket directory somewhat better. Comment cleanup.
svn path=/rtr-origin/rtr-origin.py; revision=2464
-rw-r--r-- | rtr-origin/rtr-origin.py | 377 |
1 files changed, 271 insertions, 106 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py index ce880d63..2ef4426f 100644 --- a/rtr-origin/rtr-origin.py +++ b/rtr-origin/rtr-origin.py @@ -47,45 +47,62 @@ import rpki.x509, rpki.ipaddrs, rpki.sundial, rpki.config import rpki.async class read_buffer(object): - """Wrapper around synchronous/asynchronous read state.""" + """ + Wrapper around synchronous/asynchronous read state. + """ def __init__(self): self.buffer = "" def update(self, need, callback): - """Update count of needed bytes and callback, then dispatch to callback.""" + """ + Update count of needed bytes and callback, then dispatch to callback. + """ self.need = need self.callback = callback return self.callback(self) def available(self): - """How much data do we have available in this buffer?""" + """ + How much data do we have available in this buffer? + """ return len(self.buffer) def needed(self): - """How much more data does this buffer need to become ready?""" + """ + How much more data does this buffer need to become ready? + """ return self.need - self.available() def ready(self): - """Is this buffer ready to read yet?""" + """ + Is this buffer ready to read yet? + """ return self.available() >= self.need def get(self, n): - """Hand some data to the caller.""" + """ + Hand some data to the caller. + """ b = self.buffer[:n] self.buffer = self.buffer[n:] return b def put(self, b): - """Accumulate some data.""" + """ + Accumulate some data. + """ self.buffer += b def retry(self): - """Try dispatching to the callback again.""" + """ + Try dispatching to the callback again. + """ return self.callback(self) class pdu(object): - """Object representing a generic PDU in the rpki-router protocol. + """ + Object representing a generic PDU in the rpki-router protocol. Real PDUs are subclasses of this class. """ @@ -99,7 +116,9 @@ class pdu(object): return cmp(self.to_pdu(), other.to_pdu()) def check(self): - """Check attributes to make sure they're within range.""" + """ + Check attributes to make sure they're within range. + """ pass @classmethod @@ -117,23 +136,31 @@ class pdu(object): return reader.update(need = self.header_struct.size, callback = self.got_header) def consume(self, client): - """Handle results in test client. Default behavior is just to - print out the PDU.""" + """ + Handle results in test client. Default behavior is just to print + out the PDU. + """ log(self) def send_file(self, server, filename): - """Send a content of a file as a cache response. Caller should catch IOError.""" + """ + Send a content of a file as a cache response. Caller should catch IOError. + """ f = open(filename, "rb") server.push_pdu(cache_response()) server.push_file(f) server.push_pdu(end_of_data(serial = server.current_serial)) def send_nodata(self, server): - """Send a nodata error.""" + """ + Send a nodata error. + """ server.push_pdu(error_report(errno = error_report.codes["No Data Available"], errpdu = self)) class pdu_with_serial(pdu): - """Base class for PDUs consisting of just a serial number.""" + """ + Base class for PDUs consisting of just a serial number. + """ header_struct = struct.Struct("!BBHL") @@ -148,7 +175,9 @@ class pdu_with_serial(pdu): return "[%s, serial #%s]" % (self.__class__.__name__, self.serial) def to_pdu(self): - """Generate the wire format PDU for this prefix.""" + """ + Generate the wire format PDU for this prefix. + """ if self._pdu is None: self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0, self.serial) return self._pdu @@ -163,7 +192,9 @@ class pdu_with_serial(pdu): return self class pdu_empty(pdu): - """Base class for emtpy PDUs.""" + """ + Base class for empty PDUs. + """ header_struct = struct.Struct("!BBH") @@ -171,7 +202,9 @@ class pdu_empty(pdu): return "[%s]" % self.__class__.__name__ def to_pdu(self): - """Generate the wire format PDU for this prefix.""" + """ + Generate the wire format PDU for this prefix. + """ if self._pdu is None: self._pdu = self.header_struct.pack(self.version, self.pdu_type, 0) return self._pdu @@ -186,13 +219,16 @@ class pdu_empty(pdu): return self class serial_notify(pdu_with_serial): - """Serial Notify PDU.""" + """ + Serial Notify PDU. + """ pdu_type = 0 def consume(self, client): - """Respond to a serial_notify message with either a serial_query - or reset_query, depending on what we already know. + """ + Respond to a serial_notify message with either a serial_query or + reset_query, depending on what we already know. """ log(self) if client.current_serial is None: @@ -203,12 +239,15 @@ class serial_notify(pdu_with_serial): log("[Notify did not change serial number, ignoring]") class serial_query(pdu_with_serial): - """Serial Query PDU.""" + """ + Serial Query PDU. + """ pdu_type = 1 def serve(self, server): - """Received a serial query, send incremental transfer in response. + """ + Received a serial query, send incremental transfer in response. If client is already up to date, just send an empty incremental transfer. """ @@ -226,12 +265,16 @@ class serial_query(pdu_with_serial): server.push_pdu(cache_reset()) class reset_query(pdu_empty): - """Reset Query PDU.""" + """ + Reset Query PDU. + """ pdu_type = 2 def serve(self, server): - """Received a reset query, send full current state in response.""" + """ + Received a reset query, send full current state in response. + """ log(self) if server.get_serial() is None: self.send_nodata(server) @@ -243,35 +286,45 @@ class reset_query(pdu_empty): server.push_pdu(error_report(errno = error_report.codes["Internal Error"], errpdu = self, errmsg = "Couldn't open %s" % fn)) class cache_response(pdu_empty): - """Incremental Response PDU.""" + """ + Incremental Response PDU. + """ pdu_type = 3 class end_of_data(pdu_with_serial): - """End of Data PDU.""" + """ + End of Data PDU. + """ pdu_type = 7 def consume(self, client): - """Handle end_of_data response.""" + """ + Handle end_of_data response. + """ log(self) client.current_serial = self.serial class cache_reset(pdu_empty): - """Cache reset PDU.""" + """ + Cache reset PDU. + """ pdu_type = 8 def consume(self, client): - """Handle cache_reset response, by issuing a reset_query.""" + """ + Handle cache_reset response, by issuing a reset_query. + """ log(self) client.push_pdu(reset_query()) class prefix(pdu): - """Object representing one prefix. This corresponds closely to one - PDU in the rpki-router protocol, so closely that we use lexical - ordering of the wire format of the PDU as the ordering for this - class. + """ + Object representing one prefix. This corresponds closely to one PDU + in the rpki-router protocol, so closely that we use lexical ordering + of the wire format of the PDU as the ordering for this class. """ source = 0 # Source (0 == RPKI) @@ -281,7 +334,10 @@ class prefix(pdu): @classmethod def from_asn1(cls, asnum, t): - """Read a prefix from a ROA in the tuple format used by our ASN.1 decoder.""" + """ + Read a prefix from a ROA in the tuple format used by our ASN.1 + decoder. + """ assert len(t[0]) <= cls.addr_type.bits x = 0L for y in t[0]: @@ -311,14 +367,18 @@ class prefix(pdu): print "# Announce: ", self.announce def check(self): - """Check attributes to make sure they're within range.""" + """ + Check attributes to make sure they're within range. + """ assert self.announce in (0, 1) assert self.prefixlen >= 0 and self.prefixlen <= self.addr_type.bits assert self.max_prefixlen >= self.prefixlen and self.max_prefixlen <= self.addr_type.bits assert len(self.to_pdu()) == 12 + self.addr_type.bits / 8, "Expected %d byte PDU, got %d" % (12 + self.addr_type.bits / 8, len(self.to_pdu())) def to_pdu(self, announce = None): - """Generate the wire format PDU for this prefix.""" + """ + Generate the wire format PDU for this prefix. + """ if announce is not None: assert announce in (0, 1) elif self._pdu is not None: @@ -350,17 +410,23 @@ class prefix(pdu): return self class ipv4_prefix(prefix): - """IPv4 flavor of a prefix.""" + """ + IPv4 flavor of a prefix. + """ pdu_type = 4 addr_type = rpki.ipaddrs.v4addr class ipv6_prefix(prefix): - """IPv6 flavor of a prefix.""" + """ + IPv6 flavor of a prefix. + """ pdu_type = 6 addr_type = rpki.ipaddrs.v6addr class error_report(pdu): - """Error Report PDU.""" + """ + Error Report PDU. + """ pdu_type = 10 @@ -382,7 +448,9 @@ class error_report(pdu): return "Error #%s: %s" % (self.errno, self.errmsg) def to_pdu(self): - """Generate the wire format PDU for this prefix.""" + """ + Generate the wire format PDU for this prefix. + """ if self._pdu is None: assert isinstance(self.errno, int) assert not isinstance(self.errpdu, error_report) @@ -418,14 +486,17 @@ pdu.pdu_map = dict((p.pdu_type, p) for p in (ipv4_prefix, ipv6_prefix, serial_no cache_response, end_of_data, cache_reset, error_report)) class prefix_set(list): - """Object representing a set of prefixes, that is, one versioned and + """ + Object representing a set of prefixes, that is, one versioned and (theoretically) consistant set of prefixes extracted from rcynic's output. """ @classmethod def _load_file(cls, filename): - """Low-level method to read prefix_set from a file.""" + """ + Low-level method to read prefix_set from a file. + """ self = cls() f = open(filename, "rb") r = read_buffer() @@ -441,14 +512,16 @@ class prefix_set(list): self.append(p) class axfr_set(prefix_set): - """Object representing a complete set of prefixes, that is, one + """ + Object representing a complete set of prefixes, that is, one versioned and (theoretically) consistant set of prefixes extracted from rcynic's output, all with the announce field set. """ @classmethod def parse_rcynic(cls, rcynic_dir): - """Parse ROAS fetched (and validated!) by rcynic to create a new + """ + Parse ROAS fetched (and validated!) by rcynic to create a new axfr_set. """ self = cls() @@ -470,7 +543,9 @@ class axfr_set(prefix_set): @classmethod def load(cls, filename): - """Load an axfr_set from a file, parse filename to obtain serial.""" + """ + Load an axfr_set from a file, parse filename to obtain serial. + """ fn1, fn2 = os.path.basename(filename).split(".") assert fn1.isdigit() and fn2 == "ax" self = cls._load_file(filename) @@ -478,18 +553,24 @@ class axfr_set(prefix_set): return self def filename(self): - """Generate filename for this axfr_set.""" + """ + Generate filename for this axfr_set. + """ return "%d.ax" % self.serial def save_axfr(self): - """Write axfr__set to file with magic filename.""" + """ + Write axfr__set to file with magic filename. + """ f = open(self.filename(), "wb") for p in self: f.write(p.to_pdu()) f.close() def mark_current(self): - """Mark the current serial number as current.""" + """ + Mark the current serial number as current. + """ tmpfn = "current.%d.tmp" % os.getpid() try: f = open(tmpfn, "w") @@ -501,10 +582,11 @@ class axfr_set(prefix_set): raise def save_ixfr(self, other): - """Comparing this axfr_set with an older one and write the - resulting ixfr_set to file with magic filename. Since we store - prefix_sets in sorted order, computing the difference is a trivial - linear comparison. + """ + Comparing this axfr_set with an older one and write the resulting + ixfr_set to file with magic filename. Since we store prefix_sets + in sorted order, computing the difference is a trivial linear + comparison. """ f = open("%d.ix.%d" % (self.serial, other.serial), "wb") old = other[:] @@ -524,21 +606,26 @@ class axfr_set(prefix_set): f.close() def show(self): - """Print this axfr_set.""" + """ + Print this axfr_set. + """ print "# AXFR %d (%s)" % (self.serial, rpki.sundial.datetime.utcfromtimestamp(self.serial)) for p in self: print p class ixfr_set(prefix_set): - """Object representing an incremental set of prefixes, that is, the + """ + Object representing an incremental set of prefixes, that is, the differences between one versioned and (theoretically) consistant set - of prefixes extracted from rcynic's output and another, with the announce - fields set or cleared as necessary to indicate the changes. + of prefixes extracted from rcynic's output and another, with the + announce fields set or cleared as necessary to indicate the changes. """ @classmethod def load(cls, filename): - """Load an ixfr_set from a file, parse filename to obtain serials.""" + """ + Load an ixfr_set from a file, parse filename to obtain serials. + """ fn1, fn2, fn3 = os.path.basename(filename).split(".") assert fn1.isdigit() and fn2 == "ix" and fn3.isdigit() self = cls._load_file(filename) @@ -547,18 +634,24 @@ class ixfr_set(prefix_set): return self def filename(self): - """Generate filename for this ixfr_set.""" + """ + Generate filename for this ixfr_set. + """ return "%d.ix.%d" % (self.to_serial, self.from_serial) def show(self): - """Print this ixfr_set.""" + """ + Print this ixfr_set. + """ print "# IXFR %d (%s) -> %d (%s)" % (self.from_serial, rpki.sundial.datetime.utcfromtimestamp(self.from_serial), self.to_serial, rpki.sundial.datetime.utcfromtimestamp(self.to_serial)) for p in self: print p class file_producer(object): - """File-based producer object for asynchat.""" + """ + File-based producer object for asynchat. + """ def __init__(self, handle, buffersize): self.handle = handle @@ -568,10 +661,11 @@ class file_producer(object): return self.handle.read(self.buffersize) class pdu_channel(asynchat.async_chat): - """asynchat subclass that understands our PDUs. This just handles - the network I/O. Specific engines (client, server) should be - subclasses of this with methods that do something useful with the - resulting PDUs. + """ + asynchat subclass that understands our PDUs. This just handles + network I/O. Specific engines (client, server) should be subclasses + of this with methods that do something useful with the resulting + PDUs. """ def __init__(self, conn = None): @@ -579,7 +673,9 @@ class pdu_channel(asynchat.async_chat): self.reader = read_buffer() def start_new_pdu(self): - """Start read of a new PDU.""" + """ + Start read of a new PDU. + """ p = pdu.read_pdu(self.reader) while p is not None: self.deliver_pdu(p) @@ -588,11 +684,14 @@ class pdu_channel(asynchat.async_chat): self.set_terminator(self.reader.needed()) def collect_incoming_data(self, data): - """Collect data into the read buffer.""" + """ + Collect data into the read buffer. + """ self.reader.put(data) def found_terminator(self): - """Got requested data, see if we now have a PDU. If so, pass it + """ + Got requested data, see if we now have a PDU. If so, pass it along, then restart cycle for a new PDU. """ p = self.reader.retry() @@ -603,34 +702,47 @@ class pdu_channel(asynchat.async_chat): self.start_new_pdu() def push_pdu(self, pdu): - """Write PDU to stream.""" + """ + Write PDU to stream. + """ self.push(pdu.to_pdu()) def push_file(self, f): - """Write content of a file to stream.""" + """ + Write content of a file to stream. + """ self.push_with_producer(file_producer(f, self.ac_out_buffer_size)) def log(self, msg): - """Intercept asyncore's logging.""" + """ + Intercept asyncore's logging. + """ log(msg) def log_info(self, msg, tag = "info"): - """Intercept asynchat's logging.""" + """ + Intercept asynchat's logging. + """ log("asynchat: %s: %s" % (tag, msg)) def handle_error(self): - """Handle errors caught by asyncore main loop.""" + """ + Handle errors caught by asyncore main loop. + """ log(traceback.format_exc()) log("Exiting after unhandled exception") asyncore.close_all() class server_channel(pdu_channel): - """Server protocol engine, handles upcalls from pdu_channel to + """ + Server protocol engine, handles upcalls from pdu_channel to implement protocol logic. """ def __init__(self): - """Set up stdin as connection and start listening for first PDU.""" + """ + Set up stdin as connection and start listening for first PDU. + """ pdu_channel.__init__(self) # # I don't know a sane way to get asynchat.async_chat.__init__() to @@ -652,19 +764,24 @@ class server_channel(pdu_channel): self.start_new_pdu() def deliver_pdu(self, pdu): - """Handle received PDU.""" + """ + Handle received PDU. + """ pdu.serve(self) def handle_close(self): - """Intercept close event so we can shut down other sockets.""" + """ + Intercept close event so we can shut down other sockets. + """ asynchat.async_chat.handle_close(self) asyncore.close_all() def get_serial(self): - """Read, cache, and return current serial number, or None if we - can't find the serial number file. The latter condition should - never happen, but maybe we got started in server mode while the - cronjob mode instance is still building its database. + """ + Read, cache, and return current serial number, or None if we can't + find the serial number file. The latter condition should never + happen, but maybe we got started in server mode while the cronjob + mode instance is still building its database. """ try: f = open("current", "r") @@ -676,19 +793,25 @@ class server_channel(pdu_channel): return self.current_serial def check_serial(self): - """Check for a new serial number.""" + """ + Check for a new serial number. + """ old_serial = self.current_serial return old_serial != self.get_serial() def notify(self, data = None): - """Cronjob instance kicked us, send a notify message.""" + """ + Cronjob instance kicked us, send a notify message. + """ if self.check_serial(): self.push_pdu(serial_notify(serial = self.current_serial)) else: log("Cronjob kicked me without a valid current serial number") class client_channel(pdu_channel): - """Client protocol engine, handles upcalls from pdu_channel.""" + """ + Client protocol engine, handles upcalls from pdu_channel. + """ current_serial = None @@ -697,7 +820,9 @@ class client_channel(pdu_channel): debug_using_direct_server_subprocess = True def __init__(self, *sshargs): - """Set up ssh connection and start listening for first PDU.""" + """ + Set up ssh connection and start listening for first PDU. + """ s = socket.socketpair() if self.debug_using_direct_server_subprocess: log("[Ignoring ssh arguments, using direct subprocess kludge for testing]") @@ -709,11 +834,14 @@ class client_channel(pdu_channel): self.start_new_pdu() def deliver_pdu(self, pdu): - """Handle received PDU.""" + """ + Handle received PDU. + """ pdu.consume(self) def cleanup(self): - """Force clean up this client's child process. If everything goes + """ + Force clean up this client's child process. If everything goes well, child will have exited already before this method is called, but we may need to whack it with a stick if something breaks. """ @@ -726,13 +854,16 @@ class client_channel(pdu_channel): pass def handle_close(self): - """Intercept close event so we can log it.""" + """ + Intercept close event so we can log it. + """ log("Server closed channel") asynchat.async_chat.handle_close(self) class kickme_channel(asyncore.dispatcher): - """asyncore dispatcher for the PF_UNIX socket that cronjob mode uses - to kick servers when it's time to send notify PDUs to clients. + """ + asyncore dispatcher for the PF_UNIX socket that cronjob mode uses to + kick servers when it's time to send notify PDUs to clients. """ def __init__(self, server): @@ -740,23 +871,35 @@ class kickme_channel(asyncore.dispatcher): self.server = server self.sockname = "%s.%d" % (kickme_base, os.getpid()) self.create_socket(socket.AF_UNIX, socket.SOCK_DGRAM) - self.bind(self.sockname) + try: + self.bind(self.sockname) + except socket.error, e: + log("Couldn't bind kickme socket: %r" % e) + self.close() def writable(self): - """This socket is read-only, never writable.""" + """ + This socket is read-only, never writable. + """ return False def handle_connect(self): - """Ignore connect events (not very useful on datagram socket).""" + """ + Ignore connect events (not very useful on datagram socket). + """ pass def handle_read(self): - """Handle receipt of a datagram.""" + """ + Handle receipt of a datagram. + """ data = self.recv(512) self.server.notify(data) def cleanup(self): - """Clean up this dispatcher's socket.""" + """ + Clean up this dispatcher's socket. + """ self.close() try: os.unlink(self.sockname) @@ -764,21 +907,29 @@ class kickme_channel(asyncore.dispatcher): pass def log(self, msg): - """Intercept asyncore's logging.""" + """ + Intercept asyncore's logging. + """ log(msg) def log_info(self, msg, tag = "info"): - """Intercept asyncore's logging.""" + """ + Intercept asyncore's logging. + """ log("asyncore: %s: %s" % (tag, msg)) def handle_error(self): - """Handle errors caught by asyncore main loop.""" + """ + Handle errors caught by asyncore main loop. + """ log(traceback.format_exc()) log("Exiting after unhandled exception") asyncore.close_all() def cronjob_main(argv): - """Main program for cronjob.""" + """ + Main program for cronjob. + """ if len(argv) != 1: raise RuntimeError, "Expected one argument, got %s" % argv @@ -801,6 +952,12 @@ def cronjob_main(argv): print "# New serial is %s" % pdus.serial + try: + os.stat(kickme_dir) + except OSError: + print '# Creating directory "%s"' % kickme_dir + os.makedirs(kickme_dir) + msg = "Good morning, serial %s is ready" % pdus.serial sock = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) for name in glob.iglob("%s.*" % kickme_base): @@ -817,7 +974,10 @@ def cronjob_main(argv): os.unlink(ixfr) def show_main(argv): - """Main program for show mode. Just displays current AXFR and IXFR dumps""" + """ + Main program for show mode. Just displays current AXFR and IXFR + dumps + """ if argv: raise RuntimeError, "Unexpected arguments: %s" % argv @@ -833,8 +993,9 @@ def show_main(argv): ixfr_set.load(f).show() def server_main(argv): - """Main program for server mode. Server is event driven, so - everything interesting happens in the channel classes. + """ + Main program for server mode. Server is event driven, so everything + interesting happens in the channel classes. In production use this server is run under sshd. The subsystem mechanism in sshd does not allow us to pass arguments on the command @@ -857,7 +1018,9 @@ def server_main(argv): kickme.cleanup() class client_timer(rpki.async.timer): - """Timer class for client mode, to handle the periodic serial queries.""" + """ + Timer class for client mode, to handle the periodic serial queries. + """ def __init__(self, client, period): rpki.async.timer.__init__(self) @@ -873,7 +1036,9 @@ class client_timer(rpki.async.timer): self.set(self.period) def client_main(argv): - """Main program for client mode. Not really written yet.""" + """ + Main program for client mode. This is just test code. + """ log("[Startup]") if argv: raise RuntimeError, "Unexpected arguments: %s" % argv @@ -889,7 +1054,6 @@ def client_main(argv): raise def log(msg): - """Temporary.""" rpki.log.warn(str(msg)) os.environ["TZ"] = "UTC" @@ -899,7 +1063,8 @@ cfg_file = "rtr-origin.conf" mode = None -kickme_base = "sockets/kickme" +kickme_dir = "sockets" +kickme_base = os.path.join(kickme_dir, "kickme") main_dispatch = { "cronjob" : cronjob_main, |