From f0d02dcb8fc150735772b52d98fdeb77faac8a77 Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Wed, 4 Sep 2024 21:10:02 +0100 Subject: [PATCH] fix: torch frontend bernoulli functions and methods (#28815) --- ivy/functional/frontends/torch/random_sampling.py | 4 ++-- ivy/functional/frontends/torch/tensor.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ivy/functional/frontends/torch/random_sampling.py b/ivy/functional/frontends/torch/random_sampling.py index 105e983209e64..bc6d381665985 100644 --- a/ivy/functional/frontends/torch/random_sampling.py +++ b/ivy/functional/frontends/torch/random_sampling.py @@ -13,9 +13,9 @@ "torch", ) @to_ivy_arrays_and_back -def bernoulli(input, p, *, generator=None, out=None): +def bernoulli(input, *, generator=None, out=None): seed = generator.initial_seed() if generator is not None else None - return ivy.bernoulli(p, logits=input, seed=seed, out=out) + return ivy.bernoulli(input, seed=seed, out=out) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 01d465e565549..2bb2e106fa6b7 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1249,14 +1249,14 @@ def dot(self, tensor): return torch_frontend.dot(self, tensor) @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") - def bernoulli(self, p, *, generator=None, out=None): + def bernoulli(self, *, generator=None, out=None): return torch_frontend.bernoulli( - self._ivy_array, p, generator=generator, out=out + self._ivy_array, generator=generator, out=out ) @with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch") def bernoulli_(self, p, *, generator=None, out=None): - self.ivy_array = self.bernoulli(p, generator=generator, out=out).ivy_array + self.ivy_array = torch_frontend.bernoulli(torch_frontend.full(self.shape, p, dtype=torch_frontend.float64), generator=generator, out=out).ivy_array return self def numel(self):