-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support C->R case #114
base: main
Are you sure you want to change the base?
Support C->R case #114
Changes from all commits
0fd9b1f
4179194
299d8ac
b0f0e29
b55ec0e
8bdf895
8ce46c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,10 +39,13 @@ | |
|
||
from ._custom_types import sentinel | ||
from ._misc import ( | ||
complex_to_real_structure, | ||
default_floating_dtype, | ||
inexact_asarray, | ||
is_complex_structure, | ||
jacobian, | ||
NoneAux, | ||
real_to_complex_tree, | ||
strip_weak_dtype, | ||
) | ||
from ._tags import ( | ||
|
@@ -1322,11 +1325,26 @@ def _(operator): | |
|
||
@materialise.register(FunctionLinearOperator) | ||
def _(operator): | ||
if is_complex_structure(operator.in_structure()) and not is_complex_structure( | ||
operator.out_structure() | ||
): | ||
# We'll use R^2->R representation for C->R function. | ||
in_structure = complex_to_real_structure(operator.in_structure()) | ||
|
||
map_to_original = lambda x: real_to_complex_tree( | ||
x, | ||
operator.in_structure(), | ||
) | ||
else: | ||
map_to_original = lambda x: x | ||
in_structure = operator.in_structure() | ||
flat, unravel = strip_weak_dtype( | ||
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) | ||
eqx.filter_eval_shape(jfu.ravel_pytree, in_structure) | ||
) | ||
fn = lambda x: operator.fn(map_to_original(unravel(x))) | ||
eye = jnp.eye(flat.size, dtype=flat.dtype) | ||
jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) | ||
|
||
jac = jax.vmap(fn, out_axes=-1)(eye) | ||
Comment on lines
+1328
to
+1347
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After this change: given a C->R The intention was that (FWIW I've just checked the current behaviour on It seems to me like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, as we discussed in patrick-kidger/optimistix#76 (comment) it is impossible to materialize There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should follow JAX's lead on this one, with how So I think something like this might be the best choice, then! class AbstractLinearOperator(eqx.Module):
def __check_init__(self):
if is_complex_structure(self.in_structure()) and not is_complex_structure(self.out_structure()):
raise ValueError(...) In terms of end user convenience: I care a lot about this! My usual rule for this is that it is better not to support something than to support it awkwardly or with edge-cases. I think this leads to less frustration and a better UX overall. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough! Note that your solution will break existing cases for mixed-typed operators (i.e., single operator built from C->C and R->R blocks). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, that is a good point about block-with-mixed-type operators. Ach, complex support is very complicated! Is there a clean way to detect and allow that case, that you can see? I am concerned that our complex support isn't quite meeting the above UX standards I'd like us to have... (this has definitely been a learning experience for me on how complex autodiff/operators/etc work!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One possible way is to materialize the operator, but that can be expensive. I was thinking just issue a warning instead of exception. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. I'm not sure I'm comfortable with the warning approach -- I try to avoid this kind of flaky maybe-right maybe-not behaviour. I'd rather just prohibit C->R altogether, even if it prohibits mixed C->C + R->R combinations. (Which can anway still be trivially supported by an end using by doing C=R^2 themselves.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough |
||
|
||
def batch_unravel(x): | ||
assert x.ndim > 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd just use
jax.ShapeDtypeStruct
directly for consistency with the rest of our imports.