Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable 64-bit MLIR computation #4

Open
4 tasks
jsbrittain opened this issue Aug 14, 2024 · 0 comments
Open
4 tasks

Enable 64-bit MLIR computation #4

jsbrittain opened this issue Aug 14, 2024 · 0 comments

Comments

@jsbrittain
Copy link
Owner

jsbrittain commented Aug 14, 2024

Summary of below:

  • IREE does not support e2e 64-bit computation:
    • see https://github.com/iree-org/iree/issues 8826
    • removing demotion code from pybamm results in a failure in the return values of iree_hal_buffer_map_read (out of range when returning double; nonsense when returning float [but expecting double]; also doesn't cast upwards to anything sensible but this may be because of reinterprets in the inputs or elsewhere)
  • jax/jaxlib can be upgraded to at least 0.4.31 without issue
  • iree-compiler cannot be upgraded past 886 without faulting; the problem is caused by the mlir representation of fcn_jac_times_cjmass_sparse, more specifically the call to rhs_algebraic_eval.get_jacobian(), which results in mlir that extends beyond the allowed stack limit of 32kb (sparsifying the output, etc, doesn't help); reducing model geometry does allow the mlir to compile with iree-compiler 889, but produced nan's; we will need to resolve this in order to keep iree-compiler and iree-src up-to-date for when 64-bit operations are natively supported.
  • Nevertheless, there are some optimisations that can be applied to the idaklu_solver
    • replace dense mass matrix with jax.experimental.sparse.BCOO
    • compute bandwidth explicitely from the sparse representation (avoids densifying)
    • all explicit demotions in idaklu_solver iree functions can be removed and replaced with .astype(f32) outputs (note that it is not clear whether this is entirely compatible given the iree-64 issue above, but it seems to work!)

The current implementation (jax, jaxlib 0.4.27) runs, but only with f32's. Allowing f64's to propogate through code produces an error in iree_jit.cpp line

status = iree_hal_buffer_map_read(iree_hal_buffer_view_buffer(result_view), 0,
    &result[k][0], sizeof(double) * result[k].size());

when size=double the buffer read exceeds limits; when size=float the incorrect values are returned.

Upgrading jax, jaxlib, iree seems like a sensible way forwards, but is hampered by MLIR compilation. Specifically, installing 0.4.28 or higher results in compilation failure for several reasons, one of which is a stack overflow error (error: 'func.func' op exceeded stack allocation limit of 32768 bytes for function. Got 234752 bytes). Inspection of the MLIR (fcn_jac_times_cjmass_sparse.mlir) shows that the file in question is 7.7Mb long, due to several 962x962 constant dense matrix declarations. Interestingly, the file question is identical between 0.4.27 and 0.4.28, with the dense definitions persisting in the MLIR code. All MLIR is lowered by Jax, but it is unclear if this requires a Jax solution, or is due to changes in iree.

Note - there are 4 components to maintain:

Upgrade table:

jax/jaxlib date iree
0.4.27 May-7 20240507.886
0.4.28 May-10 20240510.889
0.4.29 Jun-10 20240610.920
0.4.30 Jun-18 20240618.928
0.4.31 Jul-30 20240730.970
0.4.32 (yanked) Sep-11 20240911.1013
0.4.33 Sep-16 20240916.1018
0.4.34 Oct-4 20241004.1036

Compatibility (tested to 0.4.31):

jax/jaxlib iree-compiler iree-src status
0.4.27 886 886
0.4.28 889 889 ❌ stack allocations are too large
886 ✅ compiler downgrade
887 ❌ stack allocations are too large
0.4.29
886 889
0.4.30
0.4.31 (latest) 970 970 ❌ multiple errors
886 889 ✅ holding compiler version back runs

The main problem seems to originate from jit_fcn_jac_times_cjmass_sparse. Replacing the function body with just a call to get_jacobian() causes the issue, but returning the other part (mass_matrix) permits compilation.

return model.rhs_algebraic_eval.get_jacobian()(t, y, p).astype(jnpfloat)

The following solutions are considered:

  • Force Heap Allocation in JAX
    • use global matrix definitions, rather than define them in the jitted function
  • Use Smaller Functions or Split Operations
  • Manual MLIR Optimization
  • Custom JAX Transformations
  • Inlining Constants
  • Reduce Tensor Sizes Temporarily
    • definitely worth a try as a debug tool

Overallocation may be here:

def fcn_jac_times_cjmass(t, y, p, cj):
    return jac_rhs_algebraic_demoted(t, y, p) - cj * mass_matrix_demoted
def fcn_jac_times_cjmass_sparse(t, y, p, cj):
    return fcn_jac_times_cjmass(t, y, p, cj)[coo.row, coo.col]

since jac_rhs_algebraic_demoted(t, y, p), mass_matrix_demoted and therefore fcn_jac_times_cjmass is (962x962) dense.
Updates:

  • we can update this with sparse.BCOO from jax.experimental and it greatly reduces the mlir filesize, but does not resolve the stack issue.
  • downgrading to jax/jaxlib 0.4.27 [keeping iree at 889] does not resolve the issue, so this appears to pinpoint iree as the culprit
  • downgrading iree-compiler to 886 does not produce the stack error (using jax/jaxlib 4.27 and iree-src 889)

Compute bandwidth from sparse representation:

coo = sparse_eval.tocoo()  # convert to COOrdinate format for indexing
jac_bw_lower = max(coo.col - coo.row)
jac_bw_upper = max(coo.row - coo.col)

Replace coo with BCOO:

def fcn_jac_times_cjmass_sparse(t, y, p, cj):
    return jax_sparse.BCOO.fromdense(
        fcn_jac_times_cjmass(t, y, p, cj).T,
        nse=jac_times_cjmass_nnz
    ).data
  • Compute bandwidth from sparse representation (this works)
  • Replace dense mass matrix with BCOO
  • Remove all explicit demotion calls in idaklu_solver and replace with .astype() returns
  • Upgrade jax (this works to 4.31, but not with iree)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant