diff options
-rwxr-xr-x | rtr-origin/rtr-origin.py | 164 |
1 files changed, 139 insertions, 25 deletions
diff --git a/rtr-origin/rtr-origin.py b/rtr-origin/rtr-origin.py index 545933da..aa74f194 100755 --- a/rtr-origin/rtr-origin.py +++ b/rtr-origin/rtr-origin.py @@ -93,7 +93,6 @@ class v6addr(ipaddr): af = socket.AF_INET6 size = 16 - def read_current(): """ Read current serial number and nonce. Return None for both if @@ -457,15 +456,25 @@ class reset_query(pdu_empty): fn = "%d.ax" % server.current_serial self.send_file(server, fn) except IOError: - server.push_pdu(error_report(errno = error_report.codes["Internal Error"], errpdu = self, errmsg = "Couldn't open %s" % fn)) + server.push_pdu(error_report(errno = error_report.codes["Internal Error"], + errpdu = self, errmsg = "Couldn't open %s" % fn)) class cache_response(pdu_nonce): """ - Incremental Response PDU. + Cache Response PDU. """ pdu_type = 3 + def consume(self, client): + """ + Handle cache_response. + """ + blather(self) + if self.nonce != client.current_nonce: + blather("[Nonce changed, resetting]") + client.cache_reset() + class end_of_data(pdu_with_serial): """ End of Data PDU. @@ -478,8 +487,7 @@ class end_of_data(pdu_with_serial): Handle end_of_data response. """ blather(self) - client.current_serial = self.serial - client.current_nonce = self.nonce + client.end_of_data(self.serial, self.nonce) class cache_reset(pdu_empty): """ @@ -493,6 +501,7 @@ class cache_reset(pdu_empty): Handle cache_reset response, by issuing a reset_query. """ blather(self) + client.cache_reset() client.push_pdu(reset_query()) class prefix(pdu): @@ -539,6 +548,13 @@ class prefix(pdu): blather("# MaxPrefixlen: %s" % self.max_prefixlen) blather("# Announce: %s" % self.announce) + def consume(self, client): + """ + Handle one incoming prefix PDU + """ + blather(self) + client.consume_prefix(self) + def check(self): """ Check attributes to make sure they're within range. @@ -1213,10 +1229,16 @@ class client_channel(pdu_channel): current_serial = None current_nonce = None + sql = None + host = None + port = None + cache_id = None - def __init__(self, sock, proc, killsig): + def __init__(self, sock, proc, killsig, host, port): self.killsig = killsig self.proc = proc + self.host = host + self.port = port pdu_channel.__init__(self, conn = sock) self.start_new_pdu() @@ -1229,8 +1251,10 @@ class client_channel(pdu_channel): blather("[Running ssh: %s]" % " ".join(args)) s = socket.socketpair() return cls(sock = s[1], - proc = subprocess.Popen(args, executable = "/usr/bin/ssh", stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGKILL) + proc = subprocess.Popen(args, executable = "/usr/bin/ssh", + stdin = s[0], stdout = s[0], close_fds = True), + killsig = signal.SIGKILL, + host = host, port = port) @classmethod def tcp(cls, host, port): @@ -1257,11 +1281,12 @@ class client_channel(pdu_channel): blather("[socket.connect() failed: %s]" % e) s.close() continue - return cls(sock = s, proc = None, killsig = None) + return cls(sock = s, proc = None, killsig = None, + host = host, port = port) sys.exit(1) @classmethod - def loopback(cls): + def loopback(cls, host, port): """ Set up loopback connection and start listening for first PDU. """ @@ -1272,7 +1297,8 @@ class client_channel(pdu_channel): argv.extend(("--syslog", sys.argv[sys.argv.index("--syslog") + 1])) return cls(sock = s[1], proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGINT) + killsig = signal.SIGINT, + host = host, port = port) @classmethod def tls(cls, host, port): @@ -1291,7 +1317,84 @@ class client_channel(pdu_channel): s = socket.socketpair() return cls(sock = s[1], proc = subprocess.Popen(args, stdin = s[0], stdout = s[0], close_fds = True), - killsig = signal.SIGKILL) + killsig = signal.SIGKILL, + host = host, port = port) + + def setup_sql(self, sqlname): + """ + Set up an SQLite database to contain the table we receive. If + necessary, we will create the database. + """ + import sqlite3 + missing = not os.path.exists(sqlname) + self.sql = sqlite3.connect(sqlname, detect_types = sqlite3.PARSE_DECLTYPES) + self.sql.text_factory = str + cur = self.sql.cursor() + cur.execute("PRAGMA foreign_keys = on") + if missing: + cur.execute(''' + CREATE TABLE cache ( + cache_id INTEGER PRIMARY KEY NOT NULL, + host TEXT NOT NULL, + port TEXT NOT NULL, + nonce INTEGER, + serial INTEGER, + updated INTEGER, + UNIQUE (host, port))''') + cur.execute(''' + CREATE TABLE prefix ( + cache_id INTEGER NOT NULL + REFERENCES cache(cache_id) + ON DELETE CASCADE + ON UPDATE CASCADE, + asn INTEGER NOT NULL, + prefix TEXT NOT NULL, + prefixlen INTEGER NOT NULL, + max_prefixlen INTEGER NOT NULL, + UNIQUE (cache_id, asn, prefix, prefixlen, max_prefixlen))''') + cur.execute("SELECT cache_id, nonce, serial FROM cache WHERE host = ? AND port = ?", + (self.host, self.port)) + try: + self.cache_id, self.current_nonce, self.current_serial = cur.fetchone() + except TypeError: + cur.execute("INSERT INTO cache (host, port) VALUES (?, ?)", (self.host, self.port)) + self.cache_id = cur.lastrowid + self.sql.commit() + + def cache_reset(self): + """ + Handle cache_reset actions. + """ + self.current_serial = None + if self.sql: + cur = self.sql.cursor() + cur.execute("DELETE FROM prefix WHERE cache_id = ?", (self.cache_id,)) + cur.execute("UPDATE cache SET serial = NULL WHERE cache_id = ?", (self.cache_id,)) + + def end_of_data(self, serial, nonce): + """ + Handle end_of_data actions. + """ + self.current_serial = serial + self.current_nonce = nonce + if self.sql: + self.sql.execute("UPDATE cache SET serial = ?, nonce = ?, updated = datetime('now') WHERE cache_id = ?", + (serial, nonce, self.cache_id)) + self.sql.commit() + + def consume_prefix(self, prefix): + """ + Handle one prefix PDU. + """ + if self.sql: + values = (self.cache_id, prefix.asn, str(prefix.prefix), prefix.prefixlen, prefix.max_prefixlen) + if prefix.announce: + self.sql.execute("INSERT INTO prefix (cache_id, asn, prefix, prefixlen, max_prefixlen) VALUES (?, ?, ?, ?, ?)", + values) + else: + self.sql.execute("DELETE FROM prefix " + "WHERE cache_id = ? AND asn = ? AND prefix = ? AND prefixlen = ? AND max_prefixlen = ?", + values) def deliver_pdu(self, pdu): """ @@ -1700,24 +1803,35 @@ def client_main(argv): The remaining arguments should be a hostname (or IP address) and a TCP port number. - If the first argument is "tls", the client will attempt to open a TLS connection to the server. The - remaining arguments should be a hostname (or IP address) and a TCP - port number. + If the first argument is "tls", the client will attempt to open a + TLS connection to the server. The remaining arguments should be a + hostname (or IP address) and a TCP port number. + + An optional final name is the name of a file containing a SQLite + database in which to store the received table. If specified, this + database will be created if missing. """ blather("[Startup]") client = None + if not argv: + argv = ["loopback"] + proto = argv[0] + if proto == "loopback" and len(argv) in (1, 2): + constructor = client_channel.loopback + host, port = "", "" + sqlname = None if len(argv) == 1 else argv[1] + elif proto in ("ssh", "tcp", "tls") and len(argv) in (3, 4): + constructor = getattr(client_channel, proto) + host, port = argv[1:3] + sqlname = None if len(argv) == 3 else argv[3] + else: + sys.exit("Unexpected arguments: %s" % " ".join(argv)) + try: - if not argv or (argv[0] == "loopback" and len(argv) == 1): - client = client_channel.loopback() - elif argv[0] == "ssh" and len(argv) == 3: - client = client_channel.ssh(*argv[1:]) - elif argv[0] == "tcp" and len(argv) == 3: - client = client_channel.tcp(*argv[1:]) - elif argv[0] == "tls" and len(argv) == 3: - client = client_channel.tls(*argv[1:]) - else: - sys.exit("Unexpected arguments: %r" % (argv,)) + client = constructor(host, port) + if sqlname: + client.setup_sql(sqlname) while True: if client.current_serial is None or client.current_nonce is None: client.push_pdu(reset_query()) |