aboutsummaryrefslogtreecommitdiff
path: root/scripts/rpki/sql.py
blob: f7214a476c7c0113ea7ba9a172bfef90936e9b4f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# $Id$

import MySQLdb

def connect(cfg, section="sql"):
  """Connect to a MySQL database using connection parameters from an
     rpki.config.parser object.
  """

  return MySQLdb.connect(user   = cfg.get(section, "sql-username"),
                         db     = cfg.get(section, "sql-database"),
                         passwd = cfg.get(section, "sql-password"))

class sql_persistant(object):
  """Mixin for persistant class that needs to be stored in SQL.
  """

  ## @var sql_children
  # Dictionary listing this class's children in the tree of SQL
  # tables.  Key the name of the attribute in this class at which a
  # list of the resulting child objects are stored; value is is the
  # class object of a child.
  sql_children = {}

  ## @var sql_in_db
  # Whether this object is already in SQL or not.  Perhaps this should
  # instead be a None value in the object's ID field?
  sql_in_db = False

  ## @var sql_dirty
  # Whether this object has been modified and needs to be written back
  # to SQL.
  sql_dirty = False

  ## @var sql_attributes
  # Tuple of attributes to translate between this Python object and its SQL representation.
  sql_attributes = None                 # Must be overriden by derived type

  ## @var sql_id_name
  # Name of the (auto-increment) ID column for this table, or None if it doesn't have one.
  sql_id_name = None

  @classmethod
  def sql_fetch(cls, db, cur=None, arg_dict=None, **kwargs):
    """Fetch rows from SQL based on a canned query and a set of
    keyword arguments, and instantiate them as objects, returning a
    list of the instantiated objects.

    This is a class method because in general we don't even know how
    many matches the SQL lookup will return until after we've
    performed it.
    """

    result = []
    if cur is None:
      cur = db.cursor()
    if arg_dict is None:
      arg_dict = kwargs
    else:
      assert len(kwargs) == 0
    cur.execute(self.sql_select_cmd % arg_dict)
    for row in cur.fetchall():
      self = cls()
      self.in_sql = True
      self.sql_objectify(*row)
      result.append(self)
      attr_dict = self.sql_makedict()
      for kid_name,kid_type in self.sql_children.items():
        setattr(self, kid_name, kid_type.sql_fetch(db, cur, attr_dict))
    return result
      
  def sql_objectify(self):
    """Initialize self with values returned by self.sql_fetch().
    """
    raise NotImplementedError

  def sql_store(self, db, cur=None):
    """Save an object and its descendents to SQL.
    """
    if cur is None:
      cur = db.cursor()
    if not self.sql_in_db:
      cur.execute(self.sql_insert_cmd % self.sql_makedict())
      if self.sql_id_name is not None:
        setattr(self, self.sql_id_name, cur.lastrowid)
    elif self.sql_dirty:
      cur.execute(self.sql_update_cmd % self.sql_makedict())
    self.sql_dirty = False
    self.sql_in_db = True
    for kids in self.sql_children:
      for kid in getattr(self, kids):
        kid.sql_store(db, cur)

  def sql_delete(self, db, cur=None):
    """Delete an object and its descendants from SQL.
    """
    if cur is None:
      cur = db.cursor()
    if self.sql_in_db:
      cur.execute(self.sql_delete_cmd % self.sql_makedict())
      self.sql_in_db = False
    for kids in self.sql_children:
      for kid in getattr(self, kids):
        kid.sql_delete(db, cur)

  def sql_makedict(self):
    """Copy attributes from this object into a dict for use with
    canned SQL queries.
    """
    return dict((a, getattr(self, a)) for a in self.sql_attributes)