Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/889-feat-allow-activation-functi…
Browse files Browse the repository at this point in the history
…on-parameter-for-forward-layer' into 889-feat-allow-activation-function-parameter-for-forward-layer
  • Loading branch information
Gerhardsa0 committed Jul 15, 2024
2 parents 6fa6740 + ed7abb0 commit 488af2c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
21 changes: 15 additions & 6 deletions src/safeds/ml/nn/layers/_forward_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ class ForwardLayer(Layer):
If the given activation function does not exist
"""

def __init__(self, neuron_count: int,
overwrite_activation_function: Literal["sigmoid", "relu", "softmax", "none", "notset"] = "notset"):
def __init__(
self,
neuron_count: int,
overwrite_activation_function: Literal["sigmoid", "relu", "softmax", "none", "notset"] = "notset",
):
_check_bounds("neuron_count", neuron_count, lower_bound=_ClosedBound(1))

self._input_size: int | None = None
Expand Down Expand Up @@ -98,11 +101,17 @@ def __eq__(self, other: object) -> bool:
return NotImplemented
if self is other:
return True
return (self._input_size == other._input_size and self._output_size == other._output_size
and self._activation_function == other._activation_function)
return (
self._input_size == other._input_size
and self._output_size == other._output_size
and self._activation_function == other._activation_function
)

def __sizeof__(self) -> int:
import sys

return (sys.getsizeof(self._input_size) + sys.getsizeof(self._output_size)
+ sys.getsizeof(self._activation_function))
return (
sys.getsizeof(self._input_size)
+ sys.getsizeof(self._output_size)
+ sys.getsizeof(self._activation_function)
)
1 change: 0 additions & 1 deletion tests/safeds/ml/nn/layers/test_forward_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from safeds.ml.nn.layers import ForwardLayer
from torch import nn


# TODO: Should be tested on a model, not a layer, since input size gets inferred
# @pytest.mark.parametrize(
# "input_size",
Expand Down

0 comments on commit 488af2c

Please sign in to comment.