From c0fa3e874a960cdae489940790b59b79eb626ed1 Mon Sep 17 00:00:00 2001 From: DanPuzzuoli Date: Mon, 17 Jul 2023 14:39:51 -0400 Subject: [PATCH 1/4] modifying Signal.__init__ to not check if carrier_freq==0 if carrier_freq is a tracer --- qiskit_dynamics/signals/signals.py | 10 +++++++++- test/dynamics/signals/test_signals.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/qiskit_dynamics/signals/signals.py b/qiskit_dynamics/signals/signals.py index efd118993..37852b87a 100644 --- a/qiskit_dynamics/signals/signals.py +++ b/qiskit_dynamics/signals/signals.py @@ -26,6 +26,7 @@ try: import jax.numpy as jnp + import jax except ImportError: pass @@ -95,8 +96,15 @@ def __init__( envelope = Array(complex(envelope)) if isinstance(envelope, Array): + carrier_freq = Array(carrier_freq) + # if envelope is constant and the carrier is zero, this is a constant signal - if carrier_freq == 0.0: + if ( + not ( + carrier_freq.backend == "jax" and isinstance(carrier_freq.data, jax.core.Tracer) + ) + and carrier_freq == 0.0 + ): self._is_constant = True if envelope.backend == "jax": diff --git a/test/dynamics/signals/test_signals.py b/test/dynamics/signals/test_signals.py index e8991d34c..89b6ef4d3 100644 --- a/test/dynamics/signals/test_signals.py +++ b/test/dynamics/signals/test_signals.py @@ -933,6 +933,22 @@ def eval_const(a): jit_grad_eval = jit(grad(eval_const)) self.assertAllClose(jit_grad_eval(3.0), 1.0) + def test_jit_grad_carrier_freq_construct(self): + """Test jit/gradding through a function that constructs a signal and takes carrier frequency + as an argument. + """ + + def eval_sig(a, v, t): + a = Array(a) + v = Array(v) + return Array(Signal(a, v)(t)).data + + jit_eval = jit(eval_sig) + self.assertAllClose(jit_eval(1.0, 1.0, 1.0), 1.0) + + jit_grad_eval = jit(grad(eval_sig)) + self.assertAllClose(jit_grad_eval(1.0, 1.0, 1.0), 1.0) + def test_signal_list_jit_eval(self): """Test jit-compilation of SignalList evaluation.""" call_jit = jit(lambda t: Array(self.signal_list(t)).data) From 3ea43354d9a73d2d737051935988bc22e31dce9c Mon Sep 17 00:00:00 2001 From: DanPuzzuoli Date: Mon, 17 Jul 2023 14:44:13 -0400 Subject: [PATCH 2/4] adding bug fix release note --- releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml diff --git a/releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml b/releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml new file mode 100644 index 000000000..10aa674b7 --- /dev/null +++ b/releasenotes/notes/carrier-freq-0-19ad4362c874944f.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + In the case that ``envelope`` is a constant, the :meth:`Signal.__init__` method has been updated + to not attempt to evaluate ``carrier_freq == 0.0`` if ``carrier_freq`` is a JAX tracer. In this + case, it is not possible to determine if the :class:`Signal` instance is constant. This resolves + an error that was being raised during JAX tracing if ``carrier_freq`` is abstract. \ No newline at end of file From 0c58518a851dde2dab10e5cd1b0ea0ec2da0c19a Mon Sep 17 00:00:00 2001 From: DanPuzzuoli Date: Mon, 17 Jul 2023 14:59:46 -0400 Subject: [PATCH 3/4] moving carrier_freq==0.0 into try block --- qiskit_dynamics/signals/signals.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/qiskit_dynamics/signals/signals.py b/qiskit_dynamics/signals/signals.py index 37852b87a..71e3e8f03 100644 --- a/qiskit_dynamics/signals/signals.py +++ b/qiskit_dynamics/signals/signals.py @@ -96,16 +96,13 @@ def __init__( envelope = Array(complex(envelope)) if isinstance(envelope, Array): - carrier_freq = Array(carrier_freq) # if envelope is constant and the carrier is zero, this is a constant signal - if ( - not ( - carrier_freq.backend == "jax" and isinstance(carrier_freq.data, jax.core.Tracer) - ) - and carrier_freq == 0.0 - ): - self._is_constant = True + try: + if carrier_freq == 0.0: + self._is_constant = True + except: + pass if envelope.backend == "jax": self._envelope = lambda t: envelope * jnp.ones_like(Array(t).data) From b45731a686350bf7d3c3fa5dc2f5ee42c1f9b2f9 Mon Sep 17 00:00:00 2001 From: DanPuzzuoli Date: Tue, 18 Jul 2023 10:30:07 -0400 Subject: [PATCH 4/4] changing check to use a try block --- qiskit_dynamics/signals/signals.py | 5 ++--- test/dynamics/signals/test_signals.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/qiskit_dynamics/signals/signals.py b/qiskit_dynamics/signals/signals.py index 71e3e8f03..493702227 100644 --- a/qiskit_dynamics/signals/signals.py +++ b/qiskit_dynamics/signals/signals.py @@ -26,7 +26,6 @@ try: import jax.numpy as jnp - import jax except ImportError: pass @@ -96,12 +95,12 @@ def __init__( envelope = Array(complex(envelope)) if isinstance(envelope, Array): - # if envelope is constant and the carrier is zero, this is a constant signal try: + # try block is for catching JAX tracer errors if carrier_freq == 0.0: self._is_constant = True - except: + except Exception: # pylint: disable=broad-except pass if envelope.backend == "jax": diff --git a/test/dynamics/signals/test_signals.py b/test/dynamics/signals/test_signals.py index 89b6ef4d3..cd1b107a2 100644 --- a/test/dynamics/signals/test_signals.py +++ b/test/dynamics/signals/test_signals.py @@ -933,6 +933,19 @@ def eval_const(a): jit_grad_eval = jit(grad(eval_const)) self.assertAllClose(jit_grad_eval(3.0), 1.0) + # validate that is_constant is being properly set + def eval_const_conditional(a): + a = Array(a) + sig = Signal(a) + + if sig.is_constant: + return 5.0 + else: + return 3.0 + + jit_eval = jit(eval_const_conditional) + self.assertAllClose(jit_eval(1.0), 5.0) + def test_jit_grad_carrier_freq_construct(self): """Test jit/gradding through a function that constructs a signal and takes carrier frequency as an argument.