diff --git a/betty/ancestry.py b/betty/ancestry.py index 5cf77f9ec..864906a58 100644 --- a/betty/ancestry.py +++ b/betty/ancestry.py @@ -195,7 +195,15 @@ def __init__(self, entity_id: str, name: str = None): Entity.__init__(self, entity_id) self._name = name self._coordinates = None - self._events = set() + + def handle_event_addition(event: Event): + event.place = self + + def handle_event_removal(event: Event): + event.place = None + + self._events = EventHandlingSet( + handle_event_addition, handle_event_removal) @property def label(self) -> str: @@ -210,7 +218,7 @@ def coordinates(self, coordinates: Coordinates): self._coordinates = coordinates @property - def events(self): + def events(self) -> Iterable: return self._events @@ -261,8 +269,13 @@ def place(self) -> Optional[Place]: return self._place @place.setter - def place(self, place: Place): + def place(self, place: Optional[Place]): + previous_place = self._place self._place = place + if previous_place is not None: + previous_place.events.remove(self) + if place is not None: + place.events.add(self) @property def type(self): diff --git a/betty/gramps.py b/betty/gramps.py index 8bd411836..a286ab572 100644 --- a/betty/gramps.py +++ b/betty/gramps.py @@ -210,8 +210,6 @@ def _parse_event(places: Dict[str, Place], element: Element) -> Tuple[str, Event # Parse the event place. place_handle = xpath1(element, './ns:place/@hlink') if place_handle: - place = places[place_handle] - event.place = place - place.events.add(event) + event.place = places[place_handle] return handle, event diff --git a/betty/tests/test_ancestry.py b/betty/tests/test_ancestry.py index 13df6ae28..b953c67ce 100644 --- a/betty/tests/test_ancestry.py +++ b/betty/tests/test_ancestry.py @@ -1,6 +1,6 @@ from unittest import TestCase -from betty.ancestry import EventHandlingSet, Person, Family +from betty.ancestry import EventHandlingSet, Person, Family, Event, Place class EventHandlingSetTest(TestCase): @@ -109,3 +109,24 @@ def test_children_should_sync_references(self): sut.children.remove(child) self.assertEquals([], list(sut.children)) self.assertEquals(None, child.descendant_family) + + +class PlaceTest(TestCase): + def test_events_should_sync_references(self): + event = Event('1', Event.Type.BIRTH) + sut = Place('1') + sut.events.add(event) + self.assertIn(event, sut.events) + self.assertEquals(sut, event.place) + + +class EventTest(TestCase): + def test_place_should_sync_references(self): + place = Place('1') + sut = Event('1', Event.Type.BIRTH) + sut.place = place + self.assertEquals(place, sut.place) + self.assertIn(sut, place.events) + sut.place = None + self.assertEquals(None, sut.place) + self.assertNotIn(sut, place.events) diff --git a/betty/tests/test_gramps.py b/betty/tests/test_gramps.py index 0d12b5917..6b23ef035 100644 --- a/betty/tests/test_gramps.py +++ b/betty/tests/test_gramps.py @@ -24,6 +24,11 @@ def test_place_should_include_coordinates(self): self.assertEquals('52.366667', place.coordinates.latitude) self.assertEquals('4.9', place.coordinates.longitude) + def test_place_should_include_events(self): + place = self.ancestry.places['P0000'] + event = self.ancestry.events['E0000'] + self.assertIn(event, place.events) + class ParsePersonTest(GrampsTestCase): def test_person_should_include_individual_name(self): @@ -55,7 +60,8 @@ def test_family_should_include_children(self): class ParseEventTest(GrampsTestCase): def test_event_should_include_place(self): event = self.ancestry.events['E0000'] - self.assertEquals('P0000', event.place.id) + place = self.ancestry.places['P0000'] + self.assertEquals(place, event.place) def test_event_should_include_date(self): event = self.ancestry.events['E0000']