aboutsummaryrefslogtreecommitdiff
path: root/scripts/rpki/x509.py
blob: 865a193e613b34a4d2349947088cad654aa96d22 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# $Id$

"""
One X.509 implementation to rule them all and in the darkness hide the
twisty maze of partially overlapping X.509 support packages in Python.

There are several existing packages, none of which do quite what I
need, due to age, lack of documentation, specialization, or lack of
foresight on somebody's part (perhaps mine).  This module attempts to
bring together the functionality I need in a way that hides at least
some of the nasty details.  This involves a lot of format conversion.
"""

import POW, tlslite.api, POW.pkix

class X509(object):
  """
  Class to hold all the different representations of X.509 certs we're
  using and convert between them.
  """

  DER = None
  PEM = None
  POW = None
  POWpkix = None
  tlslite = None

  def empty(self):
    return self.DER is None and self.PEM is None and self.POW is None and self.POWpkix is None and self.tlslite is None

  def clear(self):
    self.DER = None
    self.PEM = None
    self.POW = None
    self.POWpkix = None
    self.tlslite = None

  def __init__(self, **kw):
    if len(kw):
      self.set(**kw)

  def set(self, **kw):
    name = kw.keys()[0]
    if len(kw) == 1:
      if name in ("DER", "PEM", "POW", "POWpkix", "tlslite"):
        self.clear()
        setattr(self, name, kw[name])
        return
      if name in ("PEM_file", "DER_file"):
        f = open(kw[name], "r")
        text = f.read()
        f.close()
        self.clear()
        if name == "PEM_file":
          self.PEM = text
        else:
          self.DER = text
        return
    raise RuntimeError                  # Should create our own exception classes

  def get_DER(self):
    assert not self.empty()
    if self.DER:
      return self.DER
    if self.POW:
      self.DER = self.POW.derWrite()
      return self.get_DER()
    if self.POWpkix:
      self.DER = self.POWpkix.toString()
      return self.get_DER()
    if self.PEM:
      self.POW = POW.pemRead(POW.X509_CERTIFICATE, self.PEM)
      return self.get_DER()
    raise RuntimeError

  def get_POW(self):
    assert not self.empty()
    if not self.POW:
      self.POW = POW.derRead(POW.X509_CERTIFICATE, self.get_DER())
    return self.POW

  def get_PEM(self):
    assert not self.empty()
    if not self.PEM:
      self.PEM = self.get_POW().pemWrite()
    return self.PEM

  def get_POWpkix(self):
    assert not self.empty()
    if not self.POWpkix:
      cert = POW.pkix.Certificate()
      cert.fromString(self.get_DER())
      self.POWpkix = cert
    return self.POWpkix

  def get_tlslite(self):
    assert not self.empty()
    if not self.tlslite:
      cert = tlslite.X509.X509()
      cert.parseBinary(self.get_DER())
      self.tlslite = cert
    return self.tlslite

  def getIssuer(self):
    return self.get_POW().getIssuer()

  def getSubject(self):
    return self.get_POW().getSubject()

  def get_POW_extensions(self):
    if not self.POW_extensions:
      cert = self.get_POW()
      exts = {}
      for i in range(cert.countExtensions()):
        x = cert.getExtension(i)
        exts[x[0]] = x[2]
      self.POW_extensions = exts
    return self.POW_extensions
    
  def getAKI(self):
    return self.get_POW_extensions()["authorityKeyIdentifier"]

  def getSKI(self):
    return self.get_POW_extensions()["subjectKeyIdentifier"]

def sort_chain(bag):
  """
  Sort a bag of certs into a chain, leaf first.  Various other routines
  want their certs presented in this order.
  """

  issuer_names = [x.getIssuer() for x in bag]
  subject_map = dict([(x.getSubject(), x) for x in bag])
  chain = list(bag)
  issuers = []

  for subject in subject_map:
    if subject in issuer_names:
      cert = subject_map[subject]
      issuers.append(cert)
      chain.remove(cert)

  assert len(chain) == 1

  while issuers:
    issuer = subject_map[chain[-1].getIssuer()]
    assert issuer
    chain.append(issuer)
    issuers.remove(issuer)

  return chain