You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The jaxified idaklu solver is reporting different sensitivities than the standard idaklu solver if t_eval is different from t_interp, might be due to the t_interp addition in 24.9?
Steps to Reproduce
importnumpyasnpimportpybammsolver=pybamm.IDAKLUSolver(atol=1e-6, rtol=1e-6)
model=pybamm.lithium_ion.SPM()
parameters=model.default_parameter_values# Fitting parametersparameters["Negative electrode active material volume fraction"] ="[input]"parameters["Positive electrode active material volume fraction"] ="[input]"experiment=pybamm.Experiment([("Charge at 0.5C for 3 minutes (3 second period)",)])
# Generate datasigma=0.002sim=pybamm.Simulation(
model, parameter_values=parameters, solver=solver, experiment=experiment
)
inputs= {
"Negative electrode active material volume fraction": 0.55,
"Positive electrode active material volume fraction": 0.55,
}
sim.build(initial_soc=0.5, inputs=inputs)
model=sim.built_modelsolver=sim.solvert_eval=np.array([0.0, 180.0])
t_data=np.linspace(0, 180, 20)
jax_idaklu=solver.jaxify(
model,
t_eval,
output_variables=["Voltage [V]"],
calculate_sensitivities=True,
t_interp=t_data,
)
defjax_all_grad(inputs):
grads=jax_idaklu.jax_grad(t_data, inputs)["Voltage [V]"]
returnnp.vstack(
[
grads["Negative electrode active material volume fraction"],
grads["Positive electrode active material volume fraction"],
]
).Tdefida_all_grad(inputs):
sol=solver.solve(
model,
t_eval=t_eval,
t_interp=t_data,
inputs=inputs,
calculate_sensitivities=True,
)
voltage_grad=sol["Voltage [V]"].sensitivities["all"]
returnvoltage_gradjax_all_grad=jax_all_grad(inputs)
std_all_grad=ida_all_grad(inputs)
print("jax_all_grad", jax_all_grad)
print("std_all_grad", std_all_grad)
Hi - I gave this a quick once-over and have a thought - I note that t_eval has changed its interpretation in the jaxify() function; whereas it used to read "The times at which to compute the solution. If None, the times in the model are used", it now reads "The times at which to stop the integration due to a discontinuity in time". Jaxify (as I originally wrote it, before t_interp was introduced) will solve for a multi-element time vector and cache the result so that queries at individual time points can be returned rapidly without reevaluation. To me this likely explains why t_eval and t_data are giving different results in your simulation, and why you are only seeing two unique values being output. I suspect that all values (not just sensitivities) are affected in this way too.
PyBaMM Version
24.11.1
Python Version
3.10
Describe the bug
The jaxified idaklu solver is reporting different sensitivities than the standard idaklu solver if
t_eval
is different fromt_interp
, might be due to thet_interp
addition in 24.9?Steps to Reproduce
Relevant log output
The text was updated successfully, but these errors were encountered: