Skip to content

Commit

Permalink
try and fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
ihincks committed Nov 1, 2024
1 parent 697a9f4 commit 73c7d95
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions qiskit_ibm_runtime/execution_span/twirled_slice_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ def size(self) -> int:
def mask(self, pub_idx: int) -> npt.NDArray[np.bool_]:
twirled_shape, at_front, shape_sl, shots_sl = self._data_slices[pub_idx]
mask = np.zeros(twirled_shape, dtype=np.bool_)
mask.reshape(np.prod(twirled_shape[:-1]), twirled_shape[-1])[(shape_sl, shots_sl)] = True
mask.reshape((np.prod(twirled_shape[:-1]), twirled_shape[-1]))[(shape_sl, shots_sl)] = True

if at_front:
# if the first axis is over twirling samples, push them right before shots
ndim = len(twirled_shape)
mask = mask.transpose(*range(1, ndim - 1), 0, ndim - 1)
mask = mask.transpose((*range(1, ndim - 1), 0, ndim - 1))
twirled_shape = twirled_shape[1:-1] + twirled_shape[:1] + twirled_shape[-1:]

# merge twirling axis and shots axis before returning
return mask.reshape(*twirled_shape[:-2], math.prod(twirled_shape[-2:]))
return mask.reshape((*twirled_shape[:-2], math.prod(twirled_shape[-2:])))

def filter_by_pub(self, pub_idx: int | Iterable[int]) -> "TwirledSliceSpan":
pub_idx = {pub_idx} if isinstance(pub_idx, int) else set(pub_idx)
Expand Down
6 changes: 3 additions & 3 deletions test/unit/test_execution_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,8 @@ def test_mask(self):
"""Test the mask() method"""
# reminder: ((3, 1, 5), True, slice(1), slice(2, 4))
mask1 = np.zeros((3, 1, 5), dtype=bool)
mask1.reshape(3, 5)[:1, 2:4] = True
mask1 = mask1.transpose(1, 0, 2).reshape(1, 15)
mask1.reshape((3, 5))[:1, 2:4] = True
mask1 = mask1.transpose((1, 0, 2)).reshape((1, 15))
npt.assert_array_equal(self.span1.mask(2), mask1)

# reminder: ((1, 5, 2, 3), False, slice(3,9), slice(1, 3)),
Expand All @@ -301,7 +301,7 @@ def test_mask(self):
[[[0, 1, 1], [0, 0, 0]]],
]
]
mask2 = np.array(mask2, dtype=bool).reshape(1, 5, 6)
mask2 = np.array(mask2, dtype=bool).reshape((1, 5, 6))
npt.assert_array_equal(self.span2.mask(1), mask2)

@ddt.data(
Expand Down

0 comments on commit 73c7d95

Please sign in to comment.