Skip to content

Commit

Permalink
fix: torch frontend bernoulli functions and methods (#28815)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong authored Sep 4, 2024
1 parent 9e07e83 commit f0d02dc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions ivy/functional/frontends/torch/random_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
"torch",
)
@to_ivy_arrays_and_back
def bernoulli(input, p, *, generator=None, out=None):
def bernoulli(input, *, generator=None, out=None):
seed = generator.initial_seed() if generator is not None else None
return ivy.bernoulli(p, logits=input, seed=seed, out=out)
return ivy.bernoulli(input, seed=seed, out=out)


@to_ivy_arrays_and_back
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,14 +1249,14 @@ def dot(self, tensor):
return torch_frontend.dot(self, tensor)

@with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch")
def bernoulli(self, p, *, generator=None, out=None):
def bernoulli(self, *, generator=None, out=None):
return torch_frontend.bernoulli(
self._ivy_array, p, generator=generator, out=out
self._ivy_array, generator=generator, out=out
)

@with_supported_dtypes({"2.2 and below": ("float32", "float64")}, "torch")
def bernoulli_(self, p, *, generator=None, out=None):
self.ivy_array = self.bernoulli(p, generator=generator, out=out).ivy_array
self.ivy_array = torch_frontend.bernoulli(torch_frontend.full(self.shape, p, dtype=torch_frontend.float64), generator=generator, out=out).ivy_array
return self

def numel(self):
Expand Down

0 comments on commit f0d02dc

Please sign in to comment.