aboutsummaryrefslogtreecommitdiff
path: root/rpki/rtr
diff options
context:
space:
mode:
Diffstat (limited to 'rpki/rtr')
-rw-r--r--rpki/rtr/client.py54
-rw-r--r--rpki/rtr/main.py8
2 files changed, 33 insertions, 29 deletions
diff --git a/rpki/rtr/client.py b/rpki/rtr/client.py
index 6d567a1b..eb9ce5a0 100644
--- a/rpki/rtr/client.py
+++ b/rpki/rtr/client.py
@@ -181,18 +181,21 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
expire = rpki.rtr.pdus.default_expire
updated = Timestamp(0)
- def __init__(self, sock, proc, killsig, host, port, version):
+ def __init__(self, sock, proc, killsig, args, host = None, port = None):
self.killsig = killsig
self.proc = proc
- self.host = host
- self.port = port
+ self.args = args
+ self.host = args.host if host is None else host
+ self.port = args.port if port is None else port
super(ClientChannel, self).__init__(sock = sock, root_pdu_class = PDU)
- if version is not None:
- self.version = version
+ if args.force_version is not None:
+ self.version = args.force_version
self.start_new_pdu()
+ if args.sql_database:
+ self.setup_sql()
@classmethod
- def ssh(cls, host, port, version):
+ def ssh(cls, args):
"""
Set up ssh connection and start listening for first PDU.
"""
@@ -203,11 +206,10 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
return cls(sock = s[1],
proc = subprocess.Popen(argv, executable = "/usr/bin/ssh",
stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGKILL,
- host = host, port = port, version = version)
+ killsig = signal.SIGKILL, args = args)
@classmethod
- def tcp(cls, host, port, version):
+ def tcp(cls, args):
"""
Set up TCP connection and start listening for first PDU.
"""
@@ -232,12 +234,11 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
logging.exception("[socket.connect() failed: %s]", e)
s.close()
continue
- return cls(sock = s, proc = None, killsig = None,
- host = host, port = port, version = version)
+ return cls(sock = s, proc = None, killsig = None, args = args)
sys.exit(1)
@classmethod
- def loopback(cls, host, port, version):
+ def loopback(cls, args):
"""
Set up loopback connection and start listening for first PDU.
"""
@@ -247,11 +248,11 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
argv = (sys.executable, sys.argv[0], "server")
return cls(sock = s[1],
proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGINT,
- host = host or "none", port = port or "none", version = version)
+ killsig = signal.SIGINT, args = args,
+ host = args.host or "none", port = args.port or "none")
@classmethod
- def tls(cls, host, port, version):
+ def tls(cls, args):
"""
Set up TLS connection and start listening for first PDU.
@@ -268,18 +269,17 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
s = socket.socketpair()
return cls(sock = s[1],
proc = subprocess.Popen(argv, stdin = s[0], stdout = s[0], close_fds = True),
- killsig = signal.SIGKILL,
- host = host, port = port, version = version)
+ killsig = signal.SIGKILL, args = args)
- def setup_sql(self, sqlname, reset_session):
+ def setup_sql(self):
"""
Set up an SQLite database to contain the table we receive. If
necessary, we will create the database.
"""
import sqlite3
- missing = not os.path.exists(sqlname)
- self.sql = sqlite3.connect(sqlname, detect_types = sqlite3.PARSE_DECLTYPES)
+ missing = not os.path.exists(self.args.sql_database)
+ self.sql = sqlite3.connect(self.args.sql_database, detect_types = sqlite3.PARSE_DECLTYPES)
self.sql.text_factory = str
cur = self.sql.cursor()
cur.execute("PRAGMA foreign_keys = on")
@@ -319,7 +319,7 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
key TEXT NOT NULL,
UNIQUE (cache_id, asn, ski),
UNIQUE (cache_id, asn, key))''')
- elif reset_session:
+ elif self.args.reset_session:
cur.execute("DELETE FROM cache WHERE host = ? and port = ?", (self.host, self.port))
cur.execute("SELECT cache_id, version, nonce, serial, refresh, retry, expire, updated "
"FROM cache WHERE host = ? AND port = ?",
@@ -452,6 +452,10 @@ class ClientChannel(rpki.rtr.channels.PDUChannel):
super(ClientChannel, self).handle_close()
+# Hack to let us subclass this from scripts without needing to rewrite client_main().
+
+ClientChannelClass = ClientChannel
+
def client_main(args):
"""
Test client, intended primarily for debugging.
@@ -459,13 +463,12 @@ def client_main(args):
logging.debug("[Startup]")
- constructor = getattr(rpki.rtr.client.ClientChannel, args.protocol)
+ assert issubclass(ClientChannelClass, ClientChannel)
+ constructor = getattr(ClientChannelClass, args.protocol)
client = None
try:
- client = constructor(host = args.host, port = args.port, version = args.force_version)
- if args.sql_database:
- client.setup_sql(args.sql_database, args.reset_session)
+ client = constructor(args)
polled = client.updated
wakeup = None
@@ -523,3 +526,4 @@ def argparse_setup(subparsers):
subparser.add_argument("protocol", choices = ("loopback", "tcp", "ssh", "tls"), help = "connection protocol")
subparser.add_argument("host", nargs = "?", help = "server host")
subparser.add_argument("port", nargs = "?", help = "server port")
+ return subparser
diff --git a/rpki/rtr/main.py b/rpki/rtr/main.py
index 6add407d..29a4873b 100644
--- a/rpki/rtr/main.py
+++ b/rpki/rtr/main.py
@@ -28,10 +28,6 @@ import logging
import logging.handlers
import argparse
-from rpki.rtr.server import argparse_setup as argparse_setup_server
-from rpki.rtr.client import argparse_setup as argparse_setup_client
-from rpki.rtr.generator import argparse_setup as argparse_setup_generator
-
class Formatter(logging.Formatter):
@@ -57,6 +53,10 @@ def main():
os.environ["TZ"] = "UTC"
time.tzset()
+ from rpki.rtr.server import argparse_setup as argparse_setup_server
+ from rpki.rtr.client import argparse_setup as argparse_setup_client
+ from rpki.rtr.generator import argparse_setup as argparse_setup_generator
+
if "rpki.rtr.bgpdump" in sys.modules:
from rpki.rtr.bgpdump import argparse_setup as argparse_setup_bgpdump
else: