Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: strange behaviour of jaxified idaklu solver wrt t_eval #4697

Open
martinjrobins opened this issue Dec 19, 2024 · 2 comments
Open

[Bug]: strange behaviour of jaxified idaklu solver wrt t_eval #4697

martinjrobins opened this issue Dec 19, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@martinjrobins
Copy link
Contributor

PyBaMM Version

24.11.1

Python Version

3.10

Describe the bug

The jaxified idaklu solver is reporting different sensitivities than the standard idaklu solver if t_eval is different from t_interp, might be due to the t_interp addition in 24.9?

Steps to Reproduce

import numpy as np
import pybamm

solver = pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model = pybamm.lithium_ion.SPM()
parameters = model.default_parameter_values

# Fitting parameters
parameters["Negative electrode active material volume fraction"] = "[input]"
parameters["Positive electrode active material volume fraction"] = "[input]"

experiment = pybamm.Experiment([("Charge at 0.5C for 3 minutes (3 second period)",)])


# Generate data
sigma = 0.002
sim = pybamm.Simulation(
    model, parameter_values=parameters, solver=solver, experiment=experiment
)
inputs = {
    "Negative electrode active material volume fraction": 0.55,
    "Positive electrode active material volume fraction": 0.55,
}
sim.build(initial_soc=0.5, inputs=inputs)
model = sim.built_model
solver = sim.solver

t_eval = np.array([0.0, 180.0])
t_data = np.linspace(0, 180, 20)

jax_idaklu = solver.jaxify(
    model,
    t_eval,
    output_variables=["Voltage [V]"],
    calculate_sensitivities=True,
    t_interp=t_data,
)


def jax_all_grad(inputs):
    grads = jax_idaklu.jax_grad(t_data, inputs)["Voltage [V]"]
    return np.vstack(
        [
            grads["Negative electrode active material volume fraction"],
            grads["Positive electrode active material volume fraction"],
        ]
    ).T


def ida_all_grad(inputs):
    sol = solver.solve(
        model,
        t_eval=t_eval,
        t_interp=t_data,
        inputs=inputs,
        calculate_sensitivities=True,
    )
    voltage_grad = sol["Voltage [V]"].sensitivities["all"]
    return voltage_grad


jax_all_grad = jax_all_grad(inputs)
std_all_grad = ida_all_grad(inputs)
print("jax_all_grad", jax_all_grad)
print("std_all_grad", std_all_grad)

Relevant log output

jax_all_grad [[0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.0085677  0.08058828]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]
 [0.01686415 0.08268555]]
std_all_grad 
[[0.0085677, 0.0805883], 
 [0.0168642, 0.0826856], 
 [0.0214956, 0.0836814], 
 [0.0246738, 0.0844627], 
 [0.0269041, 0.0851368], 
 [0.0284219, 0.0857429], 
 [0.0293858, 0.0863012], 
 [0.0299154, 0.0868237], 
 [0.030106, 0.0873178], 
 [0.0300357, 0.0877888], 
 [0.0297687, 0.0882406], 
 [0.0293574, 0.0886761], 
 [0.0288444, 0.0890972], 
 [0.0282643, 0.0895058], 
 [0.0276444, 0.0899032], 
 [0.0270062, 0.0902907], 
 [0.0263664, 0.090669], 
 [0.0257379, 0.091039], 
 [0.0251302, 0.0914014], 
 [0.0245502, 0.0917566]]
@martinjrobins martinjrobins added the bug Something isn't working label Dec 19, 2024
@martinjrobins
Copy link
Contributor Author

In the above code, if you replace t_eval with t_data in the call to solver.jaxify the results look better:

jax_all_grad [[0.0085677  0.08058828]
 [0.01686409 0.08268554]
 [0.02149744 0.0836817 ]
 [0.02467562 0.08446308]
 [0.02690489 0.08513703]
 [0.02842267 0.08574318]
 [0.02938645 0.08630157]
 [0.02991571 0.08682403]
 [0.03010602 0.08731817]
 [0.03003551 0.08778924]
 [0.02976831 0.08824102]
 [0.02935691 0.08867635]
 [0.02884398 0.08909744]
 [0.02826386 0.089506  ]
 [0.02764394 0.08990342]
 [0.02700576 0.09029083]
 [0.02636606 0.09066915]
 [0.02573761 0.09103914]
 [0.02512992 0.09140146]
 [0.02454992 0.09175666]]
std_all_grad 
[[0.0085677, 0.0805883], 
 [0.0168642, 0.0826856], 
 [0.0214956, 0.0836814], 
 [0.0246738, 0.0844627], 
 [0.0269041, 0.0851368], 
 [0.0284219, 0.0857429], 
 [0.0293858, 0.0863012], 
 [0.0299154, 0.0868237], 
 [0.030106, 0.0873178], 
 [0.0300357, 0.0877888], 
 [0.0297687, 0.0882406], 
 [0.0293574, 0.0886761], 
 [0.0288444, 0.0890972], 
 [0.0282643, 0.0895058], 
 [0.0276444, 0.0899032], 
 [0.0270062, 0.0902907], 
 [0.0263664, 0.090669], 
 [0.0257379, 0.091039], 
 [0.0251302, 0.0914014], 
 [0.0245502, 0.0917566]]

@jsbrittain
Copy link
Contributor

Hi - I gave this a quick once-over and have a thought - I note that t_eval has changed its interpretation in the jaxify() function; whereas it used to read "The times at which to compute the solution. If None, the times in the model are used", it now reads "The times at which to stop the integration due to a discontinuity in time". Jaxify (as I originally wrote it, before t_interp was introduced) will solve for a multi-element time vector and cache the result so that queries at individual time points can be returned rapidly without reevaluation. To me this likely explains why t_eval and t_data are giving different results in your simulation, and why you are only seeing two unique values being output. I suspect that all values (not just sensitivities) are affected in this way too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants