Skip to content

Commit

Permalink
Move to A800
Browse files Browse the repository at this point in the history
  • Loading branch information
zqiao11 committed Dec 25, 2024
1 parent 09d4e94 commit ee41984
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 15 deletions.
2 changes: 1 addition & 1 deletion project/pf/multi_scale/finetune/base/turkey_power.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ val_data=${data} \
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
trainer.callbacks.'2'.patience=10
model.lr=5e-6
3 changes: 2 additions & 1 deletion project/pf/multi_scale/finetune/small/turkey_power.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ data.prediction_length=$pl \
val_data=${data} \
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl
val_data.prediction_length=$pl \
model.lr=5e-6
2 changes: 1 addition & 1 deletion project/pf/single_scale/finetune/base/turkey_power.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ val_data=${data} \
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
trainer.callbacks.'2'.patience=10
model.lr=5e-6
2 changes: 1 addition & 1 deletion project/pf/single_scale/finetune/small/turkey_power.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ val_data=${data} \
val_data.patch_size=${ps} \
val_data.context_length=$cl \
val_data.prediction_length=$pl \
trainer.callbacks.'2'.patience=10
model.lr=5e-6
13 changes: 13 additions & 0 deletions src/uni2ts/model/multi_scale_moirai/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from uni2ts.module.ts_embed import MultiInSizeLinear, MultiOutSizeLinear
from uni2ts.optim import SchedulerType, get_scheduler
from uni2ts.transform import (
AddNewScaleSeries,
AddNewScaleContextSeries,
AddObservedMask,
AddSampleIndex,
Expand Down Expand Up @@ -485,6 +486,12 @@ def default_train_transform(
optional_fields=("past_feat_dynamic_real",),
)
# QZ: Apply downsample to target. Create a new field 'target{i}' for each scale.
# + AddNewScaleSeries(
# target_field="target",
# ds_factor=self.ds_factor,
# new_scales_target_fields=self.new_scales_target_fields,
# expected_ndim=2,
# )
+ AddNewScaleContextSeries(
target_field="target",
ds_factor=self.ds_factor,
Expand Down Expand Up @@ -624,6 +631,12 @@ def default_val_transform(
fields=("target",),
optional_fields=("past_feat_dynamic_real",),
)
# + AddNewScaleSeries(
# target_field="target",
# ds_factor=self.ds_factor,
# new_scales_target_fields=self.new_scales_target_fields,
# expected_ndim=2,
# )
+ AddNewScaleContextSeries(
target_field="target",
ds_factor=self.ds_factor,
Expand Down
23 changes: 12 additions & 11 deletions src/uni2ts/transform/multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ class AddNewScaleSeries(CheckArrNDimMixin, Transformation):
"""

target_field: str
num_new_scales: int
ds_factor: int
new_scales_target_fields: tuple[str, ...]
expected_ndim: int = 2

def __post_init__(self):
Expand All @@ -114,19 +115,19 @@ def __post_init__(self):

def __call__(self, data_entry: dict[str, Any]) -> dict[str, Any]:
self.__post_init__()
for i in range(self.num_new_scales):
data_entry[f"target{i+1}"] = self._downsample(
for field in self.new_scales_target_fields:
data_entry[field] = self._downsample(
data_entry,
self.target_field,
)

for i in range(self.num_new_scales):
self.context_length_new_scales[f"target{i+1}"] = (
self.new_context_length_list[i]
)
self.prediction_length_new_scales[f"target{i+1}"] = (
self.new_prediction_length_list[i]
)
for field in self.new_scales_target_fields:
self.context_length_new_scales[field] = self.new_context_length_list[
int(field[-1])
]
self.prediction_length_new_scales[field] = self.new_prediction_length_list[
int(field[-1])
]
data_entry["context_length_new_scales"] = self.context_length_new_scales
data_entry["prediction_length_new_scales"] = self.prediction_length_new_scales

Expand All @@ -142,7 +143,7 @@ def _downsample(self, data_entry: dict[str, Any], field: str) -> np.ndarray:

self.check_ndim(field, arr, self.expected_ndim)
dim, time = arr.shape[:2]
ds_factor = 2
ds_factor = self.ds_factor

if len(self.new_context_length_list) == 0:
context_length = data_entry["context_length"]
Expand Down

0 comments on commit ee41984

Please sign in to comment.