diff options
Diffstat (limited to 'rpki/sql.py')
-rw-r--r-- | rpki/sql.py | 51 |
1 files changed, 48 insertions, 3 deletions
diff --git a/rpki/sql.py b/rpki/sql.py index 31ed40ee..55e6f7cb 100644 --- a/rpki/sql.py +++ b/rpki/sql.py @@ -56,11 +56,12 @@ class session(object): ping_threshold = rpki.sundial.timedelta(seconds = 60) - def __init__(self, cfg): + def __init__(self, cfg, autocommit = True): self.username = cfg.get("sql-username") self.database = cfg.get("sql-database") self.password = cfg.get("sql-password") + self.autocommit = autocommit self.conv = MySQLdb.converters.conversions.copy() self.conv.update({ @@ -78,7 +79,7 @@ class session(object): passwd = self.password, conv = self.conv) self.cur = self.db.cursor() - self.db.autocommit(True) + self.db.autocommit(self.autocommit) self.timestamp = rpki.sundial.now() def close(self): @@ -113,11 +114,37 @@ class session(object): def lastrowid(self): return self.cur.lastrowid + def commit(self): + """ + Sweep cache, then commit SQL. + """ + + self.sweep() + logger.debug("Executing SQL COMMIT") + self.db.commit() + + def rollback(self): + """ + SQL rollback, then clear cache and dirty cache. + + NB: We have no way of clearing other references to cached objects, + so if you call this method you MUST forget any state that might + cause you to retain such references. This is probably tricky, and + is itself a good argument for switching to something like the + Django ORM's @commit_on_success semantics, but we do what we can. + """ + + logger.debug("Executing SQL ROLLBACK, discarding SQL cache and dirty set") + self.db.rollback() + self.dirty.clear() + self.cache.clear() + def cache_clear(self): """ Clear the SQL object cache. Shouldn't be necessary now that the cache uses weak references, but should be harmless. """ + logger.debug("Clearing SQL cache") self.assert_pristine() self.cache.clear() @@ -126,14 +153,15 @@ class session(object): """ Assert that there are no dirty objects in the cache. """ + assert not self.dirty, "Dirty objects in SQL cache: %s" % self.dirty def sweep(self): """ Write any dirty objects out to SQL. """ + for s in self.dirty.copy(): - #if s.sql_cache_debug: logger.debug("Sweeping (%s) %r", "deleting" if s.sql_deleted else "storing", s) if s.sql_deleted: s.sql_delete() @@ -150,6 +178,7 @@ class template(object): """ Build a SQL template. """ + type_map = dict((x[0], x[1]) for x in data_columns if isinstance(x, tuple)) data_columns = tuple(isinstance(x, tuple) and x[0] or x for x in data_columns) columns = (index_column,) + data_columns @@ -220,6 +249,7 @@ class sql_persistent(object): """ Fetch one object from SQL, based on an arbitrary SQL WHERE expression. """ + results = cls.sql_fetch_where(gctx, where, args, also_from) if len(results) == 0: return None @@ -235,6 +265,7 @@ class sql_persistent(object): """ Fetch all objects of this type from SQL. """ + return cls.sql_fetch_where(gctx, None) @classmethod @@ -242,6 +273,7 @@ class sql_persistent(object): """ Fetch objects of this type matching an arbitrary SQL WHERE expression. """ + if where is None: assert args is None and also_from is None if cls.sql_debug: @@ -269,6 +301,7 @@ class sql_persistent(object): """ Initialize one Python object from the result of a SQL query. """ + self = cls() self.gctx = gctx self.sql_decode(dict(zip(cls.sql_template.columns, row))) @@ -281,6 +314,7 @@ class sql_persistent(object): """ Mark this object as needing to be written back to SQL. """ + if self.sql_cache_debug and not self.sql_is_dirty: logger.debug("Marking %r SQL dirty", self) self.gctx.sql.dirty.add(self) @@ -289,6 +323,7 @@ class sql_persistent(object): """ Mark this object as not needing to be written back to SQL. """ + if self.sql_cache_debug and self.sql_is_dirty: logger.debug("Marking %r SQL clean", self) self.gctx.sql.dirty.discard(self) @@ -298,12 +333,14 @@ class sql_persistent(object): """ Query whether this object needs to be written back to SQL. """ + return self in self.gctx.sql.dirty def sql_mark_deleted(self): """ Mark this object as needing to be deleted in SQL. """ + self.sql_deleted = True self.sql_mark_dirty() @@ -311,6 +348,7 @@ class sql_persistent(object): """ Store this object to SQL. """ + args = self.sql_encode() if not self.sql_in_db: if self.sql_debug: @@ -333,6 +371,7 @@ class sql_persistent(object): """ Delete this object from SQL. """ + if self.sql_in_db: id = getattr(self, self.sql_template.index) # pylint: disable=W0622 if self.sql_debug: @@ -352,6 +391,7 @@ class sql_persistent(object): mapping between column names in SQL and attribute names in Python. If you need something fancier, override this. """ + d = dict((a, getattr(self, a, None)) for a in self.sql_template.columns) for i in self.sql_template.map: if d.get(i) is not None: @@ -365,6 +405,7 @@ class sql_persistent(object): between column names in SQL and attribute names in Python. If you need something fancier, override this. """ + for a in self.sql_template.columns: if vals.get(a) is not None and a in self.sql_template.map: setattr(self, a, self.sql_template.map[a].from_sql(vals[a])) @@ -375,18 +416,21 @@ class sql_persistent(object): """ Customization hook. """ + pass def sql_insert_hook(self): """ Customization hook. """ + pass def sql_update_hook(self): """ Customization hook. """ + self.sql_delete_hook() self.sql_insert_hook() @@ -394,6 +438,7 @@ class sql_persistent(object): """ Customization hook. """ + pass |