diff --git a/qiskit_dynamics/signals/signals.py b/qiskit_dynamics/signals/signals.py index efd118993..493702227 100644 --- a/qiskit_dynamics/signals/signals.py +++ b/qiskit_dynamics/signals/signals.py @@ -96,8 +96,12 @@ def __init__( if isinstance(envelope, Array): # if envelope is constant and the carrier is zero, this is a constant signal - if carrier_freq == 0.0: - self._is_constant = True + try: + # try block is for catching JAX tracer errors + if carrier_freq == 0.0: + self._is_constant = True + except Exception: # pylint: disable=broad-except + pass if envelope.backend == "jax": self._envelope = lambda t: envelope * jnp.ones_like(Array(t).data) diff --git a/releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml b/releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml new file mode 100644 index 000000000..10aa674b7 --- /dev/null +++ b/releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + In the case that ``envelope`` is a constant, the :meth:`Signal.__init__` method has been updated + to not attempt to evaluate ``carrier_freq == 0.0`` if ``carrier_freq`` is a JAX tracer. In this + case, it is not possible to determine if the :class:`Signal` instance is constant. This resolves + an error that was being raised during JAX tracing if ``carrier_freq`` is abstract. \ No newline at end of file diff --git a/test/dynamics/signals/test_signals.py b/test/dynamics/signals/test_signals.py index e8991d34c..cd1b107a2 100644 --- a/test/dynamics/signals/test_signals.py +++ b/test/dynamics/signals/test_signals.py @@ -933,6 +933,35 @@ def eval_const(a): jit_grad_eval = jit(grad(eval_const)) self.assertAllClose(jit_grad_eval(3.0), 1.0) + # validate that is_constant is being properly set + def eval_const_conditional(a): + a = Array(a) + sig = Signal(a) + + if sig.is_constant: + return 5.0 + else: + return 3.0 + + jit_eval = jit(eval_const_conditional) + self.assertAllClose(jit_eval(1.0), 5.0) + + def test_jit_grad_carrier_freq_construct(self): + """Test jit/gradding through a function that constructs a signal and takes carrier frequency + as an argument. + """ + + def eval_sig(a, v, t): + a = Array(a) + v = Array(v) + return Array(Signal(a, v)(t)).data + + jit_eval = jit(eval_sig) + self.assertAllClose(jit_eval(1.0, 1.0, 1.0), 1.0) + + jit_grad_eval = jit(grad(eval_sig)) + self.assertAllClose(jit_grad_eval(1.0, 1.0, 1.0), 1.0) + def test_signal_list_jit_eval(self): """Test jit-compilation of SignalList evaluation.""" call_jit = jit(lambda t: Array(self.signal_list(t)).data)