From 4c1edf1b0541d525b7938d3b313e20f004081743 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 5 Nov 2020 18:09:35 -0600 Subject: [PATCH] Allow get_scalar_constant_value to get shape values from constants --- tests/tensor/test_basic.py | 13 +++++++++++++ theano/tensor/basic.py | 3 +++ 2 files changed, 16 insertions(+) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 0160b24b6f..192e736486 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -5999,6 +5999,19 @@ def test_get_scalar_constant_value(self): v = tt.row() assert get_scalar_constant_value(v.shape[0]) == 1 + res = tt.get_scalar_constant_value(tt.as_tensor([10, 20]).shape[0]) + assert isinstance(res, np.ndarray) + assert 2 == res + + res = tt.get_scalar_constant_value( + 9 + tt.as_tensor([1.0]).shape[0], + elemwise=True, + only_process_constants=False, + max_recur=9, + ) + assert isinstance(res, np.ndarray) + assert 10 == res + def test_subtensor_of_constant(self): c = constant(rand(5)) for i in range(c.value.shape[0]): diff --git a/theano/tensor/basic.py b/theano/tensor/basic.py index b530ebc561..6e3a3ab777 100644 --- a/theano/tensor/basic.py +++ b/theano/tensor/basic.py @@ -662,6 +662,9 @@ def get_scalar_constant_value( if gp_broadcastable[idx]: return np.asarray(1) + if isinstance(grandparent, Constant): + return np.asarray(grandparent.data.shape[idx]) + raise NotScalarConstantError(v)