Skip to content

Commit

Permalink
MarkovChain.get_index: Accept array_like of values
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Mar 28, 2016
1 parent 5bc78d5 commit bfc91bb
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
42 changes: 39 additions & 3 deletions quantecon/markov/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,43 @@ def state_values(self, values):

def get_index(self, value):
"""
Return the index of the given value in state_values.
Return the index (or indices) of the given value (or values) in
`state_values`.
Parameters
----------
value
Value(s) to get the index (indices) for.
Returns
-------
idx : int or ndarray(int)
Index of `value` if `value` is a single state value; array
of indices if `value` is an array_like of state values.
"""
if self.state_values is None:
state_values_ndim = 1
else:
state_values_ndim = self.state_values.ndim

values = np.asarray(value)

if values.ndim <= state_values_ndim - 1:
return self._get_index(value)
elif values.ndim == state_values_ndim: # array of values
k = values.shape[0]
idx = np.empty(k, dtype=int)
for i in range(k):
idx[i] = self._get_index(values[i])
return idx
else:
raise ValueError('invalid value')


def _get_index(self, value):
"""
Return the index of the given value in `state_values`.
Parameters
----------
Expand All @@ -239,10 +275,10 @@ def get_index(self, value):
Returns
-------
idx : int
Index of the value.
Index of `value`.
"""
error_msg = 'value {0} not found'.format(repr(value))
error_msg = 'value {0} not found'.format(value)

if self.state_values is None:
if isinstance(value, numbers.Integral) and (0 <= value < self.n):
Expand Down
5 changes: 5 additions & 0 deletions quantecon/markov/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,15 +483,20 @@ def test_get_index():
eq_(mc.get_index(0), 0)
eq_(mc.get_index(1), 1)
assert_raises(ValueError, mc.get_index, 2)
assert_array_equal(mc.get_index([1, 0]), [1, 0])
assert_raises(ValueError, mc.get_index, [[1]])

mc.state_values = [1, 2]
eq_(mc.get_index(1), 0)
eq_(mc.get_index(2), 1)
assert_raises(ValueError, mc.get_index, 0)
assert_array_equal(mc.get_index([2, 1]), [1, 0])
assert_raises(ValueError, mc.get_index, [[1]])

mc.state_values = [[1, 2], [3, 4]]
eq_(mc.get_index([1, 2]), 0)
assert_raises(ValueError, mc.get_index, 1)
assert_array_equal(mc.get_index([[3, 4], [1, 2]]), [1, 0])


@raises(ValueError)
Expand Down

0 comments on commit bfc91bb

Please sign in to comment.