Skip to content

Commit

Permalink
fixed jax jitting tests for puse and added gassian pulse tests
Browse files Browse the repository at this point in the history
  • Loading branch information
to24toro committed Feb 21, 2024
1 parent 2dba11b commit 967147f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
5 changes: 4 additions & 1 deletion qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,10 @@ def get_samples(pulse: SymbolicPulse) -> ArrayLike:
f"Pulse parameter '{symbol.name}' is not defined for this instance. "
"Please check your waveform expression is correct."
) from ex
return _lru_cache_expr(envelope, "jax" if any(isinstance(v, jax.core.Tracer) for v in pulse_params.values()) else "numpy")(*args)
return _lru_cache_expr(
envelope,
"jax" if any(isinstance(v, jax.core.Tracer) for v in pulse_params.values()) else "numpy",
)(*args)


@functools.lru_cache(maxsize=None)
Expand Down
20 changes: 18 additions & 2 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ def jit_func_instruction_to_signals(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we can use only SymbolicPulse when jax-jitting
# bacause jax-jitting doesn't correspond to validate_parameters in qiskit.pulse.
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
instance = pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand All @@ -398,11 +398,27 @@ def jit_func_instruction_to_signals(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

def jit_func_gaussian_to_signals(amp):
pulse.Gaussian.disable_validation = True
instance = pulse.Gaussian(duration=5, amp=amp, sigma=2.0)
with pulse.build() as schedule:
pulse.play(instance, pulse.DriveChannel(0))

converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

jit(jit_func_instruction_to_signals)(0.1)
self.jit_grad(jit_func_instruction_to_signals)(0.1)
jit_samples = jit(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

jit(jit_func_gaussian_to_signals)(0.983)
self.jit_grad(jit_func_gaussian_to_signals)(0.983)
jit_samples = jit(jit_func_gaussian_to_signals)(0.983)
self.assertAllClose(
jit_gaussian_samples, self.gaussian_get_waveform_samples, atol=1e-7, rtol=1e-7
)

def test_pulse_types_combination_with_jax(self):
"""Test that converting schedule including some pulse types with Jax works well"""

Expand Down

0 comments on commit 967147f

Please sign in to comment.