Skip to content

Commit

Permalink
Update EstimatorV2 result decoder (Qiskit#1461)
Browse files Browse the repository at this point in the history
* Update EstimatorV2 result decoder

* Remove unused code

* Remove trailing spaces

* Linting
  • Loading branch information
mberna authored Mar 4, 2024
1 parent 11f7793 commit e722c97
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 29 deletions.
9 changes: 1 addition & 8 deletions qiskit_ibm_runtime/runtime_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,7 @@ def result( # pylint: disable=arguments-differ
self._api_client.job_results(job_id=self.job_id())
)

version_param = {}
# TODO: Remove getting/setting version once it's in result metadata
if _decoder.__name__ == EstimatorResultDecoder.__name__:
if not self._version:
self._version = self.inputs.get("version", 1)
version_param["version"] = self._version

return _decoder.decode(result_raw, **version_param) if result_raw else None # type: ignore
return _decoder.decode(result_raw) if result_raw else None # type: ignore

def cancel(self) -> None:
"""Cancel the job.
Expand Down
32 changes: 11 additions & 21 deletions qiskit_ibm_runtime/utils/estimator_result_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

"""Estimator result decoder."""

from typing import Dict
from typing import Dict, Union
import numpy as np

from qiskit.primitives import EstimatorResult
from qiskit.primitives.containers import PrimitiveResult, make_data_bin, PubResult
from qiskit.primitives.containers import PrimitiveResult

from .result_decoder import ResultDecoder

Expand All @@ -26,24 +26,14 @@ class EstimatorResultDecoder(ResultDecoder):

@classmethod
def decode( # type: ignore # pylint: disable=arguments-differ
cls, raw_result: str, version: int
) -> EstimatorResult:
cls, raw_result: str
) -> Union[EstimatorResult, PrimitiveResult]:
"""Convert the result to EstimatorResult."""
decoded: Dict = super().decode(raw_result)
if version == 2:
out_results = []
for val, meta in zip(decoded["values"], decoded["metadata"]):
if not isinstance(val, np.ndarray):
val = np.asarray(val)
data_bin_cls = make_data_bin(
[("evs", np.ndarray), ("stds", np.ndarray)], shape=val.shape
)
out_results.append(
PubResult(data=data_bin_cls(val, meta.pop("standard_error")), metadata=meta)
)
# TODO what metadata should be passed in to PrimitiveResult?
return PrimitiveResult(out_results, metadata=decoded["metadata"])
return EstimatorResult(
values=np.asarray(decoded["values"]),
metadata=decoded["metadata"],
)
if isinstance(decoded, PrimitiveResult):
return decoded
else:
return EstimatorResult(
values=np.asarray(decoded["values"]),
metadata=decoded["metadata"],
)

0 comments on commit e722c97

Please sign in to comment.