Skip to content

Commit

Permalink
chore: add complex_mode and fix supported dtypes in backends
Browse files Browse the repository at this point in the history
  • Loading branch information
jshepherd01 committed Aug 23, 2023
1 parent fe5a626 commit 10f622f
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[JaxArray] = None,
complex_mode="jax",
) -> JaxArray:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/mxnet/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def softplus(
beta: Optional[Union[(int, float)]] = None,
threshold: Optional[Union[(int, float)]] = None,
out: Optional[None] = None,
complex_mode="jax",
) -> None:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/numpy/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[np.ndarray] = None,
complex_mode="jax",
) -> np.ndarray:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
1 change: 1 addition & 0 deletions ivy/functional/backends/paddle/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[paddle.Tensor] = None,
complex_mode="jax",
) -> paddle.Tensor:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
13 changes: 12 additions & 1 deletion ivy/functional/backends/tensorflow/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,17 @@ def softmax(


@with_supported_dtypes(
{"2.13.0 and below": ("float16", "bfloat16", "float32", "float64")}, backend_version
{
"2.13.0 and below": (
"float16",
"bfloat16",
"float32",
"float64",
"complex64",
"complex128",
)
},
backend_version,
)
def softplus(
x: Tensor,
Expand All @@ -56,6 +66,7 @@ def softplus(
beta: Optional[Union[int, float]] = None,
threshold: Optional[Union[int, float]] = None,
out: Optional[Tensor] = None,
complex_mode="jax",
) -> Tensor:
if beta is not None and beta != 1:
x_beta = x * beta
Expand Down
5 changes: 2 additions & 3 deletions ivy/functional/backends/torch/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ def softmax(
*,
axis: Optional[int] = None,
out: Optional[torch.Tensor] = None,
complex_mode="jax",
) -> torch.Tensor:
if axis is None:
axis = -1
return torch.nn.functional.softmax(x, axis)


@with_unsupported_dtypes(
{"2.0.1 and below": ("complex", "float16", "bfloat16")}, backend_version
)
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version)
def softplus(
x: torch.Tensor,
/,
Expand Down

0 comments on commit 10f622f

Please sign in to comment.