-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to do parallel solver #25
Comments
Solving a linear system involves iterative steps, and is not as straightforward to be simply "batched". JAX-FEM uses the JAX version of scipy , not the original scipy, but I still doubt if it is even possible to perform batch operation for linear solvers. |
In solver, the function get_A_fn is using scipy no ? Here the error I got when I try to use vmap: The above exception was the direct cause of the following exception: Traceback (most recent call last): Is there a reason that you are using onp instead of np ? I couldn't understand |
It is to save GPU memory. You can use JAX-FEM to solve a problem with over 1 million DOFs on a single GPU. We need to save GPU memory so lots of places arrays are in onp. For odeint I don't think they solve linear systems, it's more or less explicit updates. |
ok thanks ! I understand now. So hypothetically, I can replace onp by np if I don't have a large number of DOFs right ? would it be a problem if I replace it by a Jax experimental sparse version ? Ok I will look at it closer then. |
Hi, I'm trying to create an hybrid model, that use your FEM-solver and a neural network.
To do that, I need to solve the same equations with different parameters (e.g the heat diffusion, where the diffusion coefficient and ic are different for each data).
I can't use vmap because the solver is using scipy and numpy, which isn't compatible.
Do you think the solver can be adapted so it can managed batches or can be pass into vmap ?
Thanks in advance for any idea !
The text was updated successfully, but these errors were encountered: