Skip to content

Commit

Permalink
fixed jax jitting tests for puse and added gassian pulse tests
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Feb 21, 2024
1 parent 2dba11b commit c456875
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,10 @@

from qiskit_ibm_runtime.fake_provider import FakeQuito

try:
from jax import jit
except ImportError:
pass

from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, JAXTestBase
from ..common import QiskitDynamicsTestCase, TestJaxBase


class TestPulseToSignals(QiskitDynamicsTestCase):
Expand Down Expand Up @@ -363,14 +358,17 @@ def test_barrier_instructions(self):
self.assertAllClose(sigs[1].samples, np.array([0.0, 0.0, 0.0, -0.5, -0.5, -0.5]))


class TestPulseToSignalsJAXTransformations(JAXTestBase):
class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase):
"""Tests InstructionToSignals class by using Jax."""

def setUp(self):
"""Set up gaussian waveform samples for comparison."""
self.constant_get_waveform_samples = (
pulse.Constant(duration=5, amp=0.1).get_waveform().samples
)
self.gaussian_get_waveform_samples = (
pulse.Gaussian(duration=5, amp=0.983, sigma=2.0).get_waveform().samples
)
self._dt = 0.222

def test_InstructionToSignals_with_JAX(self):
Expand All @@ -383,8 +381,8 @@ def jit_func_instruction_to_signals(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we can use only SymbolicPulse when jax-jitting
# bacause jax-jitting doesn't correspond to validate_parameters in qiskit.pulse.
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
instance = pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand All @@ -398,11 +396,27 @@ def jit_func_instruction_to_signals(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

jit(jit_func_instruction_to_signals)(0.1)
self.jit_grad(jit_func_instruction_to_signals)(0.1)
jit_samples = jit(jit_func_instruction_to_signals)(0.1)
def jit_func_gaussian_to_signals(amp):
pulse.Gaussian.disable_validation = True
instance = pulse.Gaussian(duration=5, amp=amp, sigma=2.0)
with pulse.build() as schedule:
pulse.play(instance, pulse.DriveChannel(0))

converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1)
jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

self.jit_wrap(jit_func_gaussian_to_signals)(0.983)
self.jit_grad_wrap(jit_func_gaussian_to_signals)(0.983)
jit_gaussian_samples = self.jit_wrap(jit_func_gaussian_to_signals)(0.983)
self.assertAllClose(
jit_gaussian_samples, self.gaussian_get_waveform_samples, atol=1e-7, rtol=1e-7
)

def test_pulse_types_combination_with_jax(self):
"""Test that converting schedule including some pulse types with Jax works well"""

Expand Down Expand Up @@ -435,8 +449,8 @@ def jit_func_symbolic_pulse(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

jit(jit_func_symbolic_pulse)(0.1)
self.jit_grad(jit_func_symbolic_pulse)(0.1)
self.jit_wrap(jit_func_symbolic_pulse)(0.1)
self.jit_grad_wrap(jit_func_symbolic_pulse)(0.1)


@ddt
Expand Down

0 comments on commit c456875

Please sign in to comment.