Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JAX pulse -> signals issue #326

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v4
name: Install Python
with:
python-version: '3.8'
python-version: '3.10'
- name: Install Deps
run: pip install -U wheel
- name: Build Artifacts
Expand Down
19 changes: 12 additions & 7 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@

from qiskit_ibm_runtime.fake_provider import FakeQuito

try:
from jax import jit
except ImportError:
pass

from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import QiskitDynamicsTestCase, JAXTestBase


class TestPulseToSignals(QiskitDynamicsTestCase):
Expand Down Expand Up @@ -358,7 +363,7 @@ def test_barrier_instructions(self):
self.assertAllClose(sigs[1].samples, np.array([0.0, 0.0, 0.0, -0.5, -0.5, -0.5]))


class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase):
class TestPulseToSignalsJAXTransformations(JAXTestBase):
"""Tests InstructionToSignals class by using Jax."""

def setUp(self):
Expand Down Expand Up @@ -393,9 +398,9 @@ def jit_func_instruction_to_signals(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1)
jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1)
jit(jit_func_instruction_to_signals)(0.1)
self.jit_grad(jit_func_instruction_to_signals)(0.1)
jit_samples = jit(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

def test_pulse_types_combination_with_jax(self):
Expand Down Expand Up @@ -430,8 +435,8 @@ def jit_func_symbolic_pulse(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_symbolic_pulse)(0.1)
self.jit_grad_wrap(jit_func_symbolic_pulse)(0.1)
jit(jit_func_symbolic_pulse)(0.1)
self.jit_grad(jit_func_symbolic_pulse)(0.1)


@ddt
Expand Down
Loading