diff --git a/qiskit_dynamics/pulse/pulse_to_signals.py b/qiskit_dynamics/pulse/pulse_to_signals.py index 946f70935..d13d9bc86 100644 --- a/qiskit_dynamics/pulse/pulse_to_signals.py +++ b/qiskit_dynamics/pulse/pulse_to_signals.py @@ -360,7 +360,10 @@ def get_samples(pulse: SymbolicPulse) -> ArrayLike: f"Pulse parameter '{symbol.name}' is not defined for this instance. " "Please check your waveform expression is correct." ) from ex - return _lru_cache_expr(envelope, "jax" if any(isinstance(v, jax.core.Tracer) for v in pulse_params.values()) else "numpy")(*args) + return _lru_cache_expr( + envelope, + "jax" if any(isinstance(v, jax.core.Tracer) for v in pulse_params.values()) else "numpy", + )(*args) @functools.lru_cache(maxsize=None) diff --git a/test/dynamics/pulse/test_pulse_to_signals.py b/test/dynamics/pulse/test_pulse_to_signals.py index 88bd2d4c3..20434a37a 100644 --- a/test/dynamics/pulse/test_pulse_to_signals.py +++ b/test/dynamics/pulse/test_pulse_to_signals.py @@ -383,8 +383,8 @@ def jit_func_instruction_to_signals(amp): (1, sym.And(_time >= 0, _time <= _duration)), (0, True) ) valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 - # we can use only SymbolicPulse when jax-jitting - # bacause jax-jitting doesn't correspond to validate_parameters in qiskit.pulse. + # we need to set disable_validation True to enable jax-jitting. + pulse.SymbolicPulse.disable_validation = True instance = pulse.SymbolicPulse( pulse_type="Constant", duration=5, @@ -398,11 +398,27 @@ def jit_func_instruction_to_signals(amp): converter = InstructionToSignals(self._dt, carriers={"d0": 5}) return converter.get_signals(schedule)[0].samples + def jit_func_gaussian_to_signals(amp): + pulse.Gaussian.disable_validation = True + instance = pulse.Gaussian(duration=5, amp=amp, sigma=2.0) + with pulse.build() as schedule: + pulse.play(instance, pulse.DriveChannel(0)) + + converter = InstructionToSignals(self._dt, carriers={"d0": 5}) + return converter.get_signals(schedule)[0].samples + jit(jit_func_instruction_to_signals)(0.1) self.jit_grad(jit_func_instruction_to_signals)(0.1) jit_samples = jit(jit_func_instruction_to_signals)(0.1) self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7) + jit(jit_func_gaussian_to_signals)(0.983) + self.jit_grad(jit_func_gaussian_to_signals)(0.983) + jit_samples = jit(jit_func_gaussian_to_signals)(0.983) + self.assertAllClose( + jit_gaussian_samples, self.gaussian_get_waveform_samples, atol=1e-7, rtol=1e-7 + ) + def test_pulse_types_combination_with_jax(self): """Test that converting schedule including some pulse types with Jax works well"""