You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With the recent change, JAX no longer automatically hashes an unhashable object by its identity. I have some huge numpy arrays containing trained weights that really don't change in values throughout the entire program that I would like to pass via static_argnum. To be on the safe side I could also set the array's WRITEABLE flag to False.
Would there be any problem with the internal calculation of current JAX if I continue to wrap the array such that it hashes by identity? I understand that I need to take 100% responsibility for ensuring the array is "static" as far as JAX is concerned. Is there any caveat to this? Reason for asking this is because of the line Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about). mentioned in the link above, which makes me wonder if there is some issue with this method.
The text was updated successfully, but these errors were encountered:
This approach sounds safe to me. I think the issues alluded to in the text you quoted were just about excessive recompiles, i.e. places where we accidentally were passing in un-hashable objects (like lists) into static args and thus not getting cache hits.
In general, that's the only kind of error hashing on object identity can produce: excessive recompiles (assuming immutable objects).
With the recent change, JAX no longer automatically hashes an unhashable object by its identity. I have some huge numpy arrays containing trained weights that really don't change in values throughout the entire program that I would like to pass via static_argnum. To be on the safe side I could also set the array's WRITEABLE flag to False.
Would there be any problem with the internal calculation of current JAX if I continue to wrap the array such that it hashes by identity? I understand that I need to take 100% responsibility for ensuring the array is "static" as far as JAX is concerned. Is there any caveat to this? Reason for asking this is because of the line
Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about).
mentioned in the link above, which makes me wonder if there is some issue with this method.The text was updated successfully, but these errors were encountered: