Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter dependent simulation for optimal bayesian experimentation - JAX @jit compatibility #243

Closed
abailly-at-anl opened this issue Jul 12, 2023 · 10 comments
Labels
enhancement New feature or request

Comments

@abailly-at-anl
Copy link

What is the expected behavior?

I am trying to parameterize simulations by decoherence rates and Hamiltonian parameters. In trying to implement this functionality, we have run into lots of compatibility issues with JAX and just-in-time compiling. It notably is not possible to return 'Solver's or 'Signal's or to receive them as parameters.

I attached a Jupyter Notebook which outlines what we have tried so far. As it does not appear possible to create a new Solver instance in a jit compiled function, we instead edit the static_hamiltonian of an existing Solver. However, it does not seem that the same can be done with the hamiltonian_operators or lindblad_dissipators.

Is there a better way to go about implementing these parameter dependent simulations? If not, it would be greatly appreciated if the Qiskit team built in greater compatibility with JAX just-in-time compiling in this regard.
Parameter Dependent Simulation.pdf

@abailly-at-anl abailly-at-anl added the enhancement New feature or request label Jul 12, 2023
@DanPuzzuoli
Copy link
Collaborator

Hi, thanks for the inquiry! Can you attach the original ipynb file rather than a pdf? I can show you some options for how to get this to work.

@abailly-at-anl
Copy link
Author

abailly-at-anl commented Jul 13, 2023 via email

@DanPuzzuoli
Copy link
Collaborator

I'm not sure if I'm missing the attachment in the email, but I don't see the notebook anywhere. Are you not able to attach a notebook in a comment? Maybe try zipping it?

@abailly-at-anl
Copy link
Author

I had attached the notebook to the email, I zipped and attached it here also. Let me know if that works!

Parameter Dependent Simulation.zip

@DanPuzzuoli
Copy link
Collaborator

I've attached an updated version of the notebook. I've put some alternate code you can use to compile these things into the section labelled Dan Alternate code suggestion.

Parameter Dependent Simulation.ipynb.zip

I've written some explanation in there, but what I'd suggest for you is to drop the RWA in the construction of the Solver, and with this you can actually build the Solver directly within the function you want to compile, so long as you set validate=False. Both the RWA code and the validation code are not JAX compatible as they both depend on the values in the model operator arrays, which compiled functions can't depend on (the RWA could probably be made JAX compatible but it's extremely low priority).

As an aside: In this case the simulation is just as performant with/without the RWA. What I've observed with this package is that a lot of the numerical benefits of the RWA are actually already present just in entering the rotating frame (without actually doing the approximation).

One issue I couldn't resolve with playing with your notebook is that there's actually a discrepancy in the output of your function v.s. the one I've just made, even if I drop the RWA from yours. I haven't been able to figure this out, though I haven't dug too deeply. I am more inclined to trust my version, as modifying the Solver after the fact could result in sketchy behaviour. (In fact I forgot these setter functions for the operators even existed. I would need to think about it again, but it might make sense to even remove these to discourage this behaviour. It's been a while since I've worked on this code, but I typically treat the Solver as immutable once I've created it.) If you are able to determine why there is a discrepancy I'd be interested to know.

An alternative to what I've shown here is to still create the Solver outside of the jitted function, but to put the parts of the model that you want to modify into hamiltonian_operators and dissipator_operators. These are the terms that are meant to have their coefficients updated on the fly. I think however what I've written is a bit more natural than this.

Lastly, I was amazed to discover that:

def signal_from_input(pulse_input):
    amp = Array(pulse_input[0])
    w = pulse_input[1]
    signal = [Signal(amp, carrier_freq = 1.)]
    signal[0].carrier_freq = w
    return signal

can't be changed to the following and still be jax-compilation compatible:

def signal_from_input(pulse_input):
    amp = Array(pulse_input[0])
    w = pulse_input[1]
    signal = [Signal(amp, carrier_freq = w)]
    return signal

I'm going to create an issue that this should be fixed.

@abailly-at-anl
Copy link
Author

Thanks for the suggestions, I will experiment with constructing the Solver directly in the method to see if it's more performant.

I also noticed a similar discrepancy to the one you mentioned this morning. In the original notebook I uploaded, I am using a separate function to update the Hamiltonian. Copying the same code from update_static_hamiltonian into the simulator function itself produces yet new different results. I'm not sure what's going on there. I updated the notebook again so you can see what I mean.

I was likewise surprised about the carrier_freq issue, which I only realized on writing this example notebook.
Parameter Dependent Simulation again.zip

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Jul 14, 2023

Your new function is extracting r, w, b from parameters, but you need to change it to hamiltonian_parameters. After I make this change it agrees with your original function.

Also btw you are setting the initial construction with

r, w, B = parameters

but then later the order is changed in your functions as:

    w = parameters[0]
    r = parameters[1]
    B = parameters[2]

not sure if this will change the comparison but could be another source of mixing things up.

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Jul 14, 2023

Okay so if I get rid of the setting of the Hamiltonian parameters in your function, and pass hamiltonian_parameters=parameters into my version of the function, the results agree 🎉 . If I change parameters and walk through the whole notebook again they keep agreeing.

So, something fishy is definitely going on with updating model operators. If you don't mind I'll change the name of this issue to point to this specific bug. Unless you disagree, I feel your issue has more-or-less been resolved, and now what remains as far as Dynamics-development is concerned is this remaining problem (along with the issue about Signal carrier frequency). I'm thinking it may make sense to simply make the models immutable.

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Jul 18, 2023

@abailly-at-anl just fyi that PR #247 will fix the carrier frequency tracing issue #245.

@DanPuzzuoli
Copy link
Collaborator

Closing this issue as the discussion is out of date.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants