Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update functools.py #71

Merged
merged 13 commits into from
Jun 14, 2024
Merged

Update functools.py #71

merged 13 commits into from
Jun 14, 2024

Conversation

timmens
Copy link
Member

@timmens timmens commented Apr 24, 2024

In this PR, I

The more general decorator allow_kwargs is replaced with the more specific allow_only_kwargs to reduce complexity when developing (i.e., I do not have to think about any mixed-case scenarios) and to throw errors more quickly and with more informative error messages. When working with functions internally, we either want to call them with keyword-only arguments (in which case I do not need the possibility to call them without keyword arguments) or with positional arguments. In the latter case we do not gain anything by implementing an allow_only_args decorator, which is why we stick with the more general one here.

@timmens timmens requested a review from hmgaudecker June 12, 2024 16:00
Copy link
Member

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just a few renaming suggestions and general style of tests.

src/lcm/functools.py Outdated Show resolved Hide resolved
@@ -445,7 +445,7 @@ def retrieve_non_sparse_choices(indices, grids, grid_shape):
return out


@partial(vmap_1d, variables=["index"])
@partial(vmap_1d, variables=["index"], apply_allow_kwargs=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@partial(vmap_1d, variables=["index"], apply_allow_kwargs=False)
@partial(vmap_1d, variables=["index"], callable_with="kwargs")

Argument will definitely have to be renamed in vmap_1d. This is just a suggestion -- instead of Boolean with an awkward name, use Literal that could be args or kwargs. But there could be better ideas, I did not think about it for very long.

tests/test_dispatchers.py Show resolved Hide resolved
tests/test_functools.py Outdated Show resolved Hide resolved
got = _retrieve_non_sparse_choices(
index=jnp.array([0, 3, 7]),
got = retrieve_non_sparse_choices(
indices=jnp.array([0, 3, 7]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In simulate.py, is there a reason for having indices and index differently in the public and private version of retrieve_non_sparse_choices? If not, this might be a good occasion to harmonise them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason was that the private function was written to work on a single index and then vmapped. Noticing that I only have to vmap jnp.unravel_index I rewrote this part for clarity.

@timmens timmens merged commit 133ed92 into main Jun 14, 2024
6 checks passed
@timmens timmens deleted the update-functools branch June 14, 2024 16:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants