Skip to content

Commit

Permalink
Merge remote-tracking branch 'Dan/fix-symbolic-pulse-jax' into fix/pu…
Browse files Browse the repository at this point in the history
…lse_to_signal
  • Loading branch information
to24toro committed Feb 21, 2024
2 parents 84c17f5 + 6749ec1 commit 2dba11b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v4
name: Install Python
with:
python-version: '3.8'
python-version: '3.10'
- name: Install Deps
run: pip install -U wheel
- name: Build Artifacts
Expand Down
19 changes: 12 additions & 7 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@

from qiskit_ibm_runtime.fake_provider import FakeQuito

try:
from jax import jit
except ImportError:
pass

from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import QiskitDynamicsTestCase, JAXTestBase


class TestPulseToSignals(QiskitDynamicsTestCase):
Expand Down Expand Up @@ -358,7 +363,7 @@ def test_barrier_instructions(self):
self.assertAllClose(sigs[1].samples, np.array([0.0, 0.0, 0.0, -0.5, -0.5, -0.5]))


class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase):
class TestPulseToSignalsJAXTransformations(JAXTestBase):
"""Tests InstructionToSignals class by using Jax."""

def setUp(self):
Expand Down Expand Up @@ -393,9 +398,9 @@ def jit_func_instruction_to_signals(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1)
jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1)
jit(jit_func_instruction_to_signals)(0.1)
self.jit_grad(jit_func_instruction_to_signals)(0.1)
jit_samples = jit(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

def test_pulse_types_combination_with_jax(self):
Expand Down Expand Up @@ -430,8 +435,8 @@ def jit_func_symbolic_pulse(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_symbolic_pulse)(0.1)
self.jit_grad_wrap(jit_func_symbolic_pulse)(0.1)
jit(jit_func_symbolic_pulse)(0.1)
self.jit_grad(jit_func_symbolic_pulse)(0.1)


@ddt
Expand Down

0 comments on commit 2dba11b

Please sign in to comment.