-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update functools.py #71
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just a few renaming suggestions and general style of tests.
src/lcm/simulate.py
Outdated
@@ -445,7 +445,7 @@ def retrieve_non_sparse_choices(indices, grids, grid_shape): | |||
return out | |||
|
|||
|
|||
@partial(vmap_1d, variables=["index"]) | |||
@partial(vmap_1d, variables=["index"], apply_allow_kwargs=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@partial(vmap_1d, variables=["index"], apply_allow_kwargs=False) | |
@partial(vmap_1d, variables=["index"], callable_with="kwargs") |
Argument will definitely have to be renamed in vmap_1d
. This is just a suggestion -- instead of Boolean with an awkward name, use Literal that could be args
or kwargs
. But there could be better ideas, I did not think about it for very long.
got = _retrieve_non_sparse_choices( | ||
index=jnp.array([0, 3, 7]), | ||
got = retrieve_non_sparse_choices( | ||
indices=jnp.array([0, 3, 7]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In simulate.py, is there a reason for having indices
and index
differently in the public and private version of retrieve_non_sparse_choices
? If not, this might be a good occasion to harmonise them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason was that the private function was written to work on a single index and then vmapped. Noticing that I only have to vmap jnp.unravel_index
I rewrote this part for clarity.
In this PR, I
functools.py
allow_kwargs
decorator with theallow_only_kwargs
decoratorThe more general decorator
allow_kwargs
is replaced with the more specificallow_only_kwargs
to reduce complexity when developing (i.e., I do not have to think about any mixed-case scenarios) and to throw errors more quickly and with more informative error messages. When working with functions internally, we either want to call them with keyword-only arguments (in which case I do not need the possibility to call them without keyword arguments) or with positional arguments. In the latter case we do not gain anything by implementing anallow_only_args
decorator, which is why we stick with the more general one here.