-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3260c10
commit d9629a2
Showing
3 changed files
with
450 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import os | ||
import sys | ||
import math | ||
from easydict import EasyDict | ||
sys.path.append(os.path.abspath(__file__ + '/../../..')) | ||
|
||
from basicts.metrics import masked_mae, masked_mse, masked_mape, masked_rmse | ||
from basicts.data import TimeSeriesForecastingDataset | ||
from basicts.runners import SimpleTimeSeriesForecastingRunner | ||
from basicts.scaler import ZScoreScaler | ||
from basicts.utils import get_regular_settings, load_dataset_desc | ||
|
||
from .arch import CATS | ||
|
||
############################## Hot Parameters ############################## | ||
# Dataset & Metrics configuration | ||
DATA_NAME = 'Weather' # Dataset name | ||
regular_settings = get_regular_settings(DATA_NAME) | ||
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence | ||
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence | ||
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios | ||
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data | ||
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data | ||
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data | ||
# Model architecture and parameters | ||
MODEL_ARCH = CATS | ||
MODEL_PARAM = { | ||
"seq_len": INPUT_LEN, | ||
"pred_len": OUTPUT_LEN, | ||
"dec_in": 21, | ||
"d_model": 256, | ||
"d_layers": 3, | ||
"n_heads": 32, | ||
"d_ff": 512, | ||
"dropout": 0.2, | ||
"query_independence": True, | ||
"patch_len": 24, | ||
"stride": 24, | ||
"padding_patch": 'end', | ||
"QAM_start":0.1, | ||
"QAM_end": 0.5, | ||
"store_attn": False | ||
} | ||
NUM_EPOCHS = 100 | ||
|
||
############################## General Configuration ############################## | ||
CFG = EasyDict() | ||
# General settings | ||
CFG.DESCRIPTION = 'An Example Config' | ||
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode) | ||
# Runner | ||
CFG.RUNNER = SimpleTimeSeriesForecastingRunner | ||
|
||
############################## Environment Configuration ############################## | ||
CFG.ENV = EasyDict() # Environment settings. Default: None | ||
CFG.ENV.SEED = 1 # Random seed. Default: None | ||
|
||
############################## Dataset Configuration ############################## | ||
CFG.DATASET = EasyDict() | ||
# Dataset settings | ||
CFG.DATASET.NAME = DATA_NAME | ||
CFG.DATASET.TYPE = TimeSeriesForecastingDataset | ||
CFG.DATASET.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO, | ||
'input_len': INPUT_LEN, | ||
'output_len': OUTPUT_LEN, | ||
# 'mode' is automatically set by the runner | ||
}) | ||
|
||
############################## Scaler Configuration ############################## | ||
CFG.SCALER = EasyDict() | ||
# Scaler settings | ||
CFG.SCALER.TYPE = ZScoreScaler # Scaler class | ||
CFG.SCALER.PARAM = EasyDict({ | ||
'dataset_name': DATA_NAME, | ||
'train_ratio': TRAIN_VAL_TEST_RATIO[0], | ||
'norm_each_channel': NORM_EACH_CHANNEL, | ||
'rescale': RESCALE, | ||
}) | ||
|
||
############################## Model Configuration ############################## | ||
CFG.MODEL = EasyDict() | ||
# Model settings | ||
CFG.MODEL.NAME = MODEL_ARCH.__name__ | ||
CFG.MODEL.ARCH = MODEL_ARCH | ||
CFG.MODEL.PARAM = MODEL_PARAM | ||
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2] | ||
CFG.MODEL.TARGET_FEATURES = [0] | ||
|
||
############################## Metrics Configuration ############################## | ||
|
||
CFG.METRICS = EasyDict() | ||
# Metrics settings | ||
CFG.METRICS.FUNCS = EasyDict({ | ||
'MAE': masked_mae, | ||
'MSE': masked_mse, | ||
'RMSE': masked_rmse, | ||
'MAPE': masked_mape | ||
}) | ||
CFG.METRICS.TARGET = 'MSE' | ||
CFG.METRICS.NULL_VAL = NULL_VAL | ||
|
||
############################## Training Configuration ############################## | ||
CFG.TRAIN = EasyDict() | ||
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
MODEL_ARCH.__name__, | ||
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)]) | ||
) | ||
CFG.TRAIN.LOSS = masked_mae | ||
# Optimizer settings | ||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = "Adam" | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
"lr": 0.01 | ||
} | ||
# Learning rate scheduler settings | ||
CFG.TRAIN.LR_SCHEDULER = EasyDict() | ||
# CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR" | ||
# CFG.TRAIN.LR_SCHEDULER.PARAM = { | ||
# "milestones": [1, 25, 50], | ||
# "gamma": 0.5 | ||
# } | ||
desc = load_dataset_desc(DATA_NAME) | ||
train_steps = math.ceil(desc["num_time_steps"] * TRAIN_VAL_TEST_RATIO[0]) | ||
CFG.TRAIN.LR_SCHEDULER.TYPE = "OneCycleLR" | ||
CFG.TRAIN.LR_SCHEDULER.PARAM = { | ||
"pct_start": 0.3, | ||
"epochs": NUM_EPOCHS, | ||
"steps_per_epoch": train_steps, | ||
"max_lr": CFG.TRAIN.OPTIM.PARAM["lr"] | ||
} | ||
CFG.TRAIN.CLIP_GRAD_PARAM = { | ||
'max_norm': 5.0 | ||
} | ||
# Train data loader settings | ||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 256 | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
|
||
############################## Validation Configuration ############################## | ||
CFG.VAL = EasyDict() | ||
CFG.VAL.INTERVAL = 1 | ||
CFG.VAL.DATA = EasyDict() | ||
CFG.VAL.DATA.BATCH_SIZE = 256 | ||
|
||
############################## Test Configuration ############################## | ||
CFG.TEST = EasyDict() | ||
CFG.TEST.INTERVAL = 1 | ||
CFG.TEST.DATA = EasyDict() | ||
CFG.TEST.DATA.BATCH_SIZE = 256 | ||
|
||
############################## Evaluation Configuration ############################## | ||
|
||
CFG.EVAL = EasyDict() | ||
|
||
# Evaluation parameters | ||
CFG.EVAL.HORIZONS = [12, 24, 48, 96] | ||
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .cats_arch import CATS |
Oops, something went wrong.