Skip to content

Commit

Permalink
Merge branch 'main' into rm-qiskit-links
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPuzzuoli authored Feb 23, 2024
2 parents 33c1d11 + cd2f0fb commit 5b99ba8
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- uses: actions/setup-python@v4
name: Install Python
with:
python-version: '3.8'
python-version: '3.10'
- name: Install Deps
run: pip install -U wheel
- name: Build Artifacts
Expand Down
3 changes: 3 additions & 0 deletions docs/tutorials/optimizing_pulse_sequence.rst
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ entry on :ref:`JAX-compatible pulse schedules <how-to use pulse schedules for ja
)
)

# we need to set disable_validation True to enable jax-jitting.
pulse.ScalableSymbolicPulse.disable_validation = True

return pulse.ScalableSymbolicPulse(
pulse_type="GaussianSquare",
duration=230,
Expand Down
3 changes: 3 additions & 0 deletions docs/userguide/how_to_use_pulse_schedule_for_jax_jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ JAX-compiled (or more generally, JAX-transformed).
_amp * sym.exp(sym.I * _angle) * lifted_gaussian(_t, _center, _duration + 1, _sigma)
)

# we need to set disable_validation True to enable jax-jitting.
pulse.ScalableSymbolicPulse.disable_validation = True

gaussian_pulse = pulse.ScalableSymbolicPulse(
pulse_type="Gaussian",
duration=160,
Expand Down
12 changes: 10 additions & 2 deletions qiskit_dynamics/pulse/pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from qiskit.pulse.library import SymbolicPulse
from qiskit import QiskitError

from qiskit_dynamics.array import Array
from qiskit_dynamics import DYNAMICS_NUMPY as unp
from qiskit_dynamics import ArrayLike

Expand Down Expand Up @@ -349,6 +348,12 @@ def get_samples(pulse: SymbolicPulse) -> ArrayLike:
raise PulseError("Pulse envelope expression is not assigned.")

args = []
try:
backend = (
"jax" if any(isinstance(v, jax.core.Tracer) for v in pulse_params.values()) else "numpy"
)
except (ImportError, NameError):
backend = "numpy"
for symbol in sorted(envelope.free_symbols, key=lambda s: s.name):
if symbol.name == "t":
times = unp.arange(0, pulse_params["duration"]) + 1 / 2
Expand All @@ -361,7 +366,10 @@ def get_samples(pulse: SymbolicPulse) -> ArrayLike:
f"Pulse parameter '{symbol.name}' is not defined for this instance. "
"Please check your waveform expression is correct."
) from ex
return _lru_cache_expr(envelope, Array.default_backend())(*args)
return _lru_cache_expr(
envelope,
backend,
)(*args)


@functools.lru_cache(maxsize=None)
Expand Down
44 changes: 35 additions & 9 deletions test/dynamics/pulse/test_pulse_to_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,15 @@

from qiskit_ibm_runtime.fake_provider import FakeQuito

try:
from jax import jit
except ImportError:
pass

from qiskit_dynamics.pulse import InstructionToSignals
from qiskit_dynamics.signals import DiscreteSignal

from ..common import QiskitDynamicsTestCase, TestJaxBase
from ..common import QiskitDynamicsTestCase, JAXTestBase


class TestPulseToSignals(QiskitDynamicsTestCase):
Expand Down Expand Up @@ -358,14 +363,17 @@ def test_barrier_instructions(self):
self.assertAllClose(sigs[1].samples, np.array([0.0, 0.0, 0.0, -0.5, -0.5, -0.5]))


class TestPulseToSignalsJAXTransformations(QiskitDynamicsTestCase, TestJaxBase):
class TestPulseToSignalsJAXTransformations(JAXTestBase):
"""Tests InstructionToSignals class by using Jax."""

def setUp(self):
"""Set up gaussian waveform samples for comparison."""
self.constant_get_waveform_samples = (
pulse.Constant(duration=5, amp=0.1).get_waveform().samples
)
self.gaussian_get_waveform_samples = (
pulse.Gaussian(duration=5, amp=0.983, sigma=2.0).get_waveform().samples
)
self._dt = 0.222

def test_InstructionToSignals_with_JAX(self):
Expand All @@ -378,8 +386,8 @@ def jit_func_instruction_to_signals(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we can use only SymbolicPulse when jax-jitting
# bacause jax-jitting doesn't correspond to validate_parameters in qiskit.pulse.
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
instance = pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand All @@ -393,11 +401,27 @@ def jit_func_instruction_to_signals(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_instruction_to_signals)(0.1)
self.jit_grad_wrap(jit_func_instruction_to_signals)(0.1)
jit_samples = self.jit_wrap(jit_func_instruction_to_signals)(0.1)
def jit_func_gaussian_to_signals(amp):
pulse.Gaussian.disable_validation = True
instance = pulse.Gaussian(duration=5, amp=amp, sigma=2.0)
with pulse.build() as schedule:
pulse.play(instance, pulse.DriveChannel(0))

converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

jit(jit_func_instruction_to_signals)(0.1)
self.jit_grad(jit_func_instruction_to_signals)(0.1)
jit_samples = jit(jit_func_instruction_to_signals)(0.1)
self.assertAllClose(jit_samples, self.constant_get_waveform_samples, atol=1e-7, rtol=1e-7)

jit(jit_func_gaussian_to_signals)(0.983)
self.jit_grad(jit_func_gaussian_to_signals)(0.983)
jit_gaussian_samples = jit(jit_func_gaussian_to_signals)(0.983)
self.assertAllClose(
jit_gaussian_samples, self.gaussian_get_waveform_samples, atol=1e-7, rtol=1e-7
)

def test_pulse_types_combination_with_jax(self):
"""Test that converting schedule including some pulse types with Jax works well"""

Expand All @@ -408,6 +432,8 @@ def jit_func_symbolic_pulse(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
instance = pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand All @@ -430,8 +456,8 @@ def jit_func_symbolic_pulse(amp):
converter = InstructionToSignals(self._dt, carriers={"d0": 5})
return converter.get_signals(schedule)[0].samples

self.jit_wrap(jit_func_symbolic_pulse)(0.1)
self.jit_grad_wrap(jit_func_symbolic_pulse)(0.1)
jit(jit_func_symbolic_pulse)(0.1)
self.jit_grad(jit_func_symbolic_pulse)(0.1)


@ddt
Expand Down
2 changes: 2 additions & 0 deletions test/dynamics/solvers/test_solver_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,8 @@ def constant_pulse(amp):
(1, sym.And(_time >= 0, _time <= _duration)), (0, True)
)
valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0
# we need to set disable_validation True to enable jax-jitting.
pulse.SymbolicPulse.disable_validation = True
return pulse.SymbolicPulse(
pulse_type="Constant",
duration=5,
Expand Down

0 comments on commit 5b99ba8

Please sign in to comment.