Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support C->R case #114

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Support C->R case #114

wants to merge 7 commits into from

Conversation

Randl
Copy link
Contributor

@Randl Randl commented Nov 17, 2024

@Randl Randl marked this pull request as ready for review November 19, 2024 04:54
@Randl
Copy link
Contributor Author

Randl commented Nov 19, 2024

@patrick-kidger I think this is a reasonable implementation of what we discussed. Some documentation changes are probably required.
Mixed complex-real outputs are not supported here (but inputs are). We could aggressively assume that any real output means real output, but that would complicate implementation. It also would be a breaking change for cases where the function is properly linear (e.g., when complex and real inputs are multiplied each by a separate matrix). I don't think we can really detect this case and give it special treatment, but suggestions are welcome. All in all, keeping in mind the optimization use case, I think this is a reasonable degree of generality.
Also, if you see any edge cases I missed, I'll be happy to add them to the tests.

@Randl Randl changed the title [WIP] Support C->R case Support C->R case Nov 19, 2024
@Randl
Copy link
Contributor Author

Randl commented Nov 27, 2024

ping @patrick-kidger

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, sorry for the delay! I've left one main thought on this. WDYT?

@@ -19,6 +19,7 @@
import jax.core
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import ShapeDtypeStruct
Copy link
Owner

@patrick-kidger patrick-kidger Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd just use jax.ShapeDtypeStruct directly for consistency with the rest of our imports.

Comment on lines +1328 to +1347
if is_complex_structure(operator.in_structure()) and not is_complex_structure(
operator.out_structure()
):
# We'll use R^2->R representation for C->R function.
in_structure = complex_to_real_structure(operator.in_structure())

map_to_original = lambda x: real_to_complex_tree(
x,
operator.in_structure(),
)
else:
map_to_original = lambda x: x
in_structure = operator.in_structure()
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
eqx.filter_eval_shape(jfu.ravel_pytree, in_structure)
)
fn = lambda x: operator.fn(map_to_original(unravel(x)))
eye = jnp.eye(flat.size, dtype=flat.dtype)
jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye)

jac = jax.vmap(fn, out_axes=-1)(eye)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this change: given a C->R FunctionLinearOperator, call it operator, then I think operator(some_complex_vector) would have worked but materialise(operator)(some_complex_vector) would not work, as it now expects something from R^2 instead?

The intention was that materialise would not change the observable input-output behaviour of an operator at all.

(FWIW I've just checked the current behaviour on FunctionLinearOperator(lambda x: x.real, jax.ShapeDtypeStruct((), jnp.complex64)) and this is also wrong in the expected way: there's no way to express 'take a real part' when multiplying against a pytree, so it's not like the current state of affairs is any better... !)

It seems to me like the AbstractLinearOperator abstraction might actually just be fundamentally incompatible with complex dtypes, due to the not-really-defined nature of linear operators over such spaces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, as we discussed in patrick-kidger/optimistix#76 (comment) it is impossible to materialize C->R operator, since it is not complex linear but rather linear in imaginary and real parts. The two solutions I see is one in this PR (break the promise that materialize is noop) or just assert we do not support C->R operators (probably just give a warning) and suggest to user to make R^2->R operator out of it. The latter is definitely cleaner from the point of view of consistency, but may be less convenient to end user? Not sure about it, but luckily for me the decision is yours 🙃
The motivation was to support C->R stuff in optax so it can be a good testbed in terms of understanding how convenient this stuff is

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should follow JAX's lead on this one, with how jax.jac{fwd,rev} similarly error out for these kinds of operations.

So I think something like this might be the best choice, then!

class AbstractLinearOperator(eqx.Module):
    def __check_init__(self):
        if is_complex_structure(self.in_structure()) and not is_complex_structure(self.out_structure()):
            raise ValueError(...)

In terms of end user convenience: I care a lot about this! My usual rule for this is that it is better not to support something than to support it awkwardly or with edge-cases. I think this leads to less frustration and a better UX overall. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! Note that your solution will break existing cases for mixed-typed operators (i.e., single operator built from C->C and R->R blocks).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that is a good point about block-with-mixed-type operators. Ach, complex support is very complicated! Is there a clean way to detect and allow that case, that you can see?

I am concerned that our complex support isn't quite meeting the above UX standards I'd like us to have... (this has definitely been a learning experience for me on how complex autodiff/operators/etc work!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible way is to materialize the operator, but that can be expensive. I was thinking just issue a warning instead of exception.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I'm not sure I'm comfortable with the warning approach -- I try to avoid this kind of flaky maybe-right maybe-not behaviour. I'd rather just prohibit C->R altogether, even if it prohibits mixed C->C + R->R combinations. (Which can anway still be trivially supported by an end using by doing C=R^2 themselves.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants