diff --git a/dwave/cloud/api/models.py b/dwave/cloud/api/models.py index f3a38c61..4cc8bfbf 100644 --- a/dwave/cloud/api/models.py +++ b/dwave/cloud/api/models.py @@ -16,7 +16,7 @@ from typing import List, Union, Optional, Dict, Any from typing_extensions import Annotated # backport for py37, py38 -from pydantic import BaseModel, RootModel, Field +from pydantic import BaseModel, RootModel, ConfigDict from pydantic.functional_validators import AfterValidator from dwave.cloud.api import constants @@ -27,7 +27,7 @@ AnyIncludingNumpy = Annotated[Any, AfterValidator(coerce_numpy_to_python)] -class SolverConfiguration(BaseModel): +class SolverCompleteConfiguration(BaseModel): id: str status: str description: str @@ -35,6 +35,18 @@ class SolverConfiguration(BaseModel): avg_load: float +class SolverFilteredConfiguration(BaseModel): + # no required fields, and no ignored fields + model_config = ConfigDict(extra='allow') + + +class SolverConfiguration(RootModel): + root: Union[SolverCompleteConfiguration, SolverFilteredConfiguration] + + def __getattr__(self, item): + return getattr(self.root, item) + + class ProblemInitialStatus(BaseModel): id: str type: constants.ProblemType diff --git a/dwave/cloud/api/resources.py b/dwave/cloud/api/resources.py index 99fa62a3..1c30406a 100644 --- a/dwave/cloud/api/resources.py +++ b/dwave/cloud/api/resources.py @@ -126,16 +126,18 @@ class Solvers(ResourceBase): client_class = SolverAPIClient @accepts(media_type='application/vnd.dwave.sapi.solver-definition-list+json', version='~=2.0') - def list_solvers(self) -> List[models.SolverConfiguration]: + def list_solvers(self, filter: Optional[str] = None) -> List[models.SolverConfiguration]: path = 'remote/' - response = self.session.get(path) + params = {'filter': filter} if filter is not None else None + response = self.session.get(path, params=params) solvers = response.json() return TypeAdapter(List[models.SolverConfiguration]).validate_python(solvers) @accepts(media_type='application/vnd.dwave.sapi.solver-definition+json', version='~=2.0') - def get_solver(self, solver_id: str) -> models.SolverConfiguration: + def get_solver(self, solver_id: str, filter: Optional[str] = None) -> models.SolverConfiguration: path = 'remote/{}'.format(solver_id) - response = self.session.get(path) + params = {'filter': filter} if filter is not None else None + response = self.session.get(path, params=params) solver = response.json() return models.SolverConfiguration.model_validate(solver) diff --git a/releasenotes/notes/add-solver-metadata-filtering-8e8ce0ad665091b3.yaml b/releasenotes/notes/add-solver-metadata-filtering-8e8ce0ad665091b3.yaml new file mode 100644 index 00000000..9177a2dd --- /dev/null +++ b/releasenotes/notes/add-solver-metadata-filtering-8e8ce0ad665091b3.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add support for retrieving filtered solver configuration to + ``dwave.cloud.api.resources.Solvers`` methods. diff --git a/tests/api/resources/test_solvers.py b/tests/api/resources/test_solvers.py index c1b7daf7..5c14365c 100644 --- a/tests/api/resources/test_solvers.py +++ b/tests/api/resources/test_solvers.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io import uuid import unittest -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse, parse_qsl +import json +import requests import requests_mock from dwave.cloud.api.resources import Solvers @@ -101,8 +104,102 @@ def test_invalid_token(self): api.list_solvers() +class FilteringTestsMixin: + additive_filter = 'none,+id' + subtractive_filter = 'all,-properties.couplers' + + # assume self.api is initialized + + def test_solver_collection_property_filtering(self): + with self.subTest('SAPI additive filtering'): + solvers = self.api.list_solvers(filter=self.additive_filter) + for item in solvers: + self.assertIsInstance(item.root, models.SolverFilteredConfiguration) + self.assertEqual(item.model_dump().keys(), {'id'}) + + with self.subTest('SAPI subtractive filtering'): + solvers = self.api.list_solvers(filter=self.subtractive_filter) + for item in solvers: + self.assertIsInstance(item.root, models.SolverCompleteConfiguration) + self.assertNotIn('couplers', item.properties) + if item.properties.get('category') == 'qpu': + self.assertIn('qubits', item.properties) + + def test_solver_property_filtering(self): + # find a QPU solver to query + qpu = [solver.id for solver in self.api.list_solvers() + if solver.properties.get('category') == 'qpu'] + solver_id = qpu.pop() + + with self.subTest('SAPI additive filtering'): + solver = self.api.get_solver(solver_id=solver_id, filter=self.additive_filter) + self.assertIsInstance(solver.root, models.SolverFilteredConfiguration) + self.assertEqual(solver.model_dump().keys(), {'id'}) + + with self.subTest('SAPI subtractive filtering'): + solver = self.api.get_solver(solver_id=solver_id, filter=self.subtractive_filter) + self.assertIsInstance(solver.root, models.SolverCompleteConfiguration) + self.assertNotIn('couplers', solver.properties) + self.assertIn('qubits', solver.properties) + + +class TestFiltering(FilteringTestsMixin, unittest.TestCase): + + token = str(uuid.uuid4()) + endpoint = 'http://test.com/path/' + + def setUp(self): + self.mocker = requests_mock.Mocker() + + self.solver_data = qpu_clique_solver_data(3) + self.solver_id = self.solver_data['id'] + + self.solver_uri = urljoin(self.endpoint, 'solvers/remote/{}'.format(self.solver_id)) + self.list_uri = urljoin(self.endpoint, 'solvers/remote/') + + self.additive_filter_data = {"id": self.solver_id} + + self.subtractive_filter_data = self.solver_data.copy() + del self.subtractive_filter_data['properties']['couplers'] + + def custom_matcher(request): + url = urlparse(request.path_url) + filter = dict(parse_qsl(url.query)).get('filter', '') + + if filter == self.additive_filter: + data = self.additive_filter_data + elif filter == self.subtractive_filter: + data = self.subtractive_filter_data + elif filter == '': + data = self.solver_data + else: + return None + + def reply(data): + resp = requests.Response() + resp.status_code = 200 + resp.raw = io.BytesIO(json.dumps(data).encode('ascii')) + return resp + + if url.path == urlparse(self.solver_uri).path: + return reply(data) + elif url.path == urlparse(self.list_uri).path: + return reply([data]) + else: + return None + + self.mocker.add_matcher(custom_matcher) + + self.mocker.start() + self.api = Solvers(token=self.token, endpoint=self.endpoint, version_strict_mode=False) + + def tearDown(self): + self.api.close() + self.mocker.stop() + + @unittest.skipUnless(config, "SAPI access not configured.") -class TestCloudSolvers(unittest.TestCase): +class TestLiveSolvers(FilteringTestsMixin, unittest.TestCase): @classmethod def setUpClass(cls):