-
Notifications
You must be signed in to change notification settings - Fork 124
/
Copy pathsimple_tsf_dataset.py
143 lines (117 loc) · 6.96 KB
/
simple_tsf_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import inspect
import json
import logging
from typing import List
import numpy as np
from .base_dataset import BaseDataset
class TimeSeriesForecastingDataset(BaseDataset):
"""
A dataset class for time series forecasting problems, handling the loading, parsing, and partitioning
of time series data into training, validation, and testing sets based on provided ratios.
This class supports configurations where sequences may or may not overlap, accommodating scenarios
where time series data is drawn from continuous periods or distinct episodes, affecting how
the data is split into batches for model training or evaluation.
Attributes:
data_file_path (str): Path to the file containing the time series data.
description_file_path (str): Path to the JSON file containing the description of the dataset.
data (np.ndarray): The loaded time series data array, split according to the specified mode.
description (dict): Metadata about the dataset, such as shape and other properties.
"""
def __init__(self, dataset_name: str, train_val_test_ratio: List[float], mode: str, input_len: int, output_len: int, \
overlap: bool = False, logger: logging.Logger = None) -> None:
"""
Initializes the TimeSeriesForecastingDataset by setting up paths, loading data, and
preparing it according to the specified configurations.
Args:
dataset_name (str): The name of the dataset.
train_val_test_ratio (List[float]): Ratios for splitting the dataset into train, validation, and test sets.
Each value should be a float between 0 and 1, and their sum should ideally be 1.
mode (str): The operation mode of the dataset. Valid values are 'train', 'valid', or 'test'.
input_len (int): The length of the input sequence (number of historical points).
output_len (int): The length of the output sequence (number of future points to predict).
overlap (bool): Flag to determine if training/validation/test splits should overlap.
Defaults to False for strictly non-overlapping periods. Set to True to allow overlap.
logger (logging.Logger): logger.
Raises:
AssertionError: If `mode` is not one of ['train', 'valid', 'test'].
"""
assert mode in ['train', 'valid', 'test'], f"Invalid mode: {mode}. Must be one of ['train', 'valid', 'test']."
super().__init__(dataset_name, train_val_test_ratio, mode, input_len, output_len, overlap)
self.logger = logger
self.data_file_path = f'datasets/{dataset_name}/data.dat'
self.description_file_path = f'datasets/{dataset_name}/desc.json'
self.description = self._load_description()
self.data = self._load_data()
def _load_description(self) -> dict:
"""
Loads the description of the dataset from a JSON file.
Returns:
dict: A dictionary containing metadata about the dataset, such as its shape and other properties.
Raises:
FileNotFoundError: If the description file is not found.
json.JSONDecodeError: If there is an error decoding the JSON data.
"""
try:
with open(self.description_file_path, 'r') as f:
return json.load(f)
except FileNotFoundError as e:
raise FileNotFoundError(f'Description file not found: {self.description_file_path}') from e
except json.JSONDecodeError as e:
raise ValueError(f'Error decoding JSON file: {self.description_file_path}') from e
def _load_data(self) -> np.ndarray:
"""
Loads the time series data from a file and splits it according to the selected mode.
Returns:
np.ndarray: The data array for the specified mode (train, validation, or test).
Raises:
ValueError: If there is an issue with loading the data file or if the data shape is not as expected.
"""
try:
data = np.memmap(self.data_file_path, dtype='float32', mode='r', shape=tuple(self.description['shape']))
except (FileNotFoundError, ValueError) as e:
raise ValueError(f'Error loading data file: {self.data_file_path}') from e
total_len = len(data)
valid_len = int(total_len * self.train_val_test_ratio[1])
test_len = int(total_len * self.train_val_test_ratio[2])
train_len = total_len - valid_len - test_len
# Automatically configure the overlap parameter
minimal_len = self.input_len + self.output_len
if minimal_len > {'train': train_len, 'valid': valid_len, 'test': test_len}[self.mode]:
self.overlap = True # Enable overlap when the train, validation, or test set is too short
current_frame = inspect.currentframe()
file_name = inspect.getfile(current_frame)
line_number = current_frame.f_lineno - 7
dataset = {'train': 'Training', 'valid': 'Validation', 'test': 'Test'}[self.mode]
if self.logger is not None:
self.logger.info(f'{dataset} dataset is too short, enabling overlap. See details in {file_name} at line {line_number}.')
else:
print(f'{dataset} dataset is too short, enabling overlap. See details in {file_name} at line {line_number}.')
if self.mode == 'train':
offset = self.output_len if self.overlap else 0
return data[:train_len + offset].copy()
elif self.mode == 'valid':
offset_left = self.input_len - 1 if self.overlap else 0
offset_right = self.output_len if self.overlap else 0
return data[train_len - offset_left : train_len + valid_len + offset_right].copy()
else: # self.mode == 'test'
offset = self.input_len - 1 if self.overlap else 0
return data[train_len + valid_len - offset:].copy()
def __getitem__(self, index: int) -> dict:
"""
Retrieves a sample from the dataset at the specified index, considering both the input and output lengths.
Args:
index (int): The index of the desired sample in the dataset.
Returns:
dict: A dictionary containing 'inputs' and 'target', where both are slices of the dataset corresponding to
the historical input data and future prediction data, respectively.
"""
history_data = self.data[index:index + self.input_len]
future_data = self.data[index + self.input_len:index + self.input_len + self.output_len]
return {'inputs': history_data, 'target': future_data}
def __len__(self) -> int:
"""
Calculates the total number of samples available in the dataset, adjusted for the lengths of input and output sequences.
Returns:
int: The number of valid samples that can be drawn from the dataset, based on the configurations of input and output lengths.
"""
return len(self.data) - self.input_len - self.output_len + 1