diff options
-rw-r--r-- | scripts/biz-certs/Bob-CA.srl | 2 | ||||
-rw-r--r-- | scripts/rpki/left_right.py | 33 | ||||
-rw-r--r-- | scripts/rpki/sql.py | 98 | ||||
-rw-r--r-- | scripts/rpki/sundial.py | 9 | ||||
-rw-r--r-- | scripts/rpki/x509.py | 9 |
5 files changed, 67 insertions, 84 deletions
diff --git a/scripts/biz-certs/Bob-CA.srl b/scripts/biz-certs/Bob-CA.srl index f06575a7..2c2ab5ae 100644 --- a/scripts/biz-certs/Bob-CA.srl +++ b/scripts/biz-certs/Bob-CA.srl @@ -1 +1 @@ -90801F1ED19454AB +90801F1ED19454AD diff --git a/scripts/rpki/left_right.py b/scripts/rpki/left_right.py index c4abffa9..f57accc1 100644 --- a/scripts/rpki/left_right.py +++ b/scripts/rpki/left_right.py @@ -61,26 +61,6 @@ class base_elt(object): class data_elt(base_elt, rpki.sql.sql_persistant): """Virtual class for top-level left-right protocol data elements.""" - def sql_decode(self, vals): - """Decode SQL form of a data_elt object.""" - rpki.sql.sql_persistant.sql_decode(self, vals) - if "cms_ta" in vals: - self.cms_ta = rpki.x509.X509(DER = vals["cms_ta"]) - if "https_ta" in vals: - self.https_ta = rpki.x509.X509(DER = vals["https_ta"]) - if "private_key_id" in vals: - self.private_key_id = rpki.x509.RSA(DER = vals["private_key_id"]) - if "public_key" in vals: - self.public_key = rpki.x509.RSA(DER = vals["public_key"]) - - def sql_encode(self): - """Encode SQL form of a data_elt object.""" - d = rpki.sql.sql_persistant.sql_encode(self) - for i in ("cms_ta", "https_ta", "private_key_id", "public_key"): - if i in d and d[i] is not None: - d[i] = d[i].get_DER() - return d - def make_reply(self, r_pdu = None): """Construct a reply PDU.""" if r_pdu is None: @@ -303,7 +283,9 @@ class bsc_elt(data_elt): elements = ('signing_cert',) booleans = ("generate_keypair", "clear_signing_certs") - sql_template = rpki.sql.template("bsc", "bsc_id", "self_id", "public_key", "private_key_id", "hash_alg") + sql_template = rpki.sql.template("bsc", "bsc_id", "self_id", + ("public_key", rpki.x509.RSApublic), + ("private_key_id", rpki.x509.RSA), "hash_alg") pkcs10_cert_request = None public_key = None @@ -384,7 +366,8 @@ class parent_elt(data_elt): booleans = ("rekey", "reissue", "revoke") sql_template = rpki.sql.template("parent", "parent_id", "self_id", "bsc_id", "repository_id", - "cms_ta", "https_ta", "peer_contact_uri", "sia_base") + ("cms_ta", rpki.x509.X509), ("https_ta", rpki.x509.X509), + "peer_contact_uri", "sia_base") cms_ta = None https_ta = None @@ -461,7 +444,7 @@ class child_elt(data_elt): elements = ("cms_ta",) booleans = ("reissue", ) - sql_template = rpki.sql.template("child", "child_id", "self_id", "bsc_id", "cms_ta") + sql_template = rpki.sql.template("child", "child_id", "self_id", "bsc_id", ("cms_ta", rpki.x509.X509)) cms_ta = None @@ -528,8 +511,8 @@ class repository_elt(data_elt): attributes = ("action", "type", "self_id", "repository_id", "bsc_id", "peer_contact_uri") elements = ("cms_ta", "https_ta") - sql_template = rpki.sql.template("repository", "repository_id", "self_id", "bsc_id", "cms_ta", - "peer_contact_uri") + sql_template = rpki.sql.template("repository", "repository_id", "self_id", "bsc_id", + ("cms_ta", rpki.x509.X509), "peer_contact_uri") cms_ta = None https_ta = None diff --git a/scripts/rpki/sql.py b/scripts/rpki/sql.py index 570fdee4..1074e829 100644 --- a/scripts/rpki/sql.py +++ b/scripts/rpki/sql.py @@ -13,13 +13,15 @@ def connect(cfg, section="sql"): class template(object): """SQL template generator.""" - def __init__(self, table_name, *columns): + def __init__(self, table_name, index_column, *data_columns): """Build a SQL template.""" - index_column = columns[0] - data_columns = columns[1:] + type_map = dict((x[0],x[1]) for x in data_columns if isinstance(x, tuple)) + data_columns = tuple(isinstance(x, tuple) and x[0] or x for x in data_columns) + columns = (index_column,) + data_columns self.table = table_name self.index = index_column self.columns = columns + self.map = type_map self.select = "SELECT %s FROM %s" % (", ".join(columns), table_name) self.insert = "INSERT %s (%s) VALUES (%s)" % (table_name, ", ".join(data_columns), ", ".join("%(" + s + ")s" for s in data_columns)) @@ -154,20 +156,26 @@ class sql_persistant(object): def sql_encode(self): """Convert object attributes into a dict for use with canned SQL queries. This is a default version that assumes a one-to-one - mapping between column names in SQL and attribute names in Python, - with no datatype conversion. If you need something fancier, - override this. + mapping between column names in SQL and attribute names in Python. + If you need something fancier, override this. """ - return dict((a, getattr(self, a, None)) for a in self.sql_template.columns) + d = dict((a, getattr(self, a, None)) for a in self.sql_template.columns) + for i in self.sql_template.map: + if d.get(i) is not None: + d[i] = self.sql_template.map[i].to_sql(d[i]) + return d def sql_decode(self, vals): """Initialize an object with values returned by self.sql_fetch(). This is a default version that assumes a one-to-one mapping - between column names in SQL and attribute names in Python, with no - datatype conversion. If you need something fancier, override this. + between column names in SQL and attribute names in Python. If you + need something fancier, override this. """ for a in self.sql_template.columns: - setattr(self, a, vals[a]) + if vals.get(a) is not None and a in self.sql_template.map: + setattr(self, a, self.sql_template.map[a].from_sql(vals[a])) + else: + setattr(self, a, vals[a]) def sql_fetch_hook(self, gctx): """Customization hook.""" @@ -192,21 +200,16 @@ class sql_persistant(object): class ca_obj(sql_persistant): """Internal CA object.""" - sql_template = template("ca", "ca_id", "last_crl_sn", "next_crl_update", "last_issued_sn", - "last_manifest_sn", "next_manifest_update", "sia_uri", "parent_id", - "parent_resource_class") + sql_template = template("ca", "ca_id", "last_crl_sn", + ("next_crl_update", rpki.sundial.datetime), + "last_issued_sn", "last_manifest_sn", + ("next_manifest_update", rpki.sundial.datetime), + "sia_uri", "parent_id", "parent_resource_class") last_crl_sn = 0 last_issued_sn = 0 last_manifest_sn = 0 - def sql_decode(self, vals): - """Decode SQL representation of a ca_obj.""" - sql_persistant.sql_decode(self, vals) - for i in ("next_crl_update", "next_manifest_update"): - if vals.get(i) is not None: - setattr(self, i, rpki.sundial.datetime.fromdatetime(vals[i])) - def construct_sia_uri(self, gctx, parent, rc): """Construct the sia_uri value for this CA given configured information and the parent's up-down protocol list_response PDU. @@ -310,39 +313,29 @@ class ca_obj(sql_persistant): class ca_detail_obj(sql_persistant): """Internal CA detail object.""" - sql_template = template("ca_detail", "ca_detail_id", "private_key_id", "public_key", "latest_ca_cert", - "manifest_private_key_id", "manifest_public_key", "latest_manifest_cert", - "latest_manifest", "latest_crl", "state", "state_timer", "ca_cert_uri", "ca_id") - + sql_template = template("ca_detail", + "ca_detail_id", + ("private_key_id", rpki.x509.RSA), + ("public_key", rpki.x509.RSApublic), + ("latest_ca_cert", rpki.x509.X509), + ("manifest_private_key_id", rpki.x509.RSA), + ("manifest_public_key", rpki.x509.RSApublic), + ("latest_manifest_cert", rpki.x509.X509), + ("latest_manifest", rpki.x509.SignedManifest), + ("latest_crl", rpki.x509.CRL), + "state", + ("state_timer", rpki.sundial.datetime), + "ca_cert_uri", + "ca_id") + def sql_decode(self, vals): - """Decode SQL representation of a ca_detail_obj.""" + """Extra assertions for SQL decode of a ca_detail_obj.""" sql_persistant.sql_decode(self, vals) - for i,t in (("private_key_id", rpki.x509.RSA), - ("public_key", rpki.x509.RSApublic), - ("latest_ca_cert", rpki.x509.X509), - ("manifest_private_key_id", rpki.x509.RSA), - ("manifest_public_key", rpki.x509.RSApublic), - ("latest_manifest_cert", rpki.x509.X509), - ("latest_manifest", rpki.x509.SignedManifest), - ("latest_crl", rpki.x509.CRL)): - if getattr(self, i, None) is not None: - setattr(self, i, t(DER = getattr(self, i))) - if vals.get("state_timer") is not None: - self.state_timer = rpki.sundial.datetime.fromdatetime(vals["state_timer"]) assert (self.public_key is None and self.private_key_id is None) or \ self.public_key.get_DER() == self.private_key_id.get_public_DER() assert (self.manifest_public_key is None and self.manifest_private_key_id is None) or \ self.manifest_public_key.get_DER() == self.manifest_private_key_id.get_public_DER() - def sql_encode(self): - """Encode SQL representation of a ca_detail_obj.""" - d = sql_persistant.sql_encode(self) - for i in ("private_key_id", "public_key", "latest_ca_cert", "manifest_private_key_id", - "manifest_public_key", "latest_manifest_cert", "latest_manifest", "latest_crl"): - if d[i] is not None: - d[i] = d[i].get_DER() - return d - @classmethod def sql_fetch_active(cls, gctx, ca_id): """Fetch the current active ca_detail_obj associated with a given ca_id.""" @@ -503,7 +496,7 @@ class ca_detail_obj(sql_persistant): class child_cert_obj(sql_persistant): """Certificate that has been issued to a child.""" - sql_template = template("child_cert", "child_cert_id", "cert", "child_id", "ca_detail_id", "ski", "revoked") + sql_template = template("child_cert", "child_cert_id", ("cert", rpki.x509.X509), "child_id", "ca_detail_id", "ski", "revoked") def __init__(self, child_id = None, ca_detail_id = None, cert = None): """Initialize a child_cert_obj.""" @@ -514,17 +507,6 @@ class child_cert_obj(sql_persistant): if child_id or ca_detail_id or cert: self.sql_mark_dirty() - def sql_decode(self, vals): - """Decode SQL representation of a child_cert_obj.""" - sql_persistant.sql_decode(self, vals) - self.cert = rpki.x509.X509(DER = self.cert) - - def sql_encode(self): - """Encode SQL representation of a child_cert_obj.""" - d = sql_persistant.sql_encode(self) - d["cert"] = self.cert.get_DER() - return d - def reissue(self, gctx, ca_detail, resources, sia, valid_until): """Reissue an existing child_cert_obj, reusing the public key.""" diff --git a/scripts/rpki/sundial.py b/scripts/rpki/sundial.py index 96a0d591..320abfd6 100644 --- a/scripts/rpki/sundial.py +++ b/scripts/rpki/sundial.py @@ -82,6 +82,15 @@ class datetime(pydatetime.datetime): """Force correct class for timedelta results.""" return self.fromdatetime(pydatetime.datetime.__sub__(self, other)) + @classmethod + def from_sql(cls, x): + """Convert from SQL storage format.""" + return cls.fromdatetime(x) + + def to_sql(self): + """Convert to SQL storage format.""" + return self + # Alias to simplify imports for callers timedelta = pydatetime.timedelta diff --git a/scripts/rpki/x509.py b/scripts/rpki/x509.py index e5539d83..488c1d73 100644 --- a/scripts/rpki/x509.py +++ b/scripts/rpki/x509.py @@ -174,6 +174,15 @@ class DER_object(object): pass return resources + @classmethod + def from_sql(cls, x): + """Convert from SQL storage format.""" + return cls(DER = x) + + def to_sql(self): + """Convert to SQL storage format.""" + return self.get_DER() + class X509(DER_object): """X.509 certificates. |