Skip to content

Commit

Permalink
Merge pull request #211 from Forced-Alignment-and-Vowel-Extraction/ef…
Browse files Browse the repository at this point in the history
…ficiency

Efficiency
  • Loading branch information
JoFrhwld authored Nov 12, 2024
2 parents 5313a1d + 4f0a837 commit bce9ec7
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 25 deletions.
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ praatio = "^6.0.0"
numpy = "^1.24.2"
polars = "^0.20.18"
cloudpickle = "^3.0.0"
toml = "^0.10.2"

[tool.poetry.group.dev.dependencies]
jupyter = "^1.0.0"
Expand Down
14 changes: 12 additions & 2 deletions src/aligned_textgrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@
from aligned_textgrid.custom_classes import custom_classes
from aligned_textgrid.outputs.to_dataframe import to_df

from importlib.metadata import version
from importlib.metadata import version

__version__ = version("aligned_textgrid")
from pathlib import Path
import toml

__version__ = "unknown"
# adopt path to your pyproject.toml
pyproject_toml_file = Path(__file__).parent.parent.parent / "pyproject.toml"
if pyproject_toml_file.exists() and pyproject_toml_file.is_file():
data = toml.load(pyproject_toml_file)
# check project.version
if "tool" in data and "poetry" in data["tool"] and "version" in data["tool"]["poetry"]:
__version__ = data["tool"]["poetry"]["version"]

__all__ = [
"SequenceInterval",
Expand Down
47 changes: 35 additions & 12 deletions src/aligned_textgrid/sequence_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,18 @@ class SequenceList(Sequence):
def __init__(self, *args:SeqVar):
self._values = []
self.entry_class = None
for arg in args:
self.append(arg)
if len(args) > 0:
pass
else:
return
all_ecs = set([arg.entry_class for arg in args])

if len(all_ecs) > 1:
raise ValueError("All values must have the same class.")

self.entry_class = all_ecs.pop()
self._values += args
self._sort()

def __getitem__(self:Sequence[SeqVar], idx:int)->SeqVar:
return self._values[idx]
Expand Down Expand Up @@ -92,17 +102,20 @@ def __repr__(self):
return self._values.__repr__()

def _sort(self)->None:
if len(self._values) > 0:
if hasattr(self[0], "start"):
item_starts = np.array([x.start for x in self._values])
if hasattr(self[0], "time"):
item_starts = np.array([x.time for x in self._values])
item_order = np.argsort(item_starts)
self._values = [self._values[idx] for idx in item_order]

if len(self._values) < 1:
return

if np.all(self.starts[:-1] <= self.starts[1:]):
return

item_order = np.argsort(self.starts)
self._values = [self._values[idx] for idx in item_order]

#@wrap(log_class.entering, log_class.exiting)
def _entry_class_checker(self, value) -> None:
if self.entry_class is None:
self.entry_class = value.entry_class
return

if not (issubclass(self.entry_class, value.entry_class)
or issubclass(value.entry_class, self.entry_class)):
Expand Down Expand Up @@ -181,9 +194,19 @@ def append(self:Sequence[SeqVar], value:SeqVar, shift:bool = False, re_init = Fa
increment = self.ends[-1]
if shift:
value._shift(increment)


this_time = 0
if hasattr(value, "start"):
this_time = value.start
elif hasattr(value, "time"):
this_time = value.time

if len(self.starts) > 0 and all(self.starts):
insert_idx = self.starts.searchsorted(this_time)
self._values.insert(insert_idx, value)
return

self._values.append(value)
self._sort()

def concat(self:Sequence[SeqVar], intervals:Sequence[SeqVar])->None:
"""Concatenate two sequence lists
Expand Down
13 changes: 9 additions & 4 deletions src/aligned_textgrid/sequences/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
`Top` and `Bottom` classes.
"""

import aligned_textgrid
#import aligned_textgrid
from praatio.utilities.constants import Interval
import praatio
from praatio.data_classes.interval_tier import IntervalTier
Expand Down Expand Up @@ -145,13 +145,18 @@ def set_subset_list(self, subset_list:SequenceList['SequenceInterval'] = None)->
set as the `super_instance` of all objects in the list.
"""

self._subset_list = SequenceList()
if subset_list is None:
return

if all([isinstance(subint, self.subset_class) for subint in subset_list]):
self._subset_list = SequenceList(*subset_list)
for element in subset_list:
self.append_subset_list(element)
self._set_within()
if not self is element.super_instance:
element.remove_superset()
element.super_instance = self

#self._set_within()
self.contains = self._subset_list
self._set_subset_precedence()
#self.validate()
else:
Expand Down
11 changes: 5 additions & 6 deletions tests/test_sequences/test_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,18 @@ def test_super_setting(self):
new_instanceB = self.LocalClassB()
assert self.pre_instanceA.superset_class is self.LocalClassB
assert new_instanceA.superset_class is self.LocalClassB
assert self.pre_instanceA.superset_class is self.LocalClassB

assert self.pre_instanceB.subset_class is self.LocalClassA
assert new_instanceB.subset_class is self.LocalClassA


def test_postsetting_instances(self):
try:
self.pre_instanceA.set_super_instance(self.pre_instanceB)
except Exception as exc:
assert False, f"{exc}"

assert self.pre_instanceA.super_instance is self.pre_instanceB
assert self.pre_instanceA in self.pre_instanceB.subset_list
assert self.pre_instanceA in self.pre_instanceB.subset_list

assert self.pre_instanceB.subset_class is self.LocalClassA
assert new_instanceB.subset_class is self.LocalClassA

class TestPrecedence:
class LocalClassA(SequenceInterval):
Expand Down

0 comments on commit bce9ec7

Please sign in to comment.