diff --git a/src/uni2ts/model/moirai/__init__.py b/src/uni2ts/model/moirai/__init__.py index 341b734..bb9623f 100644 --- a/src/uni2ts/model/moirai/__init__.py +++ b/src/uni2ts/model/moirai/__init__.py @@ -16,6 +16,8 @@ from .finetune import MoiraiFinetune from .forecast import MoiraiForecast from .module import MoiraiModule +from .multi_scale_finetune import MoiraiMultiScaleFinetune +from .multi_scale_forecast import MultiScaleMoiraiForecast from .pretrain import MoiraiPretrain -__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiPretrain"] +__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiMultiScaleFinetune", "MultiScaleMoiraiForecast", "MoiraiPretrain"] diff --git a/src/uni2ts/transform/__init__.py b/src/uni2ts/transform/__init__.py index 2e78ec6..349c201 100644 --- a/src/uni2ts/transform/__init__.py +++ b/src/uni2ts/transform/__init__.py @@ -18,6 +18,12 @@ from .feature import AddObservedMask, AddTimeIndex, AddVariateIndex from .field import LambdaSetFieldIfNotPresent, RemoveFields, SelectFields, SetValue from .imputation import DummyValueImputation, ImputeTimeSeries, LastValueImputation +from .multi_scale import ( + AddNewFreqScaleSeries, + AddNewScaleSeries, + MultiScaleMaskedPredictionGivenFixedConfig, + PadNewScaleSeries, +) from .pad import EvalPad, MaskOutRangePaddedTokens, Pad, PadFreq from .patch import ( DefaultPatchSizeConstraints, @@ -43,6 +49,8 @@ ) __all__ = [ + "AddNewScaleSeries", + "AddNewFreqScaleSeries", "AddObservedMask", "AddTimeIndex", "AddVariateIndex", @@ -64,10 +72,12 @@ "MaskedPrediction", "MaskedPredictionGivenFixedConfig", "MaskOutRangePaddedTokens", + "MultiScaleMaskedPredictionGivenFixedConfig", "PackCollection", "PackFields", "Pad", "PadFreq", + "PadNewScaleSeries", "PatchCrop", "PatchCropGivenFixedConfig", "PatchSizeConstraints", diff --git a/src/uni2ts/transform/_mixin.py b/src/uni2ts/transform/_mixin.py index c723a58..00577b2 100644 --- a/src/uni2ts/transform/_mixin.py +++ b/src/uni2ts/transform/_mixin.py @@ -121,3 +121,18 @@ def check_ndim(self, name: str, arr: np.ndarray, expected_ndim: int): f"has expected ndim: {expected_ndim}, " f"but got ndim: {arr.ndim} of shape {arr.shape}." ) + + +class AddNewArrMixin: + @staticmethod + def apply_func( + func: Callable[[dict[str, Any], str], np.ndarray], + data_entry: dict[str, Any], + fields: tuple[str, ...], + optional_fields: tuple[str, ...] = (), + ): + for field in fields: + return func(data_entry, field) + for field in optional_fields: + if field in data_entry: + func(data_entry, field)