You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm having an issue restoring pure dict from nnx.state,
Orbax will change the key has int type into string type.
Is there any way to fix it?
I put a sample code that can reproduce the issue.
As a temporal fix, I'm using the function below to convert all the str(int) keys into int
def convert_keys(d):
new_d = {}
for key, value in d.items():
# Convert key if it is a string representation of an integer
if isinstance(key, str):
try:
new_key = int(key)
except ValueError:
new_key = key
else:
new_key = key
# Recursively process nested dictionaries
if isinstance(value, dict):
new_value = convert_keys(value)
else:
new_value = value
new_d[new_key] = new_value
return new_d
Hey @mmorinag127, thanks for posting this. I believe we knew about this edge case but forgot to upstream it.
I've created #4317 with a fix that passes your test. The solution is similar to what @OmerRochman is showing above (try casting to int) but I've integrated it into State.replace_by_pure_dict so it works out of the box.
Dear team,
I'm having an issue restoring pure dict from nnx.state,
Orbax will change the key has int type into string type.
Is there any way to fix it?
I put a sample code that can reproduce the issue.
This code throws the error following
ValueError: key in pure_dict not available in state: ('layers', '0', 'bias')
The text was updated successfully, but these errors were encountered: