You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
removing demotion code from pybamm results in a failure in the return values of iree_hal_buffer_map_read (out of range when returning double; nonsense when returning float [but expecting double]; also doesn't cast upwards to anything sensible but this may be because of reinterprets in the inputs or elsewhere)
jax/jaxlib can be upgraded to at least 0.4.31 without issue
iree-compiler cannot be upgraded past 886 without faulting; the problem is caused by the mlir representation of fcn_jac_times_cjmass_sparse, more specifically the call to rhs_algebraic_eval.get_jacobian(), which results in mlir that extends beyond the allowed stack limit of 32kb (sparsifying the output, etc, doesn't help); reducing model geometry does allow the mlir to compile with iree-compiler 889, but produced nan's; we will need to resolve this in order to keep iree-compiler and iree-src up-to-date for when 64-bit operations are natively supported.
Nevertheless, there are some optimisations that can be applied to the idaklu_solver
replace dense mass matrix with jax.experimental.sparse.BCOO
compute bandwidth explicitely from the sparse representation (avoids densifying)
all explicit demotions in idaklu_solver iree functions can be removed and replaced with .astype(f32) outputs (note that it is not clear whether this is entirely compatible given the iree-64 issue above, but it seems to work!)
The current implementation (jax, jaxlib 0.4.27) runs, but only with f32's. Allowing f64's to propogate through code produces an error in iree_jit.cpp line
status = iree_hal_buffer_map_read(iree_hal_buffer_view_buffer(result_view), 0,
&result[k][0], sizeof(double) * result[k].size());
when size=double the buffer read exceeds limits; when size=float the incorrect values are returned.
Upgrading jax, jaxlib, iree seems like a sensible way forwards, but is hampered by MLIR compilation. Specifically, installing 0.4.28 or higher results in compilation failure for several reasons, one of which is a stack overflow error (error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 234752 bytes). Inspection of the MLIR (fcn_jac_times_cjmass_sparse.mlir) shows that the file in question is 7.7Mb long, due to several 962x962 constant dense matrix declarations. Interestingly, the file question is identical between 0.4.27 and 0.4.28, with the dense definitions persisting in the MLIR code. All MLIR is lowered by Jax, but it is unclear if this requires a Jax solution, or is due to changes in iree.
The main problem seems to originate from jit_fcn_jac_times_cjmass_sparse. Replacing the function body with just a call to get_jacobian() causes the issue, but returning the other part (mass_matrix) permits compilation.
returnmodel.rhs_algebraic_eval.get_jacobian()(t, y, p).astype(jnpfloat)
The following solutions are considered:
Force Heap Allocation in JAX
use global matrix definitions, rather than define them in the jitted function
Use Smaller Functions or Split Operations
Manual MLIR Optimization
Custom JAX Transformations
Inlining Constants
Reduce Tensor Sizes Temporarily
definitely worth a try as a debug tool
Overallocation may be here:
deffcn_jac_times_cjmass(t, y, p, cj):
returnjac_rhs_algebraic_demoted(t, y, p) -cj*mass_matrix_demoteddeffcn_jac_times_cjmass_sparse(t, y, p, cj):
returnfcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col]
since jac_rhs_algebraic_demoted(t, y, p), mass_matrix_demoted and therefore fcn_jac_times_cjmass is (962x962) dense. Updates:
we can update this with sparse.BCOO from jax.experimental and it greatly reduces the mlir filesize, but does not resolve the stack issue.
downgrading to jax/jaxlib 0.4.27 [keeping iree at 889] does not resolve the issue, so this appears to pinpoint iree as the culprit
downgrading iree-compiler to 886 does not produce the stack error (using jax/jaxlib 4.27 and iree-src 889)
Compute bandwidth from sparse representation:
coo=sparse_eval.tocoo() # convert to COOrdinate format for indexingjac_bw_lower=max(coo.col-coo.row)
jac_bw_upper=max(coo.row-coo.col)
Replace coo with BCOO:
deffcn_jac_times_cjmass_sparse(t, y, p, cj):
returnjax_sparse.BCOO.fromdense(
fcn_jac_times_cjmass(t, y, p, cj).T,
nse=jac_times_cjmass_nnz
).data
Compute bandwidth from sparse representation (this works)
Replace dense mass matrix with BCOO
Remove all explicit demotion calls in idaklu_solver and replace with .astype() returns
Upgrade jax (this works to 4.31, but not with iree)
The text was updated successfully, but these errors were encountered:
Summary of below:
iree_hal_buffer_map_read
(out of range when returning double; nonsense when returning float [but expecting double]; also doesn't cast upwards to anything sensible but this may be because of reinterprets in the inputs or elsewhere)fcn_jac_times_cjmass_sparse
, more specifically the call torhs_algebraic_eval.get_jacobian()
, which results in mlir that extends beyond the allowed stack limit of 32kb (sparsifying the output, etc, doesn't help); reducing model geometry does allow the mlir to compile with iree-compiler 889, but produced nan's; we will need to resolve this in order to keep iree-compiler and iree-src up-to-date for when 64-bit operations are natively supported.The current implementation (jax, jaxlib 0.4.27) runs, but only with f32's. Allowing f64's to propogate through code produces an error in
iree_jit.cpp
linewhen size=double the buffer read exceeds limits; when size=float the incorrect values are returned.
Upgrading jax, jaxlib, iree seems like a sensible way forwards, but is hampered by MLIR compilation. Specifically, installing 0.4.28 or higher results in compilation failure for several reasons, one of which is a stack overflow error (
error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 234752 bytes
). Inspection of the MLIR (fcn_jac_times_cjmass_sparse.mlir
) shows that the file in question is 7.7Mb long, due to several 962x962 constant dense matrix declarations. Interestingly, the file question is identical between 0.4.27 and 0.4.28, with the dense definitions persisting in the MLIR code. All MLIR is lowered by Jax, but it is unclear if this requires a Jax solution, or is due to changes in iree.Note - there are 4 components to maintain:
Upgrade table:
Compatibility (tested to 0.4.31):
The main problem seems to originate from
jit_fcn_jac_times_cjmass_sparse
. Replacing the function body with just a call toget_jacobian()
causes the issue, but returning the other part (mass_matrix) permits compilation.The following solutions are considered:
Overallocation may be here:
since
jac_rhs_algebraic_demoted(t, y, p)
,mass_matrix_demoted
and thereforefcn_jac_times_cjmass
is (962x962) dense.Updates:
Compute bandwidth from sparse representation:
Replace coo with BCOO:
The text was updated successfully, but these errors were encountered: