Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arraylias integration - Signal class - #269

Merged
merged 45 commits into from
Oct 30, 2023

Conversation

to24toro
Copy link
Contributor

Summary

I started to integrate arraylias to dynamics.
This PR includes the integration of the Signal class in dynamics.

Details and comments

@@ -728,6 +738,7 @@ def test_conjugate(self):
)


@partial(test_array_backends, array_libraries=["numpy", "jax", "array_numpy", "array_jax"])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decorator works as creating TestDiscreteSignalSum of "numpy", "jax", "array_numpy", "array_jax", which inherits TestSignalSum. But there is no class whose name is TestSignalSum because it has already been decorated by test_array_backends.

@to24toro
Copy link
Contributor Author

to24toro commented Oct 19, 2023

The current implementation fails the tutorial in optimizing_pulse_sequence.rst. The main reason for this failure is that numpy.array and jax.numpy.array sometimes gets in the mix in the code. This would be no problem if we don't use jax.jit. However, this failure is found within jit(value_and_grad(objective)) in this tutorial.

if isinstance(signal, DiscreteSignal):
  # Perform a discrete time convolution.
  dt = signal.dt
  func_samples = np.asarray([self._func(dt * i) for i in range(signal.duration)])
  func_samples = func_samples / sum(func_samples)
  sig_samples = signal(dt * np.arange(signal.duration))

even if np -> unp is changed, unp.arange(signal.duration) becomes a numpy.array because signal.duration is assumed to be an int. The point is that types like int or float get coverted to numpy by unp.arange(). In the other words, numpy takes precedence over jax, leading to the issue.

@to24toro to24toro changed the title [WIP] Arraylias integration - Signal class - Arraylias integration - Signal class - Oct 20, 2023
@to24toro to24toro marked this pull request as ready for review October 20, 2023 05:25
Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The signals and test/dynamics/signals folders look good to me aside from the two comments I made:

  • change alias to numpy_alias
  • Remove all redundant type hinting for "numbers" and List where ArrayLike appears.

from qiskit import QiskitError
from qiskit_dynamics.array import Array
from qiskit_dynamics.arraylias import ArrayLike
from qiskit_dynamics.arraylias import DYNAMICS_NUMPY_ALIAS as alias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe we should call this numpy_alias instead of just alias now that there will be one for NumPy and one for SciPy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed at 79c45ce .

Comment on lines 72 to 74
envelope: Union[Callable, complex, float, int, ArrayLike],
carrier_freq: Union[float, List, ArrayLike] = 0.0,
phase: Union[float, List, ArrayLike] = 0.0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type hints are redundant as ArrayLike includes float, int, complex. Should we remove everything but ArrayLike here?

I see ArrayLike doesn't include List, but maybe we should add that, as it does technically act on lists as well.

Edit: Ah yes I see later on in the file for DiscreteSignalSum you do conversions like Union[float, List, ArrayLike] -> ArrayLike. This should be done for all similar type hints throughout the whole file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I fixed at a3509f5 and 921126b .

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Oct 24, 2023

I'm trying to figure out how to get the tests to pass in other submodules - the issue with a lot of the failures seems to be that:

unp.asarray(Array(x, backend='jax'))

ends up calling numpy.array(x).

This issue comes up in tests like test.dynamics.models.test_generator_model.TestGeneratorModelSparseJax.test_jit_grad in which a value a is being traced, and the function contains the code:

Signal(Array(a))

This test (and I'm guessing many others) can be fixed by simply changing the above to:

Signal(a)

This is how we'll want to do things going forward anyway, but this doesn't help with the fact that Signal(Array(x, backend='jax')) is broken when x is a tracer.

Fixing this by modifying Array appears to be nontrivial. I haven't been able to find a quick fix as numpy gets upset when np.array returns something that is not an array. As such, it seems that maybe asarray cannot work how we'd want for Array with backend='jax'.

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Oct 24, 2023

I've figured out how to bypass the above issue and directly fix most tests here. The only thing I haven't fixed yet is stuff to do with pulse -> signal conversion and JAX compatibility - but I think it may just be the same issue.

I'll attempt again to see if we can "properly" fix this problem (i.e. get unp.asarray to work with Array with backend=="jax"). It would be nice to have backwards compatibility with code that uses Array, but I'm also okay with it breaking if we can't find a clean solution.

@to24toro
Copy link
Contributor Author

The tests which have not been passed are two types:

Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looking good to me, but some final nitpicking details to sort out.

qiskit_dynamics/pulse/pulse_to_signals.py Outdated Show resolved Hide resolved
qiskit_dynamics/signals/signals.py Outdated Show resolved Hide resolved
qiskit_dynamics/signals/signals.py Outdated Show resolved Hide resolved
qiskit_dynamics/signals/signals.py Outdated Show resolved Hide resolved
qiskit_dynamics/signals/signals.py Outdated Show resolved Hide resolved
qiskit_dynamics/signals/signals.py Outdated Show resolved Hide resolved
Comment on lines 861 to 863
samples = unp.append(sig1.samples, sig2.samples, axis=1)
carrier_freq = unp.append(sig1.carrier_freq, sig2.carrier_freq)
phase = unp.append(sig1.phase, sig2.phase)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect these lines may be prone to the dispatching errors we had to deal with in other places. Should these be changed to use _numpy_multi_dispatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to change it to _numpy_multi_dispatch.
I fixed at 2e1436e.

0.5
* (sig1.samples[:, :, None] * sig2.samples[:, None, :].conj()).reshape(
(sig1.samples.shape[0], sig1.samples.shape[1] * sig2.samples.shape[1]),
order="C",
)
)
samples = np.append(new_samples, new_samples_conj, axis=1)
samples = unp.append(new_samples, new_samples_conj, axis=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_numpy_multi_dispatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed at 2e1436e.


new_freqs = sig1.carrier_freq + sig2.carrier_freq
new_freqs_conj = sig1.carrier_freq - sig2.carrier_freq
freqs = np.append(Array(new_freqs), Array(new_freqs_conj))
freqs = unp.append(unp.asarray(new_freqs), unp.asarray(new_freqs_conj))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_numpy_multi_dispatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed at 2e1436e.


new_phases = sig1.phase + sig2.phase
new_phases_conj = sig1.phase - sig2.phase
phases = np.append(Array(new_phases), Array(new_phases_conj))
phases = unp.append(unp.asarray(new_phases), unp.asarray(new_phases_conj))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_numpy_multi_dispatch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed at 2e1436e.

@DanPuzzuoli DanPuzzuoli self-requested a review October 30, 2023 14:39
DanPuzzuoli
DanPuzzuoli previously approved these changes Oct 30, 2023
@DanPuzzuoli DanPuzzuoli merged commit 7333814 into qiskit-community:main Oct 30, 2023
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants