-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: remove output conversions (#792)
Closes partially #732 ### Summary of Changes Output conversions are the exact inversion of the input conversion, so there is no need to specify them again. Now, a neural network only takes an input conversion and a list of layers. This also gets rid of several errors that could occur if input and output conversions did not fit together. In a later PR, the input conversion will also be removed, since they mirror datasets. --------- Co-authored-by: megalinter-bot <[email protected]>
- Loading branch information
1 parent
dd8394b
commit 46f2f5d
Showing
25 changed files
with
576 additions
and
876 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
src/safeds/ml/nn/converters/_input_converter_image_to_column.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from safeds._config import _init_default_device | ||
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList | ||
from safeds.data.labeled.containers import ImageDataset | ||
from safeds.data.labeled.containers._image_dataset import _ColumnAsTensor | ||
from safeds.data.tabular.containers import Column | ||
from safeds.data.tabular.transformation import OneHotEncoder | ||
|
||
from ._input_converter_image import _InputConversionImage | ||
|
||
if TYPE_CHECKING: | ||
from torch import Tensor | ||
|
||
from safeds.data.image.containers import ImageList | ||
|
||
|
||
class InputConversionImageToColumn(_InputConversionImage): | ||
def _data_conversion_output( | ||
self, | ||
input_data: ImageList, | ||
output_data: Tensor, | ||
**kwargs: Any, | ||
) -> ImageDataset[Column]: | ||
import torch | ||
|
||
_init_default_device() | ||
|
||
if not isinstance(input_data, _SingleSizeImageList): | ||
raise ValueError("The given input ImageList contains images of different sizes.") # noqa: TRY004 | ||
if "column_name" not in kwargs or not isinstance(kwargs.get("column_name"), str): | ||
raise ValueError( | ||
"The column_name is not set. The data can only be converted if the column_name is provided as `str` in the kwargs.", | ||
) | ||
if "one_hot_encoder" not in kwargs or not isinstance(kwargs.get("one_hot_encoder"), OneHotEncoder): | ||
raise ValueError( | ||
"The one_hot_encoder is not set. The data can only be converted if the one_hot_encoder is provided as `OneHotEncoder` in the kwargs.", | ||
) | ||
one_hot_encoder: OneHotEncoder = kwargs["one_hot_encoder"] | ||
column_name: str = kwargs["column_name"] | ||
|
||
output = torch.zeros(len(input_data), len(one_hot_encoder._get_names_of_added_columns())) | ||
output[torch.arange(len(input_data)), output_data] = 1 | ||
|
||
im_dataset: ImageDataset[Column] = ImageDataset[Column].__new__(ImageDataset) | ||
im_dataset._output = _ColumnAsTensor._from_tensor(output, column_name, one_hot_encoder) | ||
im_dataset._shuffle_tensor_indices = torch.LongTensor(list(range(len(input_data)))) | ||
im_dataset._shuffle_after_epoch = False | ||
im_dataset._batch_size = 1 | ||
im_dataset._next_batch_index = 0 | ||
im_dataset._input_size = input_data.sizes[0] | ||
im_dataset._input = input_data | ||
return im_dataset |
36 changes: 36 additions & 0 deletions
36
src/safeds/ml/nn/converters/_input_converter_image_to_image.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from safeds._config import _init_default_device | ||
from safeds.data.image.containers import ImageList | ||
from safeds.data.image.containers._single_size_image_list import _SingleSizeImageList | ||
from safeds.data.labeled.containers import ImageDataset | ||
|
||
from ._input_converter_image import _InputConversionImage | ||
|
||
if TYPE_CHECKING: | ||
from torch import Tensor | ||
|
||
|
||
class InputConversionImageToImage(_InputConversionImage): | ||
def _data_conversion_output( | ||
self, | ||
input_data: ImageList, | ||
output_data: Tensor, | ||
**_kwargs: Any, | ||
) -> ImageDataset[ImageList]: | ||
import torch | ||
|
||
_init_default_device() | ||
|
||
if not isinstance(input_data, _SingleSizeImageList): | ||
raise ValueError("The given input ImageList contains images of different sizes.") # noqa: TRY004 | ||
|
||
return ImageDataset[ImageList]( | ||
input_data, | ||
_SingleSizeImageList._create_from_tensor( | ||
(output_data * 255).to(torch.uint8), | ||
list(range(output_data.size(dim=0))), | ||
), | ||
) |
Oops, something went wrong.