diff --git a/src/braket/circuits/basis_state.py b/src/braket/circuits/basis_state.py index 814444e75..b6ce11bc8 100644 --- a/src/braket/circuits/basis_state.py +++ b/src/braket/circuits/basis_state.py @@ -33,6 +33,18 @@ def __iter__(self): def __eq__(self, other): return self.state == other.state + def __bool__(self): + return any(self.state) + + def __str__(self): + return self.as_string + + def __repr__(self): + return f'BasisState("{self.as_string}")' + + def __getitem__(self, item): + return BasisState(self.state[item]) + BasisStateInput = Union[int, list[int], str, BasisState] diff --git a/test/unit_tests/braket/circuits/test_basis_state.py b/test/unit_tests/braket/circuits/test_basis_state.py index 023494fae..166e7c8fc 100644 --- a/test/unit_tests/braket/circuits/test_basis_state.py +++ b/test/unit_tests/braket/circuits/test_basis_state.py @@ -51,6 +51,58 @@ ), ) def test_as_props(basis_state_input, size, as_tuple, as_int, as_string): - assert BasisState(basis_state_input, size).as_tuple == as_tuple - assert BasisState(basis_state_input, size).as_int == as_int - assert BasisState(basis_state_input, size).as_string == as_string + basis_state = BasisState(basis_state_input, size) + assert basis_state.as_tuple == as_tuple + assert basis_state.as_int == as_int + assert basis_state.as_string == as_string == str(basis_state) + assert repr(basis_state) == f'BasisState("{as_string}")' + + +@pytest.mark.parametrize( + "basis_state_input, index, substate_input", + ( + ( + "1001", + slice(None), + "1001", + ), + ( + "1001", + 3, + "1", + ), + ( + "1010", + slice(None, None, 2), + "11", + ), + ( + "1010", + slice(1, None, 2), + "00", + ), + ( + "1010", + slice(None, -2), + "10", + ), + ( + "1010", + -1, + "0", + ), + ), +) +def test_indexing(basis_state_input, index, substate_input): + assert BasisState(basis_state_input)[index] == BasisState(substate_input) + + +def test_bool(): + assert all( + [ + BasisState("100"), + BasisState("111"), + BasisState("1"), + ] + ) + assert not BasisState("0")