Skip to content

Commit

Permalink
Merge pull request #186 from Forced-Alignment-and-Vowel-Extraction/ab…
Browse files Browse the repository at this point in the history
…stract-base-classes

Abstract Base Class Refactor
  • Loading branch information
JoFrhwld authored Jun 11, 2024
2 parents 04d173b + d83bee7 commit 975bed9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 100 deletions.
25 changes: 3 additions & 22 deletions src/aligned_textgrid/aligned_textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from typing import Type, Sequence, Literal
from copy import copy
import numpy as np
from collections.abc import Sequence
import warnings


class AlignedTextGrid(WithinMixins):
class AlignedTextGrid(Sequence, WithinMixins):
"""An aligned Textgrid
Args:
Expand Down Expand Up @@ -80,10 +81,7 @@ def __init__(
tgr.within = self
self.entry_classes = [[tier.entry_class for tier in tg] for tg in self.tier_groups]


def __contains__(self, item):
return item in self.tier_groups


def __getitem__(
self,
idx: int | list
Expand All @@ -100,20 +98,9 @@ def __getitem__(
out_list.append(tier[x])
return(out_list)

def __iter__(self):
self._idx = 0
return self

def __len__(self):
return len(self.tier_groups)

def __next__(self):
if self._idx < len(self.tier_groups):
out = self.tier_groups[self._idx]
self._idx += 1
return(out)
raise StopIteration

def __repr__(self):
n_groups = len(self.tier_groups)
group_names = [x.name for x in self.tier_groups]
Expand Down Expand Up @@ -143,12 +130,6 @@ def __getattr__(
def __setstate__(self, d):
self.__dict__ = d

def index(
self,
group: TierGroup|PointsGroup
)->int:
return self.tier_groups.index(group)

def _extend_classes(
self,
tg: Textgrid,
Expand Down
70 changes: 2 additions & 68 deletions src/aligned_textgrid/mixins/tiermixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,42 +27,7 @@ def last(self):
if hasattr(self, "sequence_list"):
raise IndexError(f"{type(self).__name__} tier with name"\
f" {self.name} has empty sequence_list")
raise AttributeError(f"{type(self).__name__} is not indexable.")

def __contains__(self, item):
return item in self.sequence_list

def __getitem__(self, idx):
return self.sequence_list[idx]

def __iter__(self):
self._idx = 0
return self

def __len__(self):
return len(self.sequence_list)

def __next__(self):
if self._idx < len(self.sequence_list):
out = self.sequence_list[self._idx]
self._idx += 1
return(out)
raise StopIteration

def index(
self,
entry
) -> int:
"""Return index of a tier entry
Args:
entry (SequencePoint |SequenceInterval):
A SequenceInterval or a PointInterval to get the index of.
Returns:
(int): The entry's index
"""
return self.sequence_list.index(entry)
raise AttributeError(f"{type(self).__name__} is not indexable.")\


class TierGroupMixins:
Expand All @@ -71,9 +36,6 @@ class TierGroupMixins:
Attributes:
[]: Indexable and iterable
"""

def __contains__(self, item):
return item in self.tier_list

def __getattr__(
self,
Expand All @@ -91,35 +53,7 @@ def __getattr__(

if len(match_list) < 1:
raise AttributeError(f"{type(self).__name__} has no attribute {name}")

def __getitem__(
self,
idx: int|list
):
if type(idx) is int:
return self.tier_list[idx]
if len(idx) != len(self):
raise Exception("Attempt to index with incompatible list")
if type(idx) is list:
out_list = []
for x, tier in zip(idx, self.tier_list):
out_list.append(tier[x])
return(out_list)

def __iter__(self):
self._idx = 0
return self

def __len__(self):
return len(self.tier_list)

def __next__(self):
if self._idx < len(self.tier_list):
out = self.tier_list[self._idx]
self._idx += 1
return(out)
raise StopIteration


@property
def name(self):
if self._name:
Expand Down
31 changes: 26 additions & 5 deletions src/aligned_textgrid/points/tiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from aligned_textgrid.mixins.within import WithinMixins
import numpy as np
from typing import Type
from collections.abc import Sequence
import warnings

class SequencePointTier(TierMixins, WithinMixins):
class SequencePointTier(Sequence, TierMixins, WithinMixins):
"""A SequencePointTier class
Args:
Expand Down Expand Up @@ -44,8 +45,6 @@ def __init__(
tier:PointTier|list[Point] = [],
entry_class:Type[SequencePoint] = SequencePoint
):

super().__init__()
if isinstance(tier, PointTier):
self.entry_list = tier.entries
self.name = tier.name
Expand All @@ -63,6 +62,12 @@ def __init__(
this_point = self.entry_class(entry)
self.sequence_list += [this_point]
self.__set_precedence()

def __getitem__(self, idx):
return self.sequence_list[idx]

def __len__(self):
return len(self.sequence_list)

def __set_precedence(self):
for idx,seq in enumerate(self.sequence_list):
Expand Down Expand Up @@ -164,7 +169,7 @@ def save_as_tg(
out_tg.addTier(tier = point_tier)
out_tg.save(save_path, "long_textgrid")

class PointsGroup(TierGroupMixins, WithinMixins):
class PointsGroup(Sequence, TierGroupMixins, WithinMixins):
"""A collection of point tiers
Args:
Expand All @@ -181,10 +186,26 @@ def __init__(
self,
tiers: list[SequencePointTier] = [SequencePointTier()]
):
super().__init__()
self.tier_list = tiers
self.contains = self.tier_list

def __getitem__(
self,
idx: int|list
):
if type(idx) is int:
return self.tier_list[idx]
if len(idx) != len(self):
raise Exception("Attempt to index with incompatible list")
if type(idx) is list:
out_list = []
for x, tier in zip(idx, self.tier_list):
out_list.append(tier[x])
return(out_list)

def __len__(self):
return len(self.tier_list)

def get_nearest_points_index(
self,
time: float
Expand Down
33 changes: 28 additions & 5 deletions src/aligned_textgrid/sequences/tiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from aligned_textgrid.mixins.within import WithinMixins
import numpy as np
from typing import Type
from collections.abc import Sequence

import warnings

class SequenceTier(TierMixins, WithinMixins):
class SequenceTier(Sequence, TierMixins, WithinMixins):
"""A sequence tier
Given a `praatio` `IntervalTier` or list of `Interval`s, creates
Expand Down Expand Up @@ -43,7 +45,6 @@ def __init__(
tier: list[Interval] | IntervalTier = [],
entry_class: Type[SequenceInterval] = SequenceInterval
):
super().__init__()
if isinstance(tier, IntervalTier):
self.entry_list = tier.entries
self.name = tier.name
Expand All @@ -63,6 +64,13 @@ def __init__(
self.sequence_list += [this_seq]
self.__set_precedence()

def __getitem__(self, idx):
return self.sequence_list[idx]

def __len__(self):
return len(self.sequence_list)


def __set_precedence(self):
for idx,seq in enumerate(self.sequence_list):
self.__set_intier(seq)
Expand Down Expand Up @@ -107,7 +115,6 @@ def pop(

def __repr__(self):
return f"Sequence tier of {self.entry_class.__name__}; .superset_class: {self.superset_class.__name__}; .subset_class: {self.subset_class.__name__}"


@property
def starts(self):
Expand Down Expand Up @@ -182,7 +189,7 @@ def save_as_tg(
out_tg.save(save_path, "long_textgrid")


class TierGroup(TierGroupMixins, WithinMixins):
class TierGroup(Sequence,TierGroupMixins, WithinMixins):
"""Tier Grouping
Args:
Expand All @@ -206,7 +213,6 @@ def __init__(
self,
tiers: list[SequenceTier] = [SequenceTier()]
):
super().__init__()
self.tier_list = self._arrange_tiers(tiers)
#self.entry_classes = [x.__class__ for x in self.tier_list]
self._name = self.make_name()
Expand All @@ -233,6 +239,23 @@ def __init__(
for u,l in zip(upper_tier, lower_sequences):
u.set_subset_list(l)
u.validate()

def __getitem__(
self,
idx: int|list
):
if type(idx) is int:
return self.tier_list[idx]
if len(idx) != len(self):
raise Exception("Attempt to index with incompatible list")
if type(idx) is list:
out_list = []
for x, tier in zip(idx, self.tier_list):
out_list.append(tier[x])
return(out_list)

def __len__(self):
return len(self.tier_list)

def __repr__(self):
n_tiers = len(self.tier_list)
Expand Down

0 comments on commit 975bed9

Please sign in to comment.