Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IndicesBoundaryMasker in JAX causes halting problem when used in multi-GPU #93

Open
mehdiataei opened this issue Dec 3, 2024 · 3 comments
Assignees

Comments

@mehdiataei
Copy link
Contributor

The current implementation occasionally causes halting between GPUs when padding is applied to the bmap. This issue primarily arises because the function cannot be JIT-compiled.

In the JAX implementation of IndicesBoundaryMasker, there are several operations, such as conditional statements, that are not supported in JIT-compiled JAX. The previous implementation was JIT-compatible and did not encounter these issues, which should serve as the reference for resolving this problem.

@hsalehipour
Copy link
Collaborator

can you create a repro please?

@mehdiataei
Copy link
Contributor Author

mehdiataei commented Dec 3, 2024

Just run any example (e.g., flow over sphere) multiple times with having both GPUs visible. It will halt 90% of the time.

Please use the old implementation to fix this. Thanks.

@mehdiataei
Copy link
Contributor Author

(this is not related to the latest PR, it happens in the old version as well).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants