diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 105e1228bb8f..f430cce9312a 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -226,11 +226,9 @@ def persist( context = EventContext() context.current_state_ids = state_ids context.prev_state_ids = state_ids - elif not backfill: + else: state_handler = self.hs.get_state_handler() context = yield state_handler.compute_event_context(event) - else: - context = EventContext() context.push_actions = push_actions diff --git a/tests/test_state.py b/tests/test_state.py index feb84f3d48b8..253aa62f2a0d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -99,6 +99,10 @@ def register_events(self, events): for e in events: self._event_id_to_event[e.event_id] = e + def store_state_group(self, *args, **kwargs): + self._next_group += 1 + return self._next_group + class DictObj(dict): def __init__(self, **kwargs): @@ -144,6 +148,7 @@ def setUp(self): "get_events", "get_next_state_group", "get_state_group_delta", + "store_state_group", ] ) hs = Mock(spec_set=[ @@ -316,6 +321,7 @@ def test_branch_have_banned_conflict(self): store = StateGroupStore() self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids self.store.get_events = store.get_events + self.store.store_state_group = store.store_state_group store.register_events(graph.walk()) context_store = {} @@ -399,6 +405,7 @@ def test_branch_have_perms_conflict(self): store = StateGroupStore() self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids self.store.get_events = store.get_events + self.store.store_state_group = store.store_state_group store.register_events(graph.walk()) context_store = {}