Skip to content

Commit

Permalink
Add Indexer documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
gorold committed Jun 12, 2024
1 parent db72959 commit 1ee0790
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/uni2ts/data/indexer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,27 @@


class Indexer(abc.ABC, Sequence):
"""
Base class for all Indexers.
An Indexer is responsible for extracting data from an underlying file format.
"""

def __init__(self, uniform: bool = False):
"""
:param uniform: whether the underlying data has uniform length
"""
self.uniform = uniform

def check_index(self, idx: int | slice | Iterable[int]):
"""
Check the validity of a given index.
:param idx: index to check
:return: None
:raises IndexError: if idx is out of bounds
:raises NotImplementedError: if idx is not a valid type
"""
if isinstance(idx, int):
if idx < 0 or idx >= len(self):
raise IndexError(f"Index {idx} out of bounds for length {len(self)}")
Expand All @@ -48,6 +65,12 @@ def check_index(self, idx: int | slice | Iterable[int]):
def __getitem__(
self, idx: int | slice | Iterable[int]
) -> dict[str, Data | BatchedData]:
"""
Retrive the data from the underlying storage in dictionary format.
:param idx: index to retrieve
:return: underlying data with given index
"""
self.check_index(idx)

if isinstance(idx, int):
Expand All @@ -72,9 +95,20 @@ def _getitem_int(self, idx: int) -> dict[str, Data]: ...
def _getitem_iterable(self, idx: Iterable[int]) -> dict[str, BatchedData]: ...

def get_uniform_probabilities(self) -> np.ndarray:
"""
Obtains uniform probability distribution over all time series.
:return: uniform probability distribution
"""
return np.ones(len(self)) / len(self)

def get_proportional_probabilities(self, field: str = "target") -> np.ndarray:
"""
Obtain proportion of each time series based on number of time steps.
:param field: field name to measure time series length
:return: proportional probabilities
"""
if self.uniform:
return self.get_uniform_probabilities()

Expand Down
16 changes: 16 additions & 0 deletions src/uni2ts/data/indexer/hf_dataset_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@


class HuggingFaceDatasetIndexer(Indexer):
"""
Indexer for Hugging Face Datasets
"""

def __init__(self, dataset: Dataset, uniform: bool = False):
"""
:param dataset: underlying Hugging Face Dataset
:param uniform: whether the underlying data has uniform length
"""
super().__init__(uniform=uniform)
self.dataset = dataset
self.features = dict(self.dataset.features)
Expand Down Expand Up @@ -109,6 +117,14 @@ def _pa_column_to_numpy(
return array

def get_proportional_probabilities(self, field: str = "target") -> np.ndarray:
"""
Obtain proportion of each time series based on number of time steps.
Leverages pyarrow.compute for fast implementation.
:param field: field name to measure time series length
:return: proportional probabilities
"""

if self.uniform:
return self.get_uniform_probabilities()

Expand Down

0 comments on commit 1ee0790

Please sign in to comment.