-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
many2one #36
base: master
Are you sure you want to change the base?
Conversation
@@ -161,7 +161,7 @@ def _padding_time_stamp_mark( | |||
padding_mark = get_time_mark(whole_time_stamp, 1, self.config.freq) | |||
return padding_mark | |||
|
|||
def validate(self, valid_data_loader, criterion): | |||
def validate(self, valid_data_loader, covariate, criterion): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest renaming this parameter to 'covariates', to be consistent with the corresponding parameter in ModelBase
: -covariate["exog"].shape[1] | ||
if covariate["exog"].shape[1] > 0 | ||
else None, | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We assume that target and output have both target series and exog in them? This looks weird as the exog should be passed in as a part of 'covariates'. Currently, the 'covariates' parameters is not actually used...
In this case, you should at least keep the parameters precise and clear: pass in 'series_dim: int' rather than 'covariates: Dict'.
@@ -194,7 +206,7 @@ def validate(self, valid_data_loader, criterion): | |||
return total_loss | |||
|
|||
def forecast_fit( | |||
self, train_valid_data: pd.DataFrame, train_ratio_in_tv: float | |||
self, train_valid_data: pd.DataFrame, covariate: dict, train_ratio_in_tv: float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the parameter name is different from that defined in the base class, please do a global check to see if the interface is consitent
@@ -203,6 +215,9 @@ def forecast_fit( | |||
:param train_ratio_in_tv: Represents the splitting ratio of the training set validation set. If it is equal to 1, it means that the validation set is not partitioned. | |||
:return: The fitted model object. | |||
""" | |||
train_valid_data = pd.concat( | |||
[train_valid_data, covariate["exog"]], axis=1 | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please consider the case when exog does not exist
@@ -421,6 +448,9 @@ def batch_forecast( | |||
|
|||
input_data = batch_maker.make_batch(self.config.batch_size, self.config.seq_len) | |||
input_np = input_data["input"] | |||
input_np = np.concatenate( | |||
(input_np, input_data["covariates"]["exog"]), axis=2 | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please consider the case when exog does not exist
Splits a DataFrame into target and remaining parts based on the target_channel configuration. | ||
|
||
:param df: The input DataFrame to be split. | ||
:param target_channel: Configuration for selecting target columns. It can include integers (positive or negative) and lists of two integers representing slices. If set to None, all columns are selected as target columns, and the remaining DataFrame is empty. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is too long.
and you'd better provide some examples on what 'target_channel' can be.
|
||
def parse_target_channel( | ||
target_channel: Optional[List], num_columns: int | ||
) -> List[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this function is independent of the outer scope, I suggest moving it to the global scope, while setting it to be private.
remaining_df = df.iloc[:, remaining_columns] | ||
else: | ||
# Create an empty DataFrame with the same index as df and zero columns | ||
remaining_df = pd.DataFrame(index=df.index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
df.iloc[:, []]
works as expected, so this if-else is not necessary.
raise IndexError( | ||
f"target_channel configuration error: Column index {item} is out of range (total columns: {num_columns})." | ||
) | ||
elif isinstance(item, list) and len(item) == 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should also allow a tuple with two elements
) | ||
|
||
# Remove duplicates while preserving order | ||
target_columns_unique = list(dict.fromkeys(target_columns)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict keeps order only after python3.6, shall we consider the compatibility with older python versions?
@qiu69 Do we apply end-to-end tests before submitting new PRs now? Please add some new test cases to the script whenever a new feature is added. |
@@ -0,0 +1,91 @@ | |||
import pandas as pd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please reformat this file
Attempt to include many other variables to predict the current multivariate data.