diff --git a/src/aligned_textgrid/aligned_textgrid.py b/src/aligned_textgrid/aligned_textgrid.py index 054cab8..395de0b 100644 --- a/src/aligned_textgrid/aligned_textgrid.py +++ b/src/aligned_textgrid/aligned_textgrid.py @@ -11,8 +11,9 @@ from aligned_textgrid.sequences.tiers import SequenceTier, TierGroup from aligned_textgrid.points.tiers import SequencePointTier, PointsGroup from aligned_textgrid.mixins.within import WithinMixins +from aligned_textgrid.mixins.tiermixins import TierGroupMixins from aligned_textgrid.custom_classes import custom_classes, clone_class, get_class_hierarchy -from typing import Type, Sequence, Literal +from typing import Type, Literal from copy import copy import numpy as np from collections.abc import Sequence @@ -24,7 +25,11 @@ class AlignedTextGrid(Sequence, WithinMixins): """An aligned Textgrid Args: - textgrid (str|Path|praatio.textgrid.Textgrid, optional): A `praatio` TextGrid + textgrid (str|Path|praatio.textgrid.Textgrid|Sequence[TierGroup|PointsGroup], optional): + An object to create a new AlignedTextGrid which can be one of: + i) A path-like value (str|pathlib.Path) to a TextGrid file. + ii) A praatio.textgrid.TextGrid object. + iii) A list of [](`~aligned_textgrid.TierGroup`)s entry_classes (Sequence[Sequence[Type[SequenceInterval]]] | Sequence[Type[SequenceInterval]], optional): If a single list of `SequenceInterval` subclasses is given, they will be repeated as many times as necessary to assign a class to every tier. @@ -53,7 +58,7 @@ class for each tier within each tier group. Say, if only the first speaker def __init__( self, - textgrid: Textgrid|str|Path = None, + textgrid: Textgrid|str|Path|Sequence[TierGroup|PointsGroup] = None, entry_classes: Sequence[Sequence[Type[SequenceInterval]]] | Sequence[Type[SequenceInterval]] @@ -61,23 +66,22 @@ def __init__( *, textgrid_path: str = None ): - self.entry_classes = None - self.entry_classes = self._reclone_classes(entry_classes) + self._cloned_classes = [] + self._tier_groups = [] + self.contains = self.tier_groups + if textgrid_path: textgrid = textgrid_path if textgrid: - self._process_textgrid_arg(textgrid) + self._process_textgrid_arg(textgrid, entry_classes) else: warnings.warn('Initializing an empty AlignedTextGrid') - self._init_empty() return - - self.tier_groups = self._relate_tiers() + self.contains = self.tier_groups for tgr in self.tier_groups: tgr.within = self - self.entry_classes = [[tier.entry_class for tier in tg] for tg in self.tier_groups] self._set_group_names() def __getitem__( @@ -110,7 +114,17 @@ def __repr__(self): def __setstate__(self, d): self.__dict__ = d - def _process_textgrid_arg(self, arg): + def _process_textgrid_arg(self, arg, entry_classes): + + # if passed a list of TierGroups + if isinstance(arg, Sequence) and \ + len(arg)>0 and \ + all([isinstance(v, TierGroupMixins) for v in arg]): + for trg in arg: + self.append(trg) + return + + # if passed a Path-like value if isinstance(arg, str) or isinstance(arg, Path): arg_str = str(arg) tg = openTextgrid( @@ -119,10 +133,15 @@ def _process_textgrid_arg(self, arg): duplicateNamesMode='rename' ) + #if passed a praatio.textgrid.Textgrid if isinstance(arg, Textgrid): tg = arg - - self.tg_tiers, self.entry_classes = self._nestify_tiers(tg, self.entry_classes) + + # do nestifying etc here. + tg_tiers, entry_classes = self._nestify_tiers(tg, entry_classes) + tier_groups = self._relate_tiers(tg_tiers, entry_classes) + self.tier_groups = tier_groups + def _extend_classes( self, @@ -164,22 +183,14 @@ def _reclone_classes( ) -> list[SequenceInterval|SequencePoint]|list[list[SequenceInterval|SequencePoint]]: flat_classes = entry_classes + if type(entry_classes[0]) is list: flat_classes = [c for tg in entry_classes for c in tg] unique_classes = list(set(flat_classes)) - orig_classes = [] - orig_class_names = [] - if self.entry_classes: - orig_classes = [c for tg in self.entry_classes for c in tg] - orig_class_names = [c.__name__ for c in orig_classes] + cloned_class_names = [cl.__name__ for cl in self._cloned_classes] - - already_cloned = [c.__name__ in orig_classes for c in flat_classes] - if all(already_cloned): - return - points = [c for c in unique_classes if issubclass(c, SequencePoint)] tops = [ c @@ -190,17 +201,17 @@ def _reclone_classes( points_clone = [] for p in points: - if p.__name__ in orig_class_names: + if p.__name__ in cloned_class_names: points_clone.append( - orig_classes[orig_class_names.index(p.__name__)] + self._cloned_classes[cloned_class_names.index(p.__name__)] ) else: points_clone.append(clone_class(p)) tops_clone = [] for t in tops: - if t.__name__ in orig_class_names: + if t.__name__ in cloned_class_names: tops_clone.append( - orig_classes[orig_class_names.index(t.__name__)] + self._cloned_classes[cloned_class_names.index(t.__name__)] ) else: tops_clone.append(clone_class(t)) @@ -210,16 +221,7 @@ def _reclone_classes( full_seq_clone += get_class_hierarchy(tclone, []) full_clone = points_clone + full_seq_clone - - if type(entry_classes[0]) is list: - new_entry_classes = [ - self._swap_classes(tg_classes, full_clone) - for tg_classes in entry_classes - ] - else: - new_entry_classes = self._swap_classes(entry_classes, full_clone) - - return new_entry_classes + self._cloned_classes += full_clone def _swap_classes( self, @@ -233,13 +235,6 @@ def _swap_classes( out_classes = [new_classes[i] for i in new_idx] return out_classes - - - def _init_empty(self): - self.tier_groups = [] - self.contains = self.tier_groups - self.entry_classes = [] - self.tg_tiers = None def _nestify_tiers( self, @@ -308,7 +303,7 @@ def _nestify_tiers( return tier_list, entry_list - def _relate_tiers(self): + def _relate_tiers(self, tg_tiers, entry_classes): """_Private method_ creates RelatedTier objects for each set of @@ -319,8 +314,13 @@ def _relate_tiers(self): """ tier_groups = [] + self._reclone_classes(entry_classes) + if type(entry_classes[0]) is list: + entry_classes = [self._swap_classes(ecs, self._cloned_classes) for ecs in entry_classes] + else: + entry_classes = self._swap_classes(entry_classes, self._cloned_classes) - for tier_group, classes in zip(self.tg_tiers, self.entry_classes): + for tier_group, classes in zip(tg_tiers, entry_classes): sequence_tier_list = [] point_tier_list = [] for tier, entry_class in zip(tier_group, classes): @@ -355,6 +355,27 @@ def _set_group_names(self): for idx, name in enumerate(tier_group_names): setattr(self, name, self.tier_groups[idx]) + @property + def tier_groups(self) -> list[TierGroup|PointsGroup|None]: + if self._tier_groups: + return self._tier_groups + return [] + + @tier_groups.setter + def tier_groups(self, new:Sequence[TierGroup|PointsGroup]) -> None: + if not(isinstance(new, Sequence) and all([isinstance(v, TierGroupMixins) for v in new])): + raise ValueError("Ony a list of TierGroups can be set as tier groups.") + + if len(new) < 1: + return + + self._tier_groups = new + self.contains = self.tier_groups + + @property + def entry_classes(self): + return [tgr.entry_classes for tgr in self.tier_groups] + @property def tier_names(self) -> list[str]: if len(self) == 0: @@ -414,15 +435,17 @@ def append(self, tier_group:TierGroup): tier_group (TierGroup): The TierGroup to append to the AlignedTextGrid """ - new_classes = self._reclone_classes(tier_group.entry_classes) + self._reclone_classes(tier_group.entry_classes) + new_classes = self._swap_classes(tier_group.entry_classes, self._cloned_classes) for cl, tier in zip(new_classes, tier_group): entries = [cl._cast(i) for i in tier] tier.__init__(entries) tier_group.__init__(tier_group) - self.tier_groups.append(tier_group) - self.entry_classes = [[tier.entry_class for tier in tg] for tg in self.tier_groups] + new_tgs = self.tier_groups + [tier_group] + self.tier_groups = new_tgs + self._set_group_names() def cleanup(self)->None: """Cleanup gaps in AlignedTextGrid @@ -579,17 +602,14 @@ def interleave_class( new_class.set_subset_class(down_class) new_tiergoups = [] - new_entry_classes = [] for tg in self.tier_groups: if not specified_class in tg.entry_classes: new_tiergoups.append(tg) - new_entry_classes.append(tg.entry_classes) else: copy_tier = [tier for tier in tg if tier.entry_class is copy_class][0] new_tier = SequenceTier( - [seq.return_interval() for seq in copy_tier], - entry_class = new_class + [new_class(seq) for seq in copy_tier] ) if not copy_labels: for seq in new_tier: @@ -610,10 +630,8 @@ def interleave_class( new_tg = TierGroup(tier_list) new_tg.name = tg.name new_tiergoups.append(new_tg) - new_entry_classes.append(new_tg.entry_classes) self.tier_groups = new_tiergoups - self.entry_classes = new_entry_classes for tgr in self.tier_groups: tgr.within = self self.contains = self.tier_groups @@ -697,8 +715,6 @@ def pop_class( new_tier_groups.append(new_tg) self.tier_groups = new_tier_groups - new_entry_classes = [tg.entry_classes for tg in self.tier_groups] - self.entry_classes = new_entry_classes for tgr in self.tier_groups: tgr.within = self self.contains = self.tier_groups diff --git a/src/aligned_textgrid/mixins/within.py b/src/aligned_textgrid/mixins/within.py index 51ae08e..418ffb2 100644 --- a/src/aligned_textgrid/mixins/within.py +++ b/src/aligned_textgrid/mixins/within.py @@ -1,4 +1,5 @@ -from typing import TypeVar, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import TypeVar, TYPE_CHECKING if TYPE_CHECKING: from aligned_textgrid import SequenceInterval, \ SequenceTier,\ diff --git a/src/aligned_textgrid/points/tiers.py b/src/aligned_textgrid/points/tiers.py index f1c5b34..9d0a215 100644 --- a/src/aligned_textgrid/points/tiers.py +++ b/src/aligned_textgrid/points/tiers.py @@ -232,7 +232,7 @@ class PointsGroup(Sequence, TierGroupMixins, WithinMixins): [](`~aligned_textgrid.mixins.within.WithinMixins`) Args: - tiers (list[SequencePointTier]: + tiers (list[SequencePointTier]|PointsGroup): A list of SequencePointTiers Attributes: @@ -240,8 +240,10 @@ class PointsGroup(Sequence, TierGroupMixins, WithinMixins): """ def __init__( self, - tiers: list[SequencePointTier] = [SequencePointTier()] + tiers: list[SequencePointTier]|Self = [SequencePointTier()] ): + if isinstance(tiers, PointsGroup): + tiers = [tier for tier in tiers] self.tier_list = tiers self.contains = self.tier_list self._set_tier_names() diff --git a/tests/test_add_append.py b/tests/test_add_append.py index 2d84720..db0f26a 100644 --- a/tests/test_add_append.py +++ b/tests/test_add_append.py @@ -400,6 +400,73 @@ def test_tier_group_concat(self): assert word1.fol is word2 +class TestATG: + + def test_atg_append(self): + tg1 = TierGroup() + tg2 = TierGroup() + pg1 = PointsGroup() + atg = AlignedTextGrid() + + atg.append(tg1) + atg.append(tg2) + atg.append(pg1) + + assert tg1 in atg + assert tg2 in atg + assert pg1 in atg + + assert tg1 in atg.contains + assert tg2 in atg.contains + assert pg1 in atg.contains + + assert tg1.within is atg + assert tg2.within is atg + assert pg1.within is atg + + def test_clone_on_append(self): + MyWord, = custom_classes(["MyWord"]) + tgr = TierGroup([ + SequenceTier([ + MyWord((0,10,"test")) + ]) + ]) + + atg = AlignedTextGrid() + atg.append(tgr) + + assert tgr in atg + + assert not tgr.entry_classes[0] is MyWord + assert tgr.entry_classes[0].__name__ == "MyWord" + assert isinstance(tgr.MyWord.first, MyWord) + + def test_clone_reuse(self): + MyWord, = custom_classes(["MyWord"]) + tgr1 = TierGroup([ + SequenceTier([ + MyWord((0,10,"test")) + ]) + ]) + + tgr2 = TierGroup([ + SequenceTier([ + MyWord((0,10,"test")) + ]) + ]) + + atg = AlignedTextGrid() + + assert tgr1.entry_classes[0] is MyWord + assert tgr2.entry_classes[0] is MyWord + + atg.append(tgr1) + atg.append(tgr2) + + assert not tgr1.entry_classes[0] is MyWord + assert not tgr2.entry_classes[0] is MyWord + + assert tgr1.entry_classes[0] is tgr2.entry_classes[0] class TestCleanups: def test_sequence_cleanup(self): diff --git a/tests/test_aligned_textgrid.py b/tests/test_aligned_textgrid.py index 23f39b2..11ef3d8 100644 --- a/tests/test_aligned_textgrid.py +++ b/tests/test_aligned_textgrid.py @@ -119,6 +119,45 @@ def test_read_multi(self): assert len(atg_multi[0]) == 2 and len(atg_multi[2]) == 2 assert len(atg_multi[1]) == 1 and len(atg_multi[3]) == 1 +class TestManualCreation: + def test_manual_creation(self): + tg1 = TierGroup() + tg2 = TierGroup() + + atg = AlignedTextGrid([tg1, tg2]) + + assert tg1 in atg + assert tg2 in atg + + assert tg1 in atg.contains + assert tg2 in atg.contains + + assert tg1.within is atg + assert tg2.within is atg + + def test_append(self): + atg = AlignedTextGrid( + "tests/test_data/KY25A_1.TextGrid", + entry_classes=[MyWord, MyPhone] + ) + + orig_tgr = atg[0] + word_tier = orig_tgr.MyWord + phone_tier = orig_tgr.MyPhone + + empty_tgr = TierGroup() + + atg.append(empty_tgr) + + assert orig_tgr in atg + assert empty_tgr in atg + + assert orig_tgr.within is atg + assert empty_tgr.within is atg + + assert atg[0].MyWord is word_tier + assert atg[0].MyPhone is phone_tier + class TestClassSetting: atg1 = AlignedTextGrid( textgrid_path="tests/test_data/KY25A_1.TextGrid",