Skip to content

Commit

Permalink
Merge pull request #199 from Forced-Alignment-and-Vowel-Extraction/co…
Browse files Browse the repository at this point in the history
…ncat-methods

AlignedTextGrid creation flexibilization
  • Loading branch information
JoFrhwld authored Jun 25, 2024
2 parents 73b3b1c + 2c8a77f commit 7ce4288
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 60 deletions.
130 changes: 73 additions & 57 deletions src/aligned_textgrid/aligned_textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from aligned_textgrid.sequences.tiers import SequenceTier, TierGroup
from aligned_textgrid.points.tiers import SequencePointTier, PointsGroup
from aligned_textgrid.mixins.within import WithinMixins
from aligned_textgrid.mixins.tiermixins import TierGroupMixins
from aligned_textgrid.custom_classes import custom_classes, clone_class, get_class_hierarchy
from typing import Type, Sequence, Literal
from typing import Type, Literal
from copy import copy
import numpy as np
from collections.abc import Sequence
Expand All @@ -24,7 +25,11 @@ class AlignedTextGrid(Sequence, WithinMixins):
"""An aligned Textgrid
Args:
textgrid (str|Path|praatio.textgrid.Textgrid, optional): A `praatio` TextGrid
textgrid (str|Path|praatio.textgrid.Textgrid|Sequence[TierGroup|PointsGroup], optional):
An object to create a new AlignedTextGrid which can be one of:
i) A path-like value (str|pathlib.Path) to a TextGrid file.
ii) A praatio.textgrid.TextGrid object.
iii) A list of [](`~aligned_textgrid.TierGroup`)s
entry_classes (Sequence[Sequence[Type[SequenceInterval]]] | Sequence[Type[SequenceInterval]], optional):
If a single list of `SequenceInterval` subclasses is given, they will be
repeated as many times as necessary to assign a class to every tier.
Expand Down Expand Up @@ -53,31 +58,30 @@ class for each tier within each tier group. Say, if only the first speaker

def __init__(
self,
textgrid: Textgrid|str|Path = None,
textgrid: Textgrid|str|Path|Sequence[TierGroup|PointsGroup] = None,
entry_classes:
Sequence[Sequence[Type[SequenceInterval]]] |
Sequence[Type[SequenceInterval]]
= [SequenceInterval],
*,
textgrid_path: str = None
):
self.entry_classes = None
self.entry_classes = self._reclone_classes(entry_classes)
self._cloned_classes = []
self._tier_groups = []
self.contains = self.tier_groups

if textgrid_path:
textgrid = textgrid_path

if textgrid:
self._process_textgrid_arg(textgrid)
self._process_textgrid_arg(textgrid, entry_classes)
else:
warnings.warn('Initializing an empty AlignedTextGrid')
self._init_empty()
return

self.tier_groups = self._relate_tiers()

self.contains = self.tier_groups
for tgr in self.tier_groups:
tgr.within = self
self.entry_classes = [[tier.entry_class for tier in tg] for tg in self.tier_groups]
self._set_group_names()

def __getitem__(
Expand Down Expand Up @@ -110,7 +114,17 @@ def __repr__(self):
def __setstate__(self, d):
self.__dict__ = d

def _process_textgrid_arg(self, arg):
def _process_textgrid_arg(self, arg, entry_classes):

# if passed a list of TierGroups
if isinstance(arg, Sequence) and \
len(arg)>0 and \
all([isinstance(v, TierGroupMixins) for v in arg]):
for trg in arg:
self.append(trg)
return

# if passed a Path-like value
if isinstance(arg, str) or isinstance(arg, Path):
arg_str = str(arg)
tg = openTextgrid(
Expand All @@ -119,10 +133,15 @@ def _process_textgrid_arg(self, arg):
duplicateNamesMode='rename'
)

#if passed a praatio.textgrid.Textgrid
if isinstance(arg, Textgrid):
tg = arg

self.tg_tiers, self.entry_classes = self._nestify_tiers(tg, self.entry_classes)

# do nestifying etc here.
tg_tiers, entry_classes = self._nestify_tiers(tg, entry_classes)
tier_groups = self._relate_tiers(tg_tiers, entry_classes)
self.tier_groups = tier_groups


def _extend_classes(
self,
Expand Down Expand Up @@ -164,22 +183,14 @@ def _reclone_classes(
) -> list[SequenceInterval|SequencePoint]|list[list[SequenceInterval|SequencePoint]]:

flat_classes = entry_classes

if type(entry_classes[0]) is list:
flat_classes = [c for tg in entry_classes for c in tg]

unique_classes = list(set(flat_classes))

orig_classes = []
orig_class_names = []
if self.entry_classes:
orig_classes = [c for tg in self.entry_classes for c in tg]
orig_class_names = [c.__name__ for c in orig_classes]
cloned_class_names = [cl.__name__ for cl in self._cloned_classes]


already_cloned = [c.__name__ in orig_classes for c in flat_classes]
if all(already_cloned):
return

points = [c for c in unique_classes if issubclass(c, SequencePoint)]
tops = [
c
Expand All @@ -190,17 +201,17 @@ def _reclone_classes(

points_clone = []
for p in points:
if p.__name__ in orig_class_names:
if p.__name__ in cloned_class_names:
points_clone.append(
orig_classes[orig_class_names.index(p.__name__)]
self._cloned_classes[cloned_class_names.index(p.__name__)]
)
else:
points_clone.append(clone_class(p))
tops_clone = []
for t in tops:
if t.__name__ in orig_class_names:
if t.__name__ in cloned_class_names:
tops_clone.append(
orig_classes[orig_class_names.index(t.__name__)]
self._cloned_classes[cloned_class_names.index(t.__name__)]
)
else:
tops_clone.append(clone_class(t))
Expand All @@ -210,16 +221,7 @@ def _reclone_classes(
full_seq_clone += get_class_hierarchy(tclone, [])

full_clone = points_clone + full_seq_clone

if type(entry_classes[0]) is list:
new_entry_classes = [
self._swap_classes(tg_classes, full_clone)
for tg_classes in entry_classes
]
else:
new_entry_classes = self._swap_classes(entry_classes, full_clone)

return new_entry_classes
self._cloned_classes += full_clone

def _swap_classes(
self,
Expand All @@ -233,13 +235,6 @@ def _swap_classes(

out_classes = [new_classes[i] for i in new_idx]
return out_classes


def _init_empty(self):
self.tier_groups = []
self.contains = self.tier_groups
self.entry_classes = []
self.tg_tiers = None

def _nestify_tiers(
self,
Expand Down Expand Up @@ -308,7 +303,7 @@ def _nestify_tiers(

return tier_list, entry_list

def _relate_tiers(self):
def _relate_tiers(self, tg_tiers, entry_classes):
"""_Private method_
creates RelatedTier objects for each set of
Expand All @@ -319,8 +314,13 @@ def _relate_tiers(self):
"""

tier_groups = []
self._reclone_classes(entry_classes)
if type(entry_classes[0]) is list:
entry_classes = [self._swap_classes(ecs, self._cloned_classes) for ecs in entry_classes]
else:
entry_classes = self._swap_classes(entry_classes, self._cloned_classes)

for tier_group, classes in zip(self.tg_tiers, self.entry_classes):
for tier_group, classes in zip(tg_tiers, entry_classes):
sequence_tier_list = []
point_tier_list = []
for tier, entry_class in zip(tier_group, classes):
Expand Down Expand Up @@ -355,6 +355,27 @@ def _set_group_names(self):
for idx, name in enumerate(tier_group_names):
setattr(self, name, self.tier_groups[idx])

@property
def tier_groups(self) -> list[TierGroup|PointsGroup|None]:
if self._tier_groups:
return self._tier_groups
return []

@tier_groups.setter
def tier_groups(self, new:Sequence[TierGroup|PointsGroup]) -> None:
if not(isinstance(new, Sequence) and all([isinstance(v, TierGroupMixins) for v in new])):
raise ValueError("Ony a list of TierGroups can be set as tier groups.")

if len(new) < 1:
return

self._tier_groups = new
self.contains = self.tier_groups

@property
def entry_classes(self):
return [tgr.entry_classes for tgr in self.tier_groups]

@property
def tier_names(self) -> list[str]:
if len(self) == 0:
Expand Down Expand Up @@ -414,15 +435,17 @@ def append(self, tier_group:TierGroup):
tier_group (TierGroup):
The TierGroup to append to the AlignedTextGrid
"""
new_classes = self._reclone_classes(tier_group.entry_classes)
self._reclone_classes(tier_group.entry_classes)
new_classes = self._swap_classes(tier_group.entry_classes, self._cloned_classes)
for cl, tier in zip(new_classes, tier_group):
entries = [cl._cast(i) for i in tier]
tier.__init__(entries)

tier_group.__init__(tier_group)

self.tier_groups.append(tier_group)
self.entry_classes = [[tier.entry_class for tier in tg] for tg in self.tier_groups]
new_tgs = self.tier_groups + [tier_group]
self.tier_groups = new_tgs
self._set_group_names()

def cleanup(self)->None:
"""Cleanup gaps in AlignedTextGrid
Expand Down Expand Up @@ -579,17 +602,14 @@ def interleave_class(
new_class.set_subset_class(down_class)

new_tiergoups = []
new_entry_classes = []

for tg in self.tier_groups:
if not specified_class in tg.entry_classes:
new_tiergoups.append(tg)
new_entry_classes.append(tg.entry_classes)
else:
copy_tier = [tier for tier in tg if tier.entry_class is copy_class][0]
new_tier = SequenceTier(
[seq.return_interval() for seq in copy_tier],
entry_class = new_class
[new_class(seq) for seq in copy_tier]
)
if not copy_labels:
for seq in new_tier:
Expand All @@ -610,10 +630,8 @@ def interleave_class(
new_tg = TierGroup(tier_list)
new_tg.name = tg.name
new_tiergoups.append(new_tg)
new_entry_classes.append(new_tg.entry_classes)

self.tier_groups = new_tiergoups
self.entry_classes = new_entry_classes
for tgr in self.tier_groups:
tgr.within = self
self.contains = self.tier_groups
Expand Down Expand Up @@ -697,8 +715,6 @@ def pop_class(
new_tier_groups.append(new_tg)

self.tier_groups = new_tier_groups
new_entry_classes = [tg.entry_classes for tg in self.tier_groups]
self.entry_classes = new_entry_classes
for tgr in self.tier_groups:
tgr.within = self
self.contains = self.tier_groups
Expand Down
3 changes: 2 additions & 1 deletion src/aligned_textgrid/mixins/within.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import TypeVar, Sequence, TYPE_CHECKING
from collections.abc import Sequence
from typing import TypeVar, TYPE_CHECKING
if TYPE_CHECKING:
from aligned_textgrid import SequenceInterval, \
SequenceTier,\
Expand Down
6 changes: 4 additions & 2 deletions src/aligned_textgrid/points/tiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,18 @@ class PointsGroup(Sequence, TierGroupMixins, WithinMixins):
[](`~aligned_textgrid.mixins.within.WithinMixins`)
Args:
tiers (list[SequencePointTier]:
tiers (list[SequencePointTier]|PointsGroup):
A list of SequencePointTiers
Attributes:
A list of the entry classes
"""
def __init__(
self,
tiers: list[SequencePointTier] = [SequencePointTier()]
tiers: list[SequencePointTier]|Self = [SequencePointTier()]
):
if isinstance(tiers, PointsGroup):
tiers = [tier for tier in tiers]
self.tier_list = tiers
self.contains = self.tier_list
self._set_tier_names()
Expand Down
67 changes: 67 additions & 0 deletions tests/test_add_append.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,73 @@ def test_tier_group_concat(self):

assert word1.fol is word2

class TestATG:

def test_atg_append(self):
tg1 = TierGroup()
tg2 = TierGroup()
pg1 = PointsGroup()
atg = AlignedTextGrid()

atg.append(tg1)
atg.append(tg2)
atg.append(pg1)

assert tg1 in atg
assert tg2 in atg
assert pg1 in atg

assert tg1 in atg.contains
assert tg2 in atg.contains
assert pg1 in atg.contains

assert tg1.within is atg
assert tg2.within is atg
assert pg1.within is atg

def test_clone_on_append(self):
MyWord, = custom_classes(["MyWord"])
tgr = TierGroup([
SequenceTier([
MyWord((0,10,"test"))
])
])

atg = AlignedTextGrid()
atg.append(tgr)

assert tgr in atg

assert not tgr.entry_classes[0] is MyWord
assert tgr.entry_classes[0].__name__ == "MyWord"
assert isinstance(tgr.MyWord.first, MyWord)

def test_clone_reuse(self):
MyWord, = custom_classes(["MyWord"])
tgr1 = TierGroup([
SequenceTier([
MyWord((0,10,"test"))
])
])

tgr2 = TierGroup([
SequenceTier([
MyWord((0,10,"test"))
])
])

atg = AlignedTextGrid()

assert tgr1.entry_classes[0] is MyWord
assert tgr2.entry_classes[0] is MyWord

atg.append(tgr1)
atg.append(tgr2)

assert not tgr1.entry_classes[0] is MyWord
assert not tgr2.entry_classes[0] is MyWord

assert tgr1.entry_classes[0] is tgr2.entry_classes[0]

class TestCleanups:
def test_sequence_cleanup(self):
Expand Down
Loading

0 comments on commit 7ce4288

Please sign in to comment.