Skip to content

Commit

Permalink
Fix failing facade tests (#113)
Browse files Browse the repository at this point in the history
* fix failing test

* trigger CI

* let me try this

* nope I meant this

---------

Co-authored-by: Astral Cai <[email protected]>
  • Loading branch information
Shiro-Raven and astralcai authored Aug 26, 2024
1 parent 267d34e commit 3cac163
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 34 deletions.
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ MarkupSafe==2.1.1
networkx==2.6.0
ninja==1.10.2.3
numpy==1.22.4
packaging==21.3
packaging>=24
PennyLane==0.24.0
PennyLane-Lightning==0.24.0
Pygments==2.15.0
Expand Down
45 changes: 12 additions & 33 deletions tests/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,11 @@ def test_generate_samples_qpu_device(self, wires, histogram):

unique_outcomes1 = np.unique(sample1, axis=0)
unique_outcomes2 = np.unique(sample2, axis=0)
assert np.all(
unique_outcomes1 == unique_outcomes2
) # possible outcomes are the same
assert np.all(unique_outcomes1 == unique_outcomes2) # possible outcomes are the same

sorted_outcomes1 = np.sort(sample1, axis=0)
sorted_outcomes2 = np.sort(sample2, axis=0)
assert np.all(
sorted_outcomes1 == sorted_outcomes2
) # set of outcomes is the same
assert np.all(sorted_outcomes1 == sorted_outcomes2) # set of outcomes is the same


class TestDeviceIntegration:
Expand All @@ -71,7 +67,7 @@ def test_load_device(self, d):
"""Test that the device loads correctly"""
dev = qml.device(d, wires=2, shots=1024)
assert dev.num_wires == 2
assert dev.shots == 1024
assert dev.shots.total_shots == 1024
assert dev.short_name == d

@pytest.mark.parametrize("d", shortnames)
Expand Down Expand Up @@ -100,9 +96,7 @@ def test_failedcircuit(self, monkeypatch):
monkeypatch.setattr(
requests, "post", lambda url, timeout, data, headers: (url, data, headers)
)
monkeypatch.setattr(
ResourceManager, "handle_response", lambda self, response: None
)
monkeypatch.setattr(ResourceManager, "handle_response", lambda self, response: None)
monkeypatch.setattr(Job, "is_complete", False)
monkeypatch.setattr(Job, "is_failed", True)

Expand All @@ -117,17 +111,13 @@ def test_shots(self, shots, monkeypatch, mocker, tol):
monkeypatch.setattr(
requests, "post", lambda url, timeout, data, headers: (url, data, headers)
)
monkeypatch.setattr(
ResourceManager, "handle_response", lambda self, response: None
)
monkeypatch.setattr(ResourceManager, "handle_response", lambda self, response: None)
monkeypatch.setattr(Job, "is_complete", True)

def fake_response(self, resource_id=None, params=None):
"""Return fake response data"""
fake_json = {"0": 1}
setattr(
self.resource, "data", type("data", tuple(), {"value": fake_json})()
)
setattr(self.resource, "data", type("data", tuple(), {"value": fake_json})())

monkeypatch.setattr(ResourceManager, "get", fake_response)

Expand All @@ -143,26 +133,20 @@ def circuit():
circuit()
assert json.loads(spy.call_args[1]["data"])["shots"] == shots

@pytest.mark.parametrize(
"error_mitigation", [None, {"debias": True}, {"debias": False}]
)
@pytest.mark.parametrize("error_mitigation", [None, {"debias": True}, {"debias": False}])
def test_error_mitigation(self, error_mitigation, monkeypatch, mocker):
"""Test that shots are correctly specified when submitting a job to the API."""

monkeypatch.setattr(
requests, "post", lambda url, timeout, data, headers: (url, data, headers)
)
monkeypatch.setattr(
ResourceManager, "handle_response", lambda self, response: None
)
monkeypatch.setattr(ResourceManager, "handle_response", lambda self, response: None)
monkeypatch.setattr(Job, "is_complete", True)

def fake_response(self, resource_id=None, params=None):
"""Return fake response data"""
fake_json = {"0": 1}
setattr(
self.resource, "data", type("data", tuple(), {"value": fake_json})()
)
setattr(self.resource, "data", type("data", tuple(), {"value": fake_json})())

monkeypatch.setattr(ResourceManager, "get", fake_response)

Expand All @@ -183,10 +167,7 @@ def circuit():
spy = mocker.spy(requests, "post")
circuit()
if error_mitigation is not None:
assert (
json.loads(spy.call_args[1]["data"])["error_mitigation"]
== error_mitigation
)
assert json.loads(spy.call_args[1]["data"])["error_mitigation"] == error_mitigation
else:
with pytest.raises(KeyError, match="error_mitigation"):
json.loads(spy.call_args[1]["data"])["error_mitigation"]
Expand Down Expand Up @@ -233,7 +214,7 @@ def test_prob_no_results(self, d):
def test_probability(self):
"""Test that device.probability works."""
dev = qml.device("ionq.simulator", wires=2)
dev._samples = np.array([[1, 1], [1, 1], [0, 0], [0, 0]])
dev.target_device._samples = np.array([[1, 1], [1, 1], [0, 0], [0, 0]])
assert np.array_equal(dev.probability(shot_range=(0, 2)), [0, 0, 0, 1])

uniform_prob = [0.25] * 4
Expand All @@ -243,9 +224,7 @@ def test_probability(self):
mock_prob.return_value = uniform_prob
assert np.array_equal(dev.probability(), uniform_prob)

@pytest.mark.parametrize(
"backend", ["harmony", "aria-1", "aria-2", "forte-1", None]
)
@pytest.mark.parametrize("backend", ["harmony", "aria-1", "aria-2", "forte-1", None])
def test_backend_initialization(self, backend):
"""Test that the device initializes with the correct backend."""
dev = qml.device(
Expand Down

0 comments on commit 3cac163

Please sign in to comment.