Skip to content

Commit

Permalink
Initial version of multiscale
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Oct 4, 2024
1 parent cb136d4 commit e6c2cf1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/uni2ts/model/moirai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .finetune import MoiraiFinetune
from .forecast import MoiraiForecast
from .module import MoiraiModule
from .multi_scale_finetune import MoiraiMultiScaleFinetune
from .multi_scale_forecast import MultiScaleMoiraiForecast
from .pretrain import MoiraiPretrain

__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiPretrain"]
__all__ = ["MoiraiFinetune", "MoiraiForecast", "MoiraiModule", "MoiraiMultiScaleFinetune", "MultiScaleMoiraiForecast", "MoiraiPretrain"]
10 changes: 10 additions & 0 deletions src/uni2ts/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from .feature import AddObservedMask, AddTimeIndex, AddVariateIndex
from .field import LambdaSetFieldIfNotPresent, RemoveFields, SelectFields, SetValue
from .imputation import DummyValueImputation, ImputeTimeSeries, LastValueImputation
from .multi_scale import (
AddNewFreqScaleSeries,
AddNewScaleSeries,
MultiScaleMaskedPredictionGivenFixedConfig,
PadNewScaleSeries,
)
from .pad import EvalPad, MaskOutRangePaddedTokens, Pad, PadFreq
from .patch import (
DefaultPatchSizeConstraints,
Expand All @@ -43,6 +49,8 @@
)

__all__ = [
"AddNewScaleSeries",
"AddNewFreqScaleSeries",
"AddObservedMask",
"AddTimeIndex",
"AddVariateIndex",
Expand All @@ -64,10 +72,12 @@
"MaskedPrediction",
"MaskedPredictionGivenFixedConfig",
"MaskOutRangePaddedTokens",
"MultiScaleMaskedPredictionGivenFixedConfig",
"PackCollection",
"PackFields",
"Pad",
"PadFreq",
"PadNewScaleSeries",
"PatchCrop",
"PatchCropGivenFixedConfig",
"PatchSizeConstraints",
Expand Down
15 changes: 15 additions & 0 deletions src/uni2ts/transform/_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,18 @@ def check_ndim(self, name: str, arr: np.ndarray, expected_ndim: int):
f"has expected ndim: {expected_ndim}, "
f"but got ndim: {arr.ndim} of shape {arr.shape}."
)


class AddNewArrMixin:
@staticmethod
def apply_func(
func: Callable[[dict[str, Any], str], np.ndarray],
data_entry: dict[str, Any],
fields: tuple[str, ...],
optional_fields: tuple[str, ...] = (),
):
for field in fields:
return func(data_entry, field)
for field in optional_fields:
if field in data_entry:
func(data_entry, field)

0 comments on commit e6c2cf1

Please sign in to comment.