Skip to content

Commit

Permalink
Fix bug with carrier_freq being a JAX tracer if envelope is constant …
Browse files Browse the repository at this point in the history
…in Signal (#247)
  • Loading branch information
DanPuzzuoli authored Jul 21, 2023
1 parent ea941cb commit 221d09b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
8 changes: 6 additions & 2 deletions qiskit_dynamics/signals/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ def __init__(

if isinstance(envelope, Array):
# if envelope is constant and the carrier is zero, this is a constant signal
if carrier_freq == 0.0:
self._is_constant = True
try:
# try block is for catching JAX tracer errors
if carrier_freq == 0.0:
self._is_constant = True
except Exception: # pylint: disable=broad-except
pass

if envelope.backend == "jax":
self._envelope = lambda t: envelope * jnp.ones_like(Array(t).data)
Expand Down
7 changes: 7 additions & 0 deletions releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
In the case that ``envelope`` is a constant, the :meth:`Signal.__init__` method has been updated
to not attempt to evaluate ``carrier_freq == 0.0`` if ``carrier_freq`` is a JAX tracer. In this
case, it is not possible to determine if the :class:`Signal` instance is constant. This resolves
an error that was being raised during JAX tracing if ``carrier_freq`` is abstract.
29 changes: 29 additions & 0 deletions test/dynamics/signals/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,35 @@ def eval_const(a):
jit_grad_eval = jit(grad(eval_const))
self.assertAllClose(jit_grad_eval(3.0), 1.0)

# validate that is_constant is being properly set
def eval_const_conditional(a):
a = Array(a)
sig = Signal(a)

if sig.is_constant:
return 5.0
else:
return 3.0

jit_eval = jit(eval_const_conditional)
self.assertAllClose(jit_eval(1.0), 5.0)

def test_jit_grad_carrier_freq_construct(self):
"""Test jit/gradding through a function that constructs a signal and takes carrier frequency
as an argument.
"""

def eval_sig(a, v, t):
a = Array(a)
v = Array(v)
return Array(Signal(a, v)(t)).data

jit_eval = jit(eval_sig)
self.assertAllClose(jit_eval(1.0, 1.0, 1.0), 1.0)

jit_grad_eval = jit(grad(eval_sig))
self.assertAllClose(jit_grad_eval(1.0, 1.0, 1.0), 1.0)

def test_signal_list_jit_eval(self):
"""Test jit-compilation of SignalList evaluation."""
call_jit = jit(lambda t: Array(self.signal_list(t)).data)
Expand Down

0 comments on commit 221d09b

Please sign in to comment.