diff --git a/src/uni2ts/data/indexer/_base.py b/src/uni2ts/data/indexer/_base.py index e326463..5675293 100644 --- a/src/uni2ts/data/indexer/_base.py +++ b/src/uni2ts/data/indexer/_base.py @@ -22,10 +22,27 @@ class Indexer(abc.ABC, Sequence): + """ + Base class for all Indexers. + + An Indexer is responsible for extracting data from an underlying file format. + """ + def __init__(self, uniform: bool = False): + """ + :param uniform: whether the underlying data has uniform length + """ self.uniform = uniform def check_index(self, idx: int | slice | Iterable[int]): + """ + Check the validity of a given index. + + :param idx: index to check + :return: None + :raises IndexError: if idx is out of bounds + :raises NotImplementedError: if idx is not a valid type + """ if isinstance(idx, int): if idx < 0 or idx >= len(self): raise IndexError(f"Index {idx} out of bounds for length {len(self)}") @@ -48,6 +65,12 @@ def check_index(self, idx: int | slice | Iterable[int]): def __getitem__( self, idx: int | slice | Iterable[int] ) -> dict[str, Data | BatchedData]: + """ + Retrive the data from the underlying storage in dictionary format. + + :param idx: index to retrieve + :return: underlying data with given index + """ self.check_index(idx) if isinstance(idx, int): @@ -72,9 +95,20 @@ def _getitem_int(self, idx: int) -> dict[str, Data]: ... def _getitem_iterable(self, idx: Iterable[int]) -> dict[str, BatchedData]: ... def get_uniform_probabilities(self) -> np.ndarray: + """ + Obtains uniform probability distribution over all time series. + + :return: uniform probability distribution + """ return np.ones(len(self)) / len(self) def get_proportional_probabilities(self, field: str = "target") -> np.ndarray: + """ + Obtain proportion of each time series based on number of time steps. + + :param field: field name to measure time series length + :return: proportional probabilities + """ if self.uniform: return self.get_uniform_probabilities() diff --git a/src/uni2ts/data/indexer/hf_dataset_indexer.py b/src/uni2ts/data/indexer/hf_dataset_indexer.py index a7f5d07..487ad6f 100644 --- a/src/uni2ts/data/indexer/hf_dataset_indexer.py +++ b/src/uni2ts/data/indexer/hf_dataset_indexer.py @@ -28,7 +28,15 @@ class HuggingFaceDatasetIndexer(Indexer): + """ + Indexer for Hugging Face Datasets + """ + def __init__(self, dataset: Dataset, uniform: bool = False): + """ + :param dataset: underlying Hugging Face Dataset + :param uniform: whether the underlying data has uniform length + """ super().__init__(uniform=uniform) self.dataset = dataset self.features = dict(self.dataset.features) @@ -109,6 +117,14 @@ def _pa_column_to_numpy( return array def get_proportional_probabilities(self, field: str = "target") -> np.ndarray: + """ + Obtain proportion of each time series based on number of time steps. + Leverages pyarrow.compute for fast implementation. + + :param field: field name to measure time series length + :return: proportional probabilities + """ + if self.uniform: return self.get_uniform_probabilities()