diff --git a/python/veles/db/backend.py b/python/veles/db/backend.py index 42a9f7c7..766d246b 100644 --- a/python/veles/db/backend.py +++ b/python/veles/db/backend.py @@ -19,15 +19,15 @@ import six -from veles.proto import msgpackwrap +from veles.proto import msgpackwrap, check from veles.schema.nodeid import NodeID -from veles.proto.node import Node, PosFilter +from veles.proto.node import TriggerState, Node, PosFilter from veles.proto.exceptions import WritePastEndError from veles.util.bigint import bigint_encode, bigint_decode DB_APP_ID = int('veles', 36) -DB_VERSION = 2 +DB_VERSION = 3 DB_BINDATA_PAGE_SIZE = 0x10000 DB_SCHEMA = [ @@ -71,6 +71,79 @@ data BLOB NOT NULL, PRIMARY KEY (id, name, page) ) + """, """ + CREATE TABLE trigger( + tid INTEGER NOT NULL PRIMARY KEY, + nid BLOB NOT NULL REFERENCES node(id), + name VARCHAR NOT NULL, + state VARCHAR NOT NULL, + exception BLOB, + UNIQUE (nid, name), + CHECK (state IN ('pending', 'done', 'exception')) + ) + """, """ + CREATE INDEX trigger_pending ON trigger(name) + WHERE state = 'pending' + """, """ + CREATE TABLE trigger_check_node( + tid INTEGER NOT NULL REFERENCES trigger(tid), + nid BLOB NOT NULL, + mode VARCHAR NOT NULL, + PRIMARY KEY (tid, nid), + CHECK (mode IN ('present', 'pos', 'parent', 'tags')) + ) + """, """ + CREATE INDEX trigger_check_node_idx + ON trigger_check_node(nid, mode) + """, """ + CREATE TABLE trigger_check_node_name( + tid INTEGER NOT NULL REFERENCES trigger(tid), + nid BLOB NOT NULL REFERENCES node(id), + name VARCHAR NOT NULL, + mode VARCHAR NOT NULL, + PRIMARY KEY (tid, nid, name), + CHECK (mode IN ('tag', 'attr', 'data', 'bindata_size', 'trigger')) + ) + """, """ + CREATE INDEX trigger_check_node_name_idx + ON trigger_check_node_name(nid, name, mode) + """, """ + CREATE TABLE trigger_check_node_bindata( + tid INTEGER NOT NULL REFERENCES trigger(tid), + nid BLOB NOT NULL REFERENCES node(id), + name VARCHAR NOT NULL, + start BLOB NOT NULL, + end BLOB NULL, + PRIMARY KEY (tid, nid, name, start) + ) + """, """ + CREATE INDEX trigger_check_node_bindata_idx + ON trigger_check_node_bindata(nid, name, start) + """, """ + CREATE INDEX trigger_check_node_bindata_tid + ON trigger_check_node_bindata(tid) + """, """ + CREATE TABLE trigger_check_list( + tclid INTEGER NOT NULL PRIMARY KEY, + tid INTEGER NOT NULL REFERENCES trigger(tid), + nid BLOB NULL REFERENCES node(id), + pos_start_from BLOB NULL, + pos_start_to BLOB NULL, + pos_end_from BLOB NULL, + pos_end_to BLOB NULL + ) + """, """ + CREATE INDEX trigger_check_list_idx + ON trigger_check_list(nid) + """, """ + CREATE INDEX trigger_check_list_tid + ON trigger_check_list(tid) + """, """ + CREATE TABLE trigger_check_list_tag( + tclid INTEGER NOT NULL REFERENCES trigger_check_list(tclid), + name VARCHAR NOT NULL, + PRIMARY KEY (tclid, name) + ) """ ] @@ -79,7 +152,7 @@ # # - link support # - xref support -# - trigger model +# - trigger busting if six.PY3: def buffer(x): @@ -179,12 +252,16 @@ def get(self, id): key: page * DB_BINDATA_PAGE_SIZE + lastlen for key, page, lastlen in c.fetchall() } + c.execute(""" + SELECT name, state FROM trigger WHERE nid = ? + """, (raw_id,)) + triggers = {k: TriggerState(v) for k, v in c.fetchall()} return Node(id=id, parent=parent, pos_start=db_bigint_decode(pos_start), pos_end=db_bigint_decode(pos_end), tags=tags, attr=attr, - data=data, bindata=bindata) + data=data, bindata=bindata, triggers=triggers) - def create(self, node, commit=True): + def create(self, node): if not isinstance(node, Node): raise TypeError('node has wrong type') if node.id == NodeID.root_id: @@ -215,10 +292,8 @@ def create(self, node, commit=True): (raw_id, key, buffer(self.packer.pack(val))) for key, val in node.attr.items() ]) - if commit: - self.commit() - def set_pos(self, id, pos_start, pos_end, commit=True): + def set_pos(self, id, pos_start, pos_end): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') if (not isinstance(pos_start, six.integer_types) @@ -238,10 +313,8 @@ def set_pos(self, id, pos_start, pos_end, commit=True): db_bigint_encode(pos_end), raw_id )) - if commit: - self.commit() - def set_parent(self, id, parent_id, commit=True): + def set_parent(self, id, parent_id): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') if not isinstance(parent_id, NodeID): @@ -257,10 +330,8 @@ def set_parent(self, id, parent_id, commit=True): SET parent = ? WHERE id = ? """, (raw_parent, raw_id)) - if commit: - self.commit() - def add_tag(self, id, tag, commit=True): + def add_tag(self, id, tag): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') if not isinstance(tag, six.text_type): @@ -274,10 +345,8 @@ def add_tag(self, id, tag, commit=True): c.execute(""" INSERT INTO node_tag (id, name) VALUES (?, ?) """, (raw_id, tag)) - if commit: - self.commit() - def del_tag(self, id, tag, commit=True): + def del_tag(self, id, tag): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') if not isinstance(tag, six.text_type): @@ -288,10 +357,8 @@ def del_tag(self, id, tag, commit=True): DELETE FROM node_tag WHERE id = ? AND name = ? """, (raw_id, tag)) - if commit: - self.commit() - def set_attr(self, id, key, val, commit=True): + def set_attr(self, id, key, val): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') if not isinstance(key, six.text_type): @@ -305,8 +372,6 @@ def set_attr(self, id, key, val, commit=True): c.execute(""" INSERT INTO node_attr (id, name, data) VALUES (?, ?, ?) """, (raw_id, key, buffer(self.packer.pack(val)))) - if commit: - self.commit() def get_data(self, id, key): if not isinstance(id, NodeID): @@ -324,7 +389,7 @@ def get_data(self, id, key): (data,), = rows return self._load(data) - def set_data(self, id, key, data, commit=True): + def set_data(self, id, key, data): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') if not isinstance(key, six.text_type): @@ -338,8 +403,6 @@ def set_data(self, id, key, data, commit=True): c.execute(""" INSERT INTO node_data (id, name, data) VALUES (?, ?, ?) """, (raw_id, key, buffer(self.packer.pack(data)))) - if commit: - self.commit() def get_bindata(self, id, key, start=0, end=None): if not isinstance(id, NodeID): @@ -378,7 +441,7 @@ def get_bindata(self, id, key, start=0, end=None): data = b''.join(bytes(x) for x, in c.fetchall()) return data[start:end] - def set_bindata(self, id, key, start, data, truncate=False, commit=True): + def set_bindata(self, id, key, start, data, truncate=False): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') start = operator.index(start) @@ -460,11 +523,211 @@ def set_bindata(self, id, key, start, data, truncate=False, commit=True): ) for page in six.moves.range(page_first, page_end) ]) - # We're done here. - if commit: - self.commit() + def find_trigger(self, nid, key): + raw_id = buffer(nid.bytes) + c = self.db.cursor() + c.execute(""" + SELECT tid FROM trigger WHERE nid = ? AND name = ? + """, (raw_id, key)) + rows = c.fetchall() + if not rows: + return None + (tid,), = rows + return tid + + def add_trigger(self, nid, key): + tid = self.find_trigger(nid, key) + if tid is not None: + return tid + raw_id = buffer(nid.bytes) + c = self.db.cursor() + c.execute(""" + INSERT INTO trigger(nid, name, state) + VALUES (?, ?, ?) + """, (raw_id, key, 'pending')) + return c.lastrowid - def delete(self, id, commit=True): + def bust_trigger(self, tid): + c = self.db.cursor() + c.execute(""" + DELETE FROM trigger_check_node WHERE tid = ? + """, (tid,)) + c.execute(""" + DELETE FROM trigger_check_node_name WHERE tid = ? + """, (tid,)) + c.execute(""" + DELETE FROM trigger_check_node_bindata WHERE tid = ? + """, (tid,)) + c.execute(""" + DELETE FROM trigger_check_list_tag + WHERE tclid IN ( + SELECT tclid + FROM trigger_check_list + WHERE tid = ? + ) + """, (tid,)) + c.execute(""" + DELETE FROM trigger_check_list WHERE tid = ? + """, (tid,)) + c.execute(""" + UPDATE trigger + SET state = 'pending', exception = NULL + WHERE tid = ? + """, (tid,)) + + def bust_triggers_all(self, nid): + raw_id = buffer(nid.bytes) + c = self.db.cursor() + c.execute(""" + SELECT tid FROM trigger_check_node + WHERE nid = ? + UNION SELECT tid FROM trigger_check_node_name + WHERE nid = ? + UNION SELECT tid FROM trigger_check_node_bindata + WHERE nid = ? + UNION SELECT tid FROM trigger_check_list + WHERE nid = ? + """, (raw_id, raw_id, raw_id, raw_id)) + tids = {tid for tid, in c.fetchall()} + for tid in tids: + self.bust_trigger(tid) + + def bust_triggers_node(self, nid, mode): + raw_id = buffer(nid.bytes) + c = self.db.cursor() + c.execute(""" + SELECT tid FROM trigger_check_node + WHERE nid = ? AND mode = ? + """, (raw_id, mode)) + tids = {tid for tid, in c.fetchall()} + for tid in tids: + self.bust_trigger(tid) + + def bust_triggers_node_name(self, nid, mode, name): + raw_id = buffer(nid.bytes) + c = self.db.cursor() + c.execute(""" + SELECT tid FROM trigger_check_node_name + WHERE nid = ? AND mode = ? AND name = ? + """, (raw_id, mode, name)) + tids = {tid for tid, in c.fetchall()} + for tid in tids: + self.bust_trigger(tid) + + def _trigger_check_node(self, tid, ch, mode): + raw_id = buffer(ch.node.bytes) + c = self.db.cursor() + c.execute(""" + INSERT INTO trigger_check_node (tid, nid, mode) + VALUES (?, ?, ?) + """, ( + tid, raw_id, mode, + )) + + def _trigger_check_node_name(self, tid, ch, mode): + raw_id = buffer(ch.node.bytes) + if isinstance(ch, check.CheckTag): + name = ch.tag + else: + name = ch.key + c = self.db.cursor() + c.execute(""" + INSERT INTO trigger_check_node_name (tid, nid, name, mode) + VALUES (?, ?, ?, ?) + """, ( + tid, raw_id, name, mode, + )) + + def _trigger_check_node_bindata(self, tid, ch, mode): + raw_id = buffer(ch.node.bytes) + c = self.db.cursor() + c.execute(""" + INSERT INTO trigger_check_node_bindata (tid, nid, name, start, end) + VALUES (?, ?, ?, ?, ?) + """, ( + tid, raw_id, ch.key, + db_bigint_encode(ch.start), + db_bigint_encode(ch.end), + )) + + def _trigger_check_list(self, tid, ch, mode): + if ch.parent == NodeID.root_id: + raw_id = None + else: + raw_id = buffer(ch.parent.bytes) + c = self.db.cursor() + c.execute(""" + INSERT INTO trigger_check_list ( + tid, nid, + pos_start_from, + pos_start_to, + pos_end_from, + pos_end_to + ) + VALUES (?, ?, ?, ?, ?, ?) + """, ( + tid, raw_id, + db_bigint_encode(ch.pos_filter.start_from), + db_bigint_encode(ch.pos_filter.start_to), + db_bigint_encode(ch.pos_filter.end_from), + db_bigint_encode(ch.pos_filter.end_to), + )) + tclid = c.lastrowid + c.executemany(""" + INSERT INTO trigger_check_list_tag (tclid, name) + VALUES (?, ?) + """, [ + (tclid, tag) + for tag in ch.tags + ]) + + def mark_trigger(self, nid, key, exception, checks): + tid = self.find_trigger(nid, key) + c = self.db.cursor() + if tid is None: + return + if exception is None: + c.execute(""" + UPDATE trigger + SET state = 'done', exception = NULL + WHERE tid = ? + """, (tid,)) + else: + c.execute(""" + UPDATE trigger + SET state = 'exception', exception = ? + WHERE tid = ? + """, (buffer(self.packer.pack(exception.dump()), tid)) + handlers = { + check.CheckGone: (self._trigger_check_node, 'present'), + check.CheckParent: (self._trigger_check_node, 'parent'), + check.CheckPos: (self._trigger_check_node, 'pos'), + check.CheckTags: (self._trigger_check_node, 'tags'), + check.CheckTag: (self._trigger_check_node_name, 'tag'), + check.CheckAttr: (self._trigger_check_node_name, 'attr'), + check.CheckData: (self._trigger_check_node_name, 'data'), + check.CheckBinDataSize: + (self._trigger_check_node_name, 'bindata_size'), + check.CheckBinData: (self._trigger_check_node_bindata, None), + check.CheckTrigger: (self._trigger_check_node_name, 'trigger'), + check.CheckList: (self._trigger_check_list, None), + } + for ch in checks: + handler, mode = handlers[type(ch)] + handler(tid, ch, mode) + + def del_trigger(self, nid, key): + tid = self.find_trigger(nid, key) + if tid is None: + return None + self.bust_trigger(tid) + c = self.db.cursor() + c.execute(""" + DELETE FROM trigger WHERE tid = ? + """, (tid,)) + return tid + + def delete(self, id): if not isinstance(id, NodeID): raise TypeError('node id has wrong type') raw_id = buffer(id.bytes) @@ -481,11 +744,20 @@ def delete(self, id, commit=True): c.execute(""" DELETE FROM node_bindata WHERE id = ? """, (raw_id,)) + c.execute(""" + SELECT tid FROM trigger + WHERE nid = ? + """, (raw_id,)) + tids = {tid for tid, in c.fetchall()} + for tid in tids: + self.bust_trigger(tid) + c.execute(""" + DELETE FROM trigger WHERE nid = ? + """, (raw_id,)) c.execute(""" DELETE FROM node WHERE id = ? """, (raw_id,)) - if commit: - self.commit() + return tids def list(self, parent, tags=frozenset(), pos_filter=PosFilter()): if not isinstance(parent, NodeID): @@ -528,6 +800,25 @@ def list(self, parent, tags=frozenset(), pos_filter=PosFilter()): c.execute(stmt, args) return {NodeID(bytes(x)) for x, in c.fetchall()} + def get_pending_triggers(self, skip=set(), limit=1): + c = self.db.cursor() + if skip: + skip_str = "AND tid NOT IN ({})".format( + ', '.join('?' for x in skip)) + else: + skip_str = '' + c.execute(""" + SELECT tid, nid, name + FROM trigger + WHERE state = 'pending' + {} + LIMIT ? + """ + skip_str, list(skip) + [limit]) + return { + (tid, nid, name) + for tid, nid, name in c.fetchall() + } + def begin(self): if six.PY3: assert not self.db.in_transaction diff --git a/python/veles/db/tracker.py b/python/veles/db/tracker.py index 2818dba1..9ea3601e 100644 --- a/python/veles/db/tracker.py +++ b/python/veles/db/tracker.py @@ -21,7 +21,7 @@ from veles.schema.nodeid import NodeID from veles.proto import operation, check -from veles.proto.node import Node, PosFilter +from veles.proto.node import TriggerState, Node, PosFilter from veles.proto.exceptions import ( ObjectGoneError, ObjectExistsError, @@ -63,6 +63,7 @@ def __init__(self, db): self.nodes = weakref.WeakValueDictionary() self.get_cached_node = lru_cache(maxsize=DB_CACHE_SIZE)( self._get_cached_node) + self.triggers_gone_callbacks = set() def _get_cached_node(self, nid): try: @@ -184,9 +185,10 @@ def _op_create(self, xact, op, dbnode): id=dbnode.id, parent=op.parent, pos_start=op.pos_start, pos_end=op.pos_end, tags=op.tags, attr=op.attr, data=set(op.data), - bindata={x: len(y) for x, y in op.bindata.items()} + bindata={x: len(y) for x, y in op.bindata.items()}, + triggers={x: TriggerState.pending for x in op.triggers}, ) - self.db.create(dbnode.node, commit=False) + self.db.create(dbnode.node) dbnode.parent = parent for key, val in op.data.items(): self.db.set_data(dbnode.id, key, val) @@ -195,6 +197,8 @@ def _op_create(self, xact, op, dbnode): self.db.set_bindata(dbnode.id, key, 0, val) for sub in dbnode.bindata_subs.get(key, set()): xact.bindata_changed(sub) + for key in op.triggers: + self.db.add_trigger(dbnode.id, key) def _op_delete(self, xact, op, dbnode): if dbnode.node is None: @@ -203,7 +207,7 @@ def _op_delete(self, xact, op, dbnode): subnode = self.get_cached_node(oid) xact.save(subnode) self._op_delete(xact, op, subnode) - self.db.delete(dbnode.id, commit=False) + xact.triggers_gone |= self.db.delete(dbnode.id) dbnode.node = None dbnode.parent = None @@ -220,7 +224,7 @@ def _op_set_parent(self, xact, op, dbnode): if cur.id == dbnode.id: raise ParentCycleError() cur = cur.parent - self.db.set_parent(dbnode.id, parent.id, commit=False) + self.db.set_parent(dbnode.id, parent.id) dbnode.node.parent = parent.id dbnode.parent = parent @@ -230,7 +234,7 @@ def _op_set_pos(self, xact, op, dbnode): if (op.pos_start == dbnode.node.pos_start and op.pos_end == dbnode.node.pos_end): return - self.db.set_pos(dbnode.id, op.pos_start, op.pos_end, commit=False) + self.db.set_pos(dbnode.id, op.pos_start, op.pos_end) dbnode.node.pos_start = op.pos_start dbnode.node.pos_end = op.pos_end @@ -239,7 +243,7 @@ def _op_add_tag(self, xact, op, dbnode): raise ObjectGoneError() if op.tag in dbnode.node.tags: return - self.db.add_tag(dbnode.id, op.tag, commit=False) + self.db.add_tag(dbnode.id, op.tag) dbnode.node.tags.add(op.tag) def _op_del_tag(self, xact, op, dbnode): @@ -247,7 +251,7 @@ def _op_del_tag(self, xact, op, dbnode): raise ObjectGoneError() if op.tag not in dbnode.node.tags: return - self.db.del_tag(dbnode.id, op.tag, commit=False) + self.db.del_tag(dbnode.id, op.tag) dbnode.node.tags.remove(op.tag) def _op_set_attr(self, xact, op, dbnode): @@ -255,7 +259,7 @@ def _op_set_attr(self, xact, op, dbnode): raise ObjectGoneError() if dbnode.node.attr.get(op.key) == op.data: return - self.db.set_attr(dbnode.id, op.key, op.data, commit=False) + self.db.set_attr(dbnode.id, op.key, op.data) if op.data is None: del dbnode.node.attr[op.key] else: @@ -264,8 +268,8 @@ def _op_set_attr(self, xact, op, dbnode): def _op_set_data(self, xact, op, dbnode): if dbnode.node is None: raise ObjectGoneError() - self.db.set_data(dbnode.id, op.key, op.data, commit=False) xact.set_data(dbnode, op.key, op.data) + self.db.set_data(dbnode.id, op.key, op.data) if op.data is None and op.key in dbnode.node.data: dbnode.node.data.remove(op.key) elif op.data is not None and op.key not in dbnode.node.data: @@ -275,7 +279,7 @@ def _op_set_bindata(self, xact, op, dbnode): if dbnode.node is None: raise ObjectGoneError() self.db.set_bindata(dbnode.id, op.key, op.start, op.data, - op.truncate, commit=False) + op.truncate) old_len = dbnode.node.bindata.get(op.key, 0) if op.truncate: new_len = op.start + len(op.data) @@ -298,14 +302,20 @@ def _op_set_bindata(self, xact, op, dbnode): def _op_add_trigger(self, xact, op, dbnode): if dbnode.node is None: raise ObjectGoneError() - # XXX - raise NotImplementedError + if op.trigger in dbnode.node.triggers: + return + self.db.add_trigger(dbnode.id, op.trigger) + dbnode.node.triggers[op.trigger] = TriggerState.pending def _op_del_trigger(self, xact, op, dbnode): if dbnode.node is None: raise ObjectGoneError() - # XXX - raise NotImplementedError + if op.trigger not in dbnode.node.triggers: + return + tid = self.db.del_trigger(dbnode.id, op.trigger) + assert tid is not None + del dbnode.node.triggers[op.trigger] + xact.triggers_gone.add(tid) def transaction(self, checks, ops): if not self.checks_ok(checks): @@ -329,6 +339,18 @@ def transaction(self, checks, ops): xact.save(dbnode) handlers[type(op)](xact, op, dbnode) + def _triggers_gone(self, tids): + if not tids: + return + for cb in self.triggers_gone_callbacks: + cb(tids) + + def register_triggers_gone_callback(self, cb): + self.triggers_gone_callbacks.add(cb) + + def unregister_triggers_gone_callback(self, cb): + self.triggers_gone_callbacks.remove(cb) + # subscribers def register_subscriber(self, sub): diff --git a/python/veles/db/transaction.py b/python/veles/db/transaction.py index c1f3abde..d40ddf50 100644 --- a/python/veles/db/transaction.py +++ b/python/veles/db/transaction.py @@ -21,52 +21,54 @@ class Transaction(object): def __init__(self, tracker): self.tracker = tracker self.undo = {} + self.old_data = {} self.list_changes = {} self.data_subs = {} self.bindata_subs = set() self.gone_subs = set() self.node_subs = {} + self.triggers_gone = set() - def __enter__(self): - self.tracker.db.begin() - return self - - def __exit__(self, exc, val, tb): - if exc is None: - # Everything alright, let's commit. - self.tracker.db.commit() - for dbnode, (node, parent) in self.undo.items(): - if dbnode.node != node: - self.handle_changed_node(dbnode, node, parent) - # Send out all gone subs. - for sub in self.gone_subs: - sub.error(ObjectGoneError()) - # Send out all node subs. - for sub, node in self.node_subs.items(): - sub.node_changed(node) - # Now send out all buffered list changes. - for sub, (changed, gone) in self.list_changes.items(): - if sub in self.gone_subs: - continue - sub.list_changed(list(changed.values()), list(gone)) - # Send out data subs. - for sub, data in self.data_subs.items(): - if sub in self.gone_subs: - continue - sub.data_changed(data) - # Send out bindata subs. - for sub in self.bindata_subs: - if sub in self.gone_subs: - continue - data = self.tracker.get_bindata( - sub.node, sub.key, sub.start, sub.end) - sub.bindata_changed(data) - else: - # Whoops. Undo changes. - self.tracker.db.rollback() - for dbnode, (node, parent) in self.undo.items(): - dbnode.node = node - dbnode.parent = parent + def precommit(self): + for dbnode, (node, parent) in self.undo.items(): + if dbnode.node is None or node is None: + # Node created or deleted - bust everything. + self.tracker.db.bust_triggers_all(dbnode.id) + else: + if dbnode.node.parent != node.parent: + self.tracker.db.bust_triggers_node(dbnode.id, 'parent') + if (dbnode.node.pos_start != node.pos_start or + dbnode.node.pos_end != node.pos_end): + self.tracker.db.bust_triggers_node(dbnode.id, 'pos') + if dbnode.node.tags != node.tags: + self.tracker.db.bust_triggers_node(dbnode.id, 'tags') + for tag in (dbnode.node.tags ^ node.tags): + self.tracker.db.bust_triggers_node_name( + dbnode.id, 'tag', tag) + for attr in (set(dbnode.node.attr) | set(node.attr)): + if dbnode.node.attr.get(attr) != node.attr.get(attr): + self.tracker.db.bust_triggers_node_name( + dbnode.id, 'attr', attr) + all_bindata = set(dbnode.node.bindata) | set(node.bindata) + for key in all_bindata: + if (dbnode.node.bindata.get(key) != + node.bindata.get(key)): + self.tracker.db.bust_triggers_node_name( + dbnode.id, 'bindata_size', key) + all_triggers = set(dbnode.node.triggers) | set(node.triggers) + for trigger in all_triggers: + if (dbnode.node.triggers.get(trigger) != + node.triggers.get(trigger)): + self.tracker.db.bust_triggers_node_name( + dbnode.id, 'trigger', trigger) + for (dbnode, key), data in self.old_data.items(): + if data != self.tracker.db.get_data(dbnode.id, key): + self.tracker.db.bust_triggers_node_name( + dbnode.id, 'data', key) + for sub in dbnode.data_subs.get(key, ()): + self.data_subs[sub] = data + # XXX: bust bindata + # XXX: bust list def handle_changed_node(self, dbnode, old_node, old_parent): # Queue normal subs. Also handle all subs for created @@ -109,6 +111,60 @@ def handle_changed_node(self, dbnode, old_node, old_parent): dbnodes.remove(dbnode) self.list_remove(sub, dbnode) + def postcommit(self): + self.tracker._triggers_gone(self.triggers_gone) + for dbnode, (node, parent) in self.undo.items(): + if dbnode.node != node: + self.handle_changed_node(dbnode, node, parent) + # Send out all gone subs. + for sub in self.gone_subs: + sub.error(ObjectGoneError()) + # Send out all node subs. + for sub, node in self.node_subs.items(): + sub.node_changed(node) + # Now send out all buffered list changes. + for sub, (changed, gone) in self.list_changes.items(): + if sub in self.gone_subs: + continue + sub.list_changed(list(changed.values()), list(gone)) + # Send out data subs. + for sub, data in self.data_subs.items(): + if sub in self.gone_subs: + continue + sub.data_changed(data) + # Send out bindata subs. + for sub in self.bindata_subs: + if sub in self.gone_subs: + continue + data = self.tracker.get_bindata( + sub.node, sub.key, sub.start, sub.end) + sub.bindata_changed(data) + + def rollback(self): + self.tracker.db.rollback() + for dbnode, (node, parent) in self.undo.items(): + dbnode.node = node + dbnode.parent = parent + + def __enter__(self): + self.tracker.db.begin() + return self + + def __exit__(self, exc, val, tb): + if exc is None: + # Everything alright, let's commit. + try: + self.precommit() + except: + # This should never happen, but just in case... + self.rollback() + raise + self.tracker.db.commit() + self.postcommit() + else: + # Whoops. Undo changes. + self.rollback() + def list_ensure(self, sub): if sub not in self.list_changes: self.list_changes[sub] = ({}, set()) @@ -122,8 +178,9 @@ def list_remove(self, sub, dbnode): self.list_changes[sub][1].add(dbnode.id) def set_data(self, dbnode, key, data): - for sub in dbnode.data_subs.get(key, ()): - self.data_subs[sub] = data + if (dbnode, key) not in self.old_data: + old_data = self.tracker.db.get_data(dbnode.id, key) + self.old_data[dbnode.id, key] = old_data def bindata_changed(self, sub): self.bindata_subs.add(sub) diff --git a/python/veles/tests/db/test_backend.py b/python/veles/tests/db/test_backend.py index 0c65158e..992d1c30 100644 --- a/python/veles/tests/db/test_backend.py +++ b/python/veles/tests/db/test_backend.py @@ -22,9 +22,10 @@ from veles.db.backend import DbBackend from veles.data.bindata import BinData -from veles.proto.node import Node +from veles.proto.node import TriggerState, Node, PosFilter from veles.schema.nodeid import NodeID from veles.proto.exceptions import WritePastEndError +from veles.proto import check from veles.tests.proto.test_pos_filter import ( NODES as LIST_NODES, @@ -41,7 +42,8 @@ def test_simple(self): pos_start=0x123, pos_end=0x456789abcdef1122334456789abcdef, data={'my_key'}, - bindata={'my_bindata': 12345}) + bindata={'my_bindata': 12345}, + triggers={'my_trigger': TriggerState.exception}) db.create(node) n1 = db.get(NodeID()) self.assertEqual(n1, None) @@ -74,7 +76,9 @@ def test_persist(self): pos_end=0x456, data={'my_key'}, bindata={'my_bindata': 12345}) + db1.begin() db1.create(node) + db1.commit() db1.close() db2 = DbBackend(path) n1 = db2.get(NodeID()) @@ -109,10 +113,28 @@ def test_delete(self): db.create(node) db.set_data(node.id, 'c', 'd') db.set_bindata(node.id, 'e', start=0, data=b'f', truncate=False) + db.add_trigger(node.id, 'g') + db.add_trigger(node.id, 'h') + db.mark_trigger(node.id, 'h', None, [ + check.CheckGone(node=NodeID()), + check.CheckParent(node=NodeID(), parent=NodeID()), + check.CheckTags(node=NodeID(), tags={'a'}), + check.CheckPos(node=NodeID(), pos_start=3, pos_end=None), + check.CheckList( + parent=NodeID.root_id, + tags={'a', 'b'}, + pos_filter=PosFilter(), + nodes=set() + ), + ]) n2 = db.get(node.id) self.assertEqual(n2.tags, {'my_node'}) self.assertEqual(db.get_data(node.id, 'c'), 'd') self.assertEqual(db.get_bindata(node.id, 'e'), b'f') + self.assertEqual(n2.triggers, { + 'g': TriggerState.pending, + 'h': TriggerState.done + }) db.delete(node.id) self.assertEqual(db.get(node.id), None) self.assertEqual(db.get_data(node.id, 'c'), None)