diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0160b24b6f..192e736486 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -5999,6 +5999,19 @@ def test_get_scalar_constant_value(self): v = tt.row() assert get_scalar_constant_value(v.shape[0]) == 1 + res = tt.get_scalar_constant_value(tt.as_tensor([10, 20]).shape[0]) + assert isinstance(res, np.ndarray) + assert 2 == res + + res = tt.get_scalar_constant_value( + 9 + tt.as_tensor([1.0]).shape[0], + elemwise=True, + only_process_constants=False, + max_recur=9, + ) + assert isinstance(res, np.ndarray) + assert 10 == res + def test_subtensor_of_constant(self): c = constant(rand(5)) for i in range(c.value.shape[0]): diff --git a/theano/tensor/basic.py b/theano/tensor/basic.py index b530ebc561..6e3a3ab777 100644 --- a/theano/tensor/basic.py +++ b/theano/tensor/basic.py @@ -662,6 +662,9 @@ def get_scalar_constant_value( if gp_broadcastable[idx]: return np.asarray(1) + if isinstance(grandparent, Constant): + return np.asarray(grandparent.data.shape[idx]) + raise NotScalarConstantError(v)