Skip to content

Commit

Permalink
fix: torch frontend item() working with tf.function (#28729)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattbarrett98 authored Apr 9, 2024
1 parent 7a0af13 commit decb87f
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dtype(self):

@property
def shape(self):
return Size(ivy.shape(self.ivy_array, as_array=True))
return Size(self.ivy_array.shape)

@property
def real(self):
Expand Down Expand Up @@ -1431,7 +1431,12 @@ def bitwise_xor_(self, other):

def item(self):
if all(dim == 1 for dim in self.shape):
return self.ivy_array.to_scalar()
if ivy.current_backend_str() == "tensorflow":
import tensorflow as tf

return tf.squeeze(self.ivy_array.data)
else:
return self.ivy_array.to_scalar()
else:
raise ValueError(
"only one element tensors can be converted to Python scalars"
Expand Down

0 comments on commit decb87f

Please sign in to comment.