aboutsummaryrefslogtreecommitdiff
path: root/rpki/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'rpki/sql.py')
-rw-r--r--rpki/sql.py51
1 files changed, 48 insertions, 3 deletions
diff --git a/rpki/sql.py b/rpki/sql.py
index 31ed40ee..55e6f7cb 100644
--- a/rpki/sql.py
+++ b/rpki/sql.py
@@ -56,11 +56,12 @@ class session(object):
ping_threshold = rpki.sundial.timedelta(seconds = 60)
- def __init__(self, cfg):
+ def __init__(self, cfg, autocommit = True):
self.username = cfg.get("sql-username")
self.database = cfg.get("sql-database")
self.password = cfg.get("sql-password")
+ self.autocommit = autocommit
self.conv = MySQLdb.converters.conversions.copy()
self.conv.update({
@@ -78,7 +79,7 @@ class session(object):
passwd = self.password,
conv = self.conv)
self.cur = self.db.cursor()
- self.db.autocommit(True)
+ self.db.autocommit(self.autocommit)
self.timestamp = rpki.sundial.now()
def close(self):
@@ -113,11 +114,37 @@ class session(object):
def lastrowid(self):
return self.cur.lastrowid
+ def commit(self):
+ """
+ Sweep cache, then commit SQL.
+ """
+
+ self.sweep()
+ logger.debug("Executing SQL COMMIT")
+ self.db.commit()
+
+ def rollback(self):
+ """
+ SQL rollback, then clear cache and dirty cache.
+
+ NB: We have no way of clearing other references to cached objects,
+ so if you call this method you MUST forget any state that might
+ cause you to retain such references. This is probably tricky, and
+ is itself a good argument for switching to something like the
+ Django ORM's @commit_on_success semantics, but we do what we can.
+ """
+
+ logger.debug("Executing SQL ROLLBACK, discarding SQL cache and dirty set")
+ self.db.rollback()
+ self.dirty.clear()
+ self.cache.clear()
+
def cache_clear(self):
"""
Clear the SQL object cache. Shouldn't be necessary now that the
cache uses weak references, but should be harmless.
"""
+
logger.debug("Clearing SQL cache")
self.assert_pristine()
self.cache.clear()
@@ -126,14 +153,15 @@ class session(object):
"""
Assert that there are no dirty objects in the cache.
"""
+
assert not self.dirty, "Dirty objects in SQL cache: %s" % self.dirty
def sweep(self):
"""
Write any dirty objects out to SQL.
"""
+
for s in self.dirty.copy():
- #if s.sql_cache_debug:
logger.debug("Sweeping (%s) %r", "deleting" if s.sql_deleted else "storing", s)
if s.sql_deleted:
s.sql_delete()
@@ -150,6 +178,7 @@ class template(object):
"""
Build a SQL template.
"""
+
type_map = dict((x[0], x[1]) for x in data_columns if isinstance(x, tuple))
data_columns = tuple(isinstance(x, tuple) and x[0] or x for x in data_columns)
columns = (index_column,) + data_columns
@@ -220,6 +249,7 @@ class sql_persistent(object):
"""
Fetch one object from SQL, based on an arbitrary SQL WHERE expression.
"""
+
results = cls.sql_fetch_where(gctx, where, args, also_from)
if len(results) == 0:
return None
@@ -235,6 +265,7 @@ class sql_persistent(object):
"""
Fetch all objects of this type from SQL.
"""
+
return cls.sql_fetch_where(gctx, None)
@classmethod
@@ -242,6 +273,7 @@ class sql_persistent(object):
"""
Fetch objects of this type matching an arbitrary SQL WHERE expression.
"""
+
if where is None:
assert args is None and also_from is None
if cls.sql_debug:
@@ -269,6 +301,7 @@ class sql_persistent(object):
"""
Initialize one Python object from the result of a SQL query.
"""
+
self = cls()
self.gctx = gctx
self.sql_decode(dict(zip(cls.sql_template.columns, row)))
@@ -281,6 +314,7 @@ class sql_persistent(object):
"""
Mark this object as needing to be written back to SQL.
"""
+
if self.sql_cache_debug and not self.sql_is_dirty:
logger.debug("Marking %r SQL dirty", self)
self.gctx.sql.dirty.add(self)
@@ -289,6 +323,7 @@ class sql_persistent(object):
"""
Mark this object as not needing to be written back to SQL.
"""
+
if self.sql_cache_debug and self.sql_is_dirty:
logger.debug("Marking %r SQL clean", self)
self.gctx.sql.dirty.discard(self)
@@ -298,12 +333,14 @@ class sql_persistent(object):
"""
Query whether this object needs to be written back to SQL.
"""
+
return self in self.gctx.sql.dirty
def sql_mark_deleted(self):
"""
Mark this object as needing to be deleted in SQL.
"""
+
self.sql_deleted = True
self.sql_mark_dirty()
@@ -311,6 +348,7 @@ class sql_persistent(object):
"""
Store this object to SQL.
"""
+
args = self.sql_encode()
if not self.sql_in_db:
if self.sql_debug:
@@ -333,6 +371,7 @@ class sql_persistent(object):
"""
Delete this object from SQL.
"""
+
if self.sql_in_db:
id = getattr(self, self.sql_template.index) # pylint: disable=W0622
if self.sql_debug:
@@ -352,6 +391,7 @@ class sql_persistent(object):
mapping between column names in SQL and attribute names in Python.
If you need something fancier, override this.
"""
+
d = dict((a, getattr(self, a, None)) for a in self.sql_template.columns)
for i in self.sql_template.map:
if d.get(i) is not None:
@@ -365,6 +405,7 @@ class sql_persistent(object):
between column names in SQL and attribute names in Python. If you
need something fancier, override this.
"""
+
for a in self.sql_template.columns:
if vals.get(a) is not None and a in self.sql_template.map:
setattr(self, a, self.sql_template.map[a].from_sql(vals[a]))
@@ -375,18 +416,21 @@ class sql_persistent(object):
"""
Customization hook.
"""
+
pass
def sql_insert_hook(self):
"""
Customization hook.
"""
+
pass
def sql_update_hook(self):
"""
Customization hook.
"""
+
self.sql_delete_hook()
self.sql_insert_hook()
@@ -394,6 +438,7 @@ class sql_persistent(object):
"""
Customization hook.
"""
+
pass