Skip to content

Commit

Permalink
feat: remove window size and forecast horizon from converter
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-reimann committed May 21, 2024
1 parent 9efbe2b commit 62c06f6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 27 deletions.
19 changes: 5 additions & 14 deletions src/safeds/ml/nn/converters/_input_converter_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,15 @@


class InputConversionTimeSeries(InputConversion[TimeSeriesDataset, TimeSeriesDataset]):
"""
The input conversion for a neural network, defines the input parameters for the neural network.
Parameters
----------
window_size:
The size of the created windows
forecast_horizon:
The forecast horizon defines the future lag of the predicted values
"""
"""The input conversion for a neural network, defines the input parameters for the neural network."""

def __init__(
self,
window_size: int,
forecast_horizon: int,
*,
prediction_name: str = "prediction_nn",
) -> None:
self._window_size = window_size
self._forecast_horizon = forecast_horizon
self._window_size = 0
self._forecast_horizon = 0
self._first = True
self._target_name: str = ""
self._time_name: str = ""
Expand Down Expand Up @@ -101,6 +90,8 @@ def _data_conversion_output(

def _is_fit_data_valid(self, input_data: TimeSeriesDataset) -> bool:
if self._first:
self._window_size = input_data.window_size
self._forecast_horizon = input_data.forecast_horizon
self._time_name = input_data.time.name
self._feature_names = input_data.features.column_names
self._target_name = input_data.target.name
Expand Down
24 changes: 12 additions & 12 deletions tests/safeds/ml/nn/converters/test_input_converter_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def test_should_raise_if_is_fitted_is_set_correctly_lstm() -> None:
model = NeuralNetworkRegressor(
InputConversionTimeSeries(1, 1, prediction_name="predicted"),
InputConversionTimeSeries(prediction_name="predicted"),
[LSTMLayer(input_size=2, output_size=1)],
)
ts = Table.from_dict({"target": [1, 1, 1, 1], "time": [0, 0, 0, 0], "feat": [0, 0, 0, 0]}).to_time_series_dataset(
Expand All @@ -33,7 +33,7 @@ class TestEq:
@pytest.mark.parametrize(
("output_conversion_ts1", "output_conversion_ts2"),
[
(InputConversionTimeSeries(1, 1), InputConversionTimeSeries(1, 1)),
(InputConversionTimeSeries(), InputConversionTimeSeries()),
],
)
def test_should_be_equal(
Expand All @@ -47,12 +47,12 @@ def test_should_be_equal(
("output_conversion_ts1", "output_conversion_ts2"),
[
(
InputConversionTimeSeries(1, 1),
InputConversionTimeSeries(),
Table(),
),
(
InputConversionTimeSeries(1, 1, prediction_name="2"),
InputConversionTimeSeries(1, 1, prediction_name="1"),
InputConversionTimeSeries( prediction_name="2"),
InputConversionTimeSeries( prediction_name="1"),
),
],
)
Expand All @@ -68,7 +68,7 @@ class TestHash:
@pytest.mark.parametrize(
("output_conversion_ts1", "output_conversion_ts2"),
[
(InputConversionTimeSeries(1, 1), InputConversionTimeSeries(1, 1)),
(InputConversionTimeSeries(), InputConversionTimeSeries()),
],
)
def test_hash_should_be_equal(
Expand All @@ -79,9 +79,9 @@ def test_hash_should_be_equal(
assert hash(output_conversion_ts1) == hash(output_conversion_ts2)

def test_hash_should_not_be_equal(self) -> None:
output_conversion_ts1 = InputConversionTimeSeries(1, 1, prediction_name="1")
output_conversion_ts2 = InputConversionTimeSeries(1, 1, prediction_name="2")
output_conversion_ts3 = InputConversionTimeSeries(1, 1, prediction_name="3")
output_conversion_ts1 = InputConversionTimeSeries(prediction_name="1")
output_conversion_ts2 = InputConversionTimeSeries(prediction_name="2")
output_conversion_ts3 = InputConversionTimeSeries(prediction_name="3")
assert hash(output_conversion_ts1) != hash(output_conversion_ts3)
assert hash(output_conversion_ts2) != hash(output_conversion_ts1)
assert hash(output_conversion_ts3) != hash(output_conversion_ts2)
Expand All @@ -91,9 +91,9 @@ class TestSizeOf:
@pytest.mark.parametrize(
"output_conversion_ts",
[
InputConversionTimeSeries(1, 1, prediction_name="1"),
InputConversionTimeSeries(1, 1, prediction_name="2"),
InputConversionTimeSeries(1, 1, prediction_name="3"),
InputConversionTimeSeries(prediction_name="1"),
InputConversionTimeSeries(prediction_name="2"),
InputConversionTimeSeries(prediction_name="3"),
],
)
def test_should_size_be_greater_than_normal_object(
Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/ml/nn/test_lstm_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_lstm_model(device: Device) -> None:
train_table, test_table = table.split_rows(0.8)

model = NeuralNetworkRegressor(
InputConversionTimeSeries(window_size=7, forecast_horizon=12, prediction_name="predicted"),
InputConversionTimeSeries(prediction_name="predicted"),
[ForwardLayer(input_size=7, output_size=256), LSTMLayer(input_size=256, output_size=1)],
)
trained_model = model.fit(
Expand Down

0 comments on commit 62c06f6

Please sign in to comment.