From d83bee7cbb2fd9978bff3653bb0202b1d79451ff Mon Sep 17 00:00:00 2001 From: JoFrhwld Date: Tue, 11 Jun 2024 11:58:50 -0400 Subject: [PATCH] abstract base class usage --- src/aligned_textgrid/aligned_textgrid.py | 25 +------- src/aligned_textgrid/mixins/tiermixins.py | 70 +---------------------- src/aligned_textgrid/points/tiers.py | 31 ++++++++-- src/aligned_textgrid/sequences/tiers.py | 33 +++++++++-- 4 files changed, 59 insertions(+), 100 deletions(-) diff --git a/src/aligned_textgrid/aligned_textgrid.py b/src/aligned_textgrid/aligned_textgrid.py index 30df32e..f7c78a5 100644 --- a/src/aligned_textgrid/aligned_textgrid.py +++ b/src/aligned_textgrid/aligned_textgrid.py @@ -15,10 +15,11 @@ from typing import Type, Sequence, Literal from copy import copy import numpy as np +from collections.abc import Sequence import warnings -class AlignedTextGrid(WithinMixins): +class AlignedTextGrid(Sequence, WithinMixins): """An aligned Textgrid Args: @@ -80,10 +81,7 @@ def __init__( tgr.within = self self.entry_classes = [[tier.entry_class for tier in tg] for tg in self.tier_groups] - - def __contains__(self, item): - return item in self.tier_groups - + def __getitem__( self, idx: int | list @@ -100,20 +98,9 @@ def __getitem__( out_list.append(tier[x]) return(out_list) - def __iter__(self): - self._idx = 0 - return self - def __len__(self): return len(self.tier_groups) - def __next__(self): - if self._idx < len(self.tier_groups): - out = self.tier_groups[self._idx] - self._idx += 1 - return(out) - raise StopIteration - def __repr__(self): n_groups = len(self.tier_groups) group_names = [x.name for x in self.tier_groups] @@ -143,12 +130,6 @@ def __getattr__( def __setstate__(self, d): self.__dict__ = d - def index( - self, - group: TierGroup|PointsGroup - )->int: - return self.tier_groups.index(group) - def _extend_classes( self, tg: Textgrid, diff --git a/src/aligned_textgrid/mixins/tiermixins.py b/src/aligned_textgrid/mixins/tiermixins.py index 14ec226..039236e 100644 --- a/src/aligned_textgrid/mixins/tiermixins.py +++ b/src/aligned_textgrid/mixins/tiermixins.py @@ -27,42 +27,7 @@ def last(self): if hasattr(self, "sequence_list"): raise IndexError(f"{type(self).__name__} tier with name"\ f" {self.name} has empty sequence_list") - raise AttributeError(f"{type(self).__name__} is not indexable.") - - def __contains__(self, item): - return item in self.sequence_list - - def __getitem__(self, idx): - return self.sequence_list[idx] - - def __iter__(self): - self._idx = 0 - return self - - def __len__(self): - return len(self.sequence_list) - - def __next__(self): - if self._idx < len(self.sequence_list): - out = self.sequence_list[self._idx] - self._idx += 1 - return(out) - raise StopIteration - - def index( - self, - entry - ) -> int: - """Return index of a tier entry - - Args: - entry (SequencePoint |SequenceInterval): - A SequenceInterval or a PointInterval to get the index of. - - Returns: - (int): The entry's index - """ - return self.sequence_list.index(entry) + raise AttributeError(f"{type(self).__name__} is not indexable.")\ class TierGroupMixins: @@ -71,9 +36,6 @@ class TierGroupMixins: Attributes: []: Indexable and iterable """ - - def __contains__(self, item): - return item in self.tier_list def __getattr__( self, @@ -91,35 +53,7 @@ def __getattr__( if len(match_list) < 1: raise AttributeError(f"{type(self).__name__} has no attribute {name}") - - def __getitem__( - self, - idx: int|list - ): - if type(idx) is int: - return self.tier_list[idx] - if len(idx) != len(self): - raise Exception("Attempt to index with incompatible list") - if type(idx) is list: - out_list = [] - for x, tier in zip(idx, self.tier_list): - out_list.append(tier[x]) - return(out_list) - - def __iter__(self): - self._idx = 0 - return self - - def __len__(self): - return len(self.tier_list) - - def __next__(self): - if self._idx < len(self.tier_list): - out = self.tier_list[self._idx] - self._idx += 1 - return(out) - raise StopIteration - + @property def name(self): if self._name: diff --git a/src/aligned_textgrid/points/tiers.py b/src/aligned_textgrid/points/tiers.py index c99feff..6d621c7 100644 --- a/src/aligned_textgrid/points/tiers.py +++ b/src/aligned_textgrid/points/tiers.py @@ -9,9 +9,10 @@ from aligned_textgrid.mixins.within import WithinMixins import numpy as np from typing import Type +from collections.abc import Sequence import warnings -class SequencePointTier(TierMixins, WithinMixins): +class SequencePointTier(Sequence, TierMixins, WithinMixins): """A SequencePointTier class Args: @@ -44,8 +45,6 @@ def __init__( tier:PointTier|list[Point] = [], entry_class:Type[SequencePoint] = SequencePoint ): - - super().__init__() if isinstance(tier, PointTier): self.entry_list = tier.entries self.name = tier.name @@ -63,6 +62,12 @@ def __init__( this_point = self.entry_class(entry) self.sequence_list += [this_point] self.__set_precedence() + + def __getitem__(self, idx): + return self.sequence_list[idx] + + def __len__(self): + return len(self.sequence_list) def __set_precedence(self): for idx,seq in enumerate(self.sequence_list): @@ -164,7 +169,7 @@ def save_as_tg( out_tg.addTier(tier = point_tier) out_tg.save(save_path, "long_textgrid") -class PointsGroup(TierGroupMixins, WithinMixins): +class PointsGroup(Sequence, TierGroupMixins, WithinMixins): """A collection of point tiers Args: @@ -181,10 +186,26 @@ def __init__( self, tiers: list[SequencePointTier] = [SequencePointTier()] ): - super().__init__() self.tier_list = tiers self.contains = self.tier_list + + def __getitem__( + self, + idx: int|list + ): + if type(idx) is int: + return self.tier_list[idx] + if len(idx) != len(self): + raise Exception("Attempt to index with incompatible list") + if type(idx) is list: + out_list = [] + for x, tier in zip(idx, self.tier_list): + out_list.append(tier[x]) + return(out_list) + def __len__(self): + return len(self.tier_list) + def get_nearest_points_index( self, time: float diff --git a/src/aligned_textgrid/sequences/tiers.py b/src/aligned_textgrid/sequences/tiers.py index 99cb9b6..3147004 100644 --- a/src/aligned_textgrid/sequences/tiers.py +++ b/src/aligned_textgrid/sequences/tiers.py @@ -10,9 +10,11 @@ from aligned_textgrid.mixins.within import WithinMixins import numpy as np from typing import Type +from collections.abc import Sequence + import warnings -class SequenceTier(TierMixins, WithinMixins): +class SequenceTier(Sequence, TierMixins, WithinMixins): """A sequence tier Given a `praatio` `IntervalTier` or list of `Interval`s, creates @@ -43,7 +45,6 @@ def __init__( tier: list[Interval] | IntervalTier = [], entry_class: Type[SequenceInterval] = SequenceInterval ): - super().__init__() if isinstance(tier, IntervalTier): self.entry_list = tier.entries self.name = tier.name @@ -63,6 +64,13 @@ def __init__( self.sequence_list += [this_seq] self.__set_precedence() + def __getitem__(self, idx): + return self.sequence_list[idx] + + def __len__(self): + return len(self.sequence_list) + + def __set_precedence(self): for idx,seq in enumerate(self.sequence_list): self.__set_intier(seq) @@ -107,7 +115,6 @@ def pop( def __repr__(self): return f"Sequence tier of {self.entry_class.__name__}; .superset_class: {self.superset_class.__name__}; .subset_class: {self.subset_class.__name__}" - @property def starts(self): @@ -182,7 +189,7 @@ def save_as_tg( out_tg.save(save_path, "long_textgrid") -class TierGroup(TierGroupMixins, WithinMixins): +class TierGroup(Sequence,TierGroupMixins, WithinMixins): """Tier Grouping Args: @@ -206,7 +213,6 @@ def __init__( self, tiers: list[SequenceTier] = [SequenceTier()] ): - super().__init__() self.tier_list = self._arrange_tiers(tiers) #self.entry_classes = [x.__class__ for x in self.tier_list] self._name = self.make_name() @@ -233,6 +239,23 @@ def __init__( for u,l in zip(upper_tier, lower_sequences): u.set_subset_list(l) u.validate() + + def __getitem__( + self, + idx: int|list + ): + if type(idx) is int: + return self.tier_list[idx] + if len(idx) != len(self): + raise Exception("Attempt to index with incompatible list") + if type(idx) is list: + out_list = [] + for x, tier in zip(idx, self.tier_list): + out_list.append(tier[x]) + return(out_list) + + def __len__(self): + return len(self.tier_list) def __repr__(self): n_tiers = len(self.tier_list)