Skip to content

Commit

Permalink
fix: Conversion of tabular dataset to tensors (#757)
Browse files Browse the repository at this point in the history
### Summary of Changes

Fixed conversion of tabular dataset to tensors and associated tests.

---------

Co-authored-by: megalinter-bot <[email protected]>
Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
3 people authored May 13, 2024
1 parent 92622fb commit 9e40b65
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 15 deletions.
5 changes: 4 additions & 1 deletion src/safeds/data/labeled/containers/_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ def __init__(self, table: Table) -> None:
_init_default_device()

self._column_names = table.column_names
self._tensor = torch.Tensor(table._data_frame.to_torch()).to(_get_device())
if table.number_of_rows == 0:
self._tensor = torch.empty((0, table.number_of_columns), dtype=torch.float32).to(_get_device())
else:
self._tensor = table._data_frame.to_torch().to(_get_device())

if not torch.all(self._tensor.sum(dim=1) == torch.ones(self._tensor.size(dim=0))):
raise ValueError(
Expand Down
8 changes: 4 additions & 4 deletions src/safeds/data/labeled/containers/_time_series_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def _create_dataset(features: torch.Tensor, target: torch.Tensor) -> Dataset:

class _CustomDataset(Dataset):
def __init__(self, features_dataset: torch.Tensor, target_dataset: torch.Tensor):
self.X = features_dataset
self.Y = target_dataset.unsqueeze(-1)
self.X = features_dataset.float()
self.Y = target_dataset.unsqueeze(-1).float()
self.len = self.X.shape[0]

def __getitem__(self, item: int) -> tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -341,8 +341,8 @@ def _create_dataset_predict(features: torch.Tensor) -> Dataset:
_init_default_device()

class _CustomDataset(Dataset):
def __init__(self, features: torch.Tensor):
self.X = features
def __init__(self, datas: torch.Tensor):
self.X = datas.float()
self.len = self.X.shape[0]

def __getitem__(self, item: int) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/ml/nn/_output_conversion_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _data_conversion(self, input_data: TimeSeriesDataset, output_data: Tensor, *
window_size: int = kwargs["window_size"]
forecast_horizon: int = kwargs["forecast_horizon"]
input_data_table = input_data.to_table()
input_data_table = input_data_table.slice_rows(window_size + forecast_horizon)
input_data_table = input_data_table.slice_rows(start=window_size + forecast_horizon)

return input_data_table.add_columns(
[Column(self._prediction_name, output_data.tolist())],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_should_create_dataloader_invalid(
1,
0,
OutOfBoundsError,
r"forecast_horizon \(=0\) is not inside \[1, \u221e\).",
None,
),
(
Table(
Expand All @@ -189,7 +189,7 @@ def test_should_create_dataloader_invalid(
0,
1,
OutOfBoundsError,
r"window_size \(=0\) is not inside \[1, \u221e\).",
None,
),
],
ids=[
Expand All @@ -204,7 +204,7 @@ def test_should_create_dataloader_predict_invalid(
window_size: int,
forecast_horizon: int,
error_type: type[ValueError],
error_msg: str,
error_msg: str | None,
device: Device,
) -> None:
configure_test_with_device(device)
Expand Down
8 changes: 2 additions & 6 deletions tests/safeds/ml/nn/test_forward_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ def test_forward_model(device: Device) -> None:
path=resolve_resource_path(_inflation_path),
)
table_1 = table_1.remove_columns(["date"])
table_2 = table_1.slice_rows(length=table_1.number_of_rows - 14)
table_2 = table_2.add_columns(
[
table_1.slice_rows(start=14).get_column("value").rename("target"),
]
)
table_2 = table_1.slice_rows(start=0, length=table_1.number_of_rows - 14)
table_2 = table_2.add_columns([(table_1.slice_rows(start=14)).get_column("value").rename("target")])
train_table, test_table = table_2.split_rows(0.8)

ss = StandardScaler()
Expand Down

0 comments on commit 9e40b65

Please sign in to comment.