From bfc91bb69443e2e234643ecc26fb2c05fe24f320 Mon Sep 17 00:00:00 2001 From: Daisuke Oyama Date: Tue, 29 Mar 2016 00:20:05 +0900 Subject: [PATCH] MarkovChain.get_index: Accept array_like of values --- quantecon/markov/core.py | 42 ++++++++++++++++++++++++++--- quantecon/markov/tests/test_core.py | 5 ++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/quantecon/markov/core.py b/quantecon/markov/core.py index c25cc3cac..2a0d49d37 100644 --- a/quantecon/markov/core.py +++ b/quantecon/markov/core.py @@ -229,7 +229,43 @@ def state_values(self, values): def get_index(self, value): """ - Return the index of the given value in state_values. + Return the index (or indices) of the given value (or values) in + `state_values`. + + Parameters + ---------- + value + Value(s) to get the index (indices) for. + + Returns + ------- + idx : int or ndarray(int) + Index of `value` if `value` is a single state value; array + of indices if `value` is an array_like of state values. + + """ + if self.state_values is None: + state_values_ndim = 1 + else: + state_values_ndim = self.state_values.ndim + + values = np.asarray(value) + + if values.ndim <= state_values_ndim - 1: + return self._get_index(value) + elif values.ndim == state_values_ndim: # array of values + k = values.shape[0] + idx = np.empty(k, dtype=int) + for i in range(k): + idx[i] = self._get_index(values[i]) + return idx + else: + raise ValueError('invalid value') + + + def _get_index(self, value): + """ + Return the index of the given value in `state_values`. Parameters ---------- @@ -239,10 +275,10 @@ def get_index(self, value): Returns ------- idx : int - Index of the value. + Index of `value`. """ - error_msg = 'value {0} not found'.format(repr(value)) + error_msg = 'value {0} not found'.format(value) if self.state_values is None: if isinstance(value, numbers.Integral) and (0 <= value < self.n): diff --git a/quantecon/markov/tests/test_core.py b/quantecon/markov/tests/test_core.py index 512e3c4b8..44410b593 100644 --- a/quantecon/markov/tests/test_core.py +++ b/quantecon/markov/tests/test_core.py @@ -483,15 +483,20 @@ def test_get_index(): eq_(mc.get_index(0), 0) eq_(mc.get_index(1), 1) assert_raises(ValueError, mc.get_index, 2) + assert_array_equal(mc.get_index([1, 0]), [1, 0]) + assert_raises(ValueError, mc.get_index, [[1]]) mc.state_values = [1, 2] eq_(mc.get_index(1), 0) eq_(mc.get_index(2), 1) assert_raises(ValueError, mc.get_index, 0) + assert_array_equal(mc.get_index([2, 1]), [1, 0]) + assert_raises(ValueError, mc.get_index, [[1]]) mc.state_values = [[1, 2], [3, 4]] eq_(mc.get_index([1, 2]), 0) assert_raises(ValueError, mc.get_index, 1) + assert_array_equal(mc.get_index([[3, 4], [1, 2]]), [1, 0]) @raises(ValueError)