Skip to content

Commit

Permalink
Allow get_scalar_constant_value to get shape values from constants
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 6, 2020
1 parent b8916f5 commit 4c1edf1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5999,6 +5999,19 @@ def test_get_scalar_constant_value(self):
v = tt.row()
assert get_scalar_constant_value(v.shape[0]) == 1

res = tt.get_scalar_constant_value(tt.as_tensor([10, 20]).shape[0])
assert isinstance(res, np.ndarray)
assert 2 == res

res = tt.get_scalar_constant_value(
9 + tt.as_tensor([1.0]).shape[0],
elemwise=True,
only_process_constants=False,
max_recur=9,
)
assert isinstance(res, np.ndarray)
assert 10 == res

def test_subtensor_of_constant(self):
c = constant(rand(5))
for i in range(c.value.shape[0]):
Expand Down
3 changes: 3 additions & 0 deletions theano/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ def get_scalar_constant_value(
if gp_broadcastable[idx]:
return np.asarray(1)

if isinstance(grandparent, Constant):
return np.asarray(grandparent.data.shape[idx])

raise NotScalarConstantError(v)


Expand Down

0 comments on commit 4c1edf1

Please sign in to comment.