Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with carrier_freq being a JAX tracer if envelope is constant in Signal #247

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions qiskit_dynamics/signals/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ def __init__(

if isinstance(envelope, Array):
# if envelope is constant and the carrier is zero, this is a constant signal
if carrier_freq == 0.0:
self._is_constant = True
try:
# try block is for catching JAX tracer errors
if carrier_freq == 0.0:
self._is_constant = True
except Exception: # pylint: disable=broad-except
pass

if envelope.backend == "jax":
self._envelope = lambda t: envelope * jnp.ones_like(Array(t).data)
Expand Down
7 changes: 7 additions & 0 deletions releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
In the case that ``envelope`` is a constant, the :meth:`Signal.__init__` method has been updated
to not attempt to evaluate ``carrier_freq == 0.0`` if ``carrier_freq`` is a JAX tracer. In this
case, it is not possible to determine if the :class:`Signal` instance is constant. This resolves
an error that was being raised during JAX tracing if ``carrier_freq`` is abstract.
29 changes: 29 additions & 0 deletions test/dynamics/signals/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,35 @@ def eval_const(a):
jit_grad_eval = jit(grad(eval_const))
self.assertAllClose(jit_grad_eval(3.0), 1.0)

# validate that is_constant is being properly set
def eval_const_conditional(a):
a = Array(a)
sig = Signal(a)

if sig.is_constant:
return 5.0
else:
return 3.0

jit_eval = jit(eval_const_conditional)
self.assertAllClose(jit_eval(1.0), 5.0)

def test_jit_grad_carrier_freq_construct(self):
"""Test jit/gradding through a function that constructs a signal and takes carrier frequency
as an argument.
"""

def eval_sig(a, v, t):
a = Array(a)
v = Array(v)
return Array(Signal(a, v)(t)).data

jit_eval = jit(eval_sig)
self.assertAllClose(jit_eval(1.0, 1.0, 1.0), 1.0)

jit_grad_eval = jit(grad(eval_sig))
self.assertAllClose(jit_grad_eval(1.0, 1.0, 1.0), 1.0)

def test_signal_list_jit_eval(self):
"""Test jit-compilation of SignalList evaluation."""
call_jit = jit(lambda t: Array(self.signal_list(t)).data)
Expand Down