Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support C->R case #114

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions lineax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.core
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import ShapeDtypeStruct
Copy link
Owner

@patrick-kidger patrick-kidger Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd just use jax.ShapeDtypeStruct directly for consistency with the rest of our imports.

from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore


Expand Down Expand Up @@ -110,3 +111,44 @@ def structure_equal(x, y) -> bool:
x = strip_weak_dtype(jax.eval_shape(lambda: x))
y = strip_weak_dtype(jax.eval_shape(lambda: y))
return eqx.tree_equal(x, y) is True


def is_complex_structure(structure):
with jax.numpy_dtype_promotion("standard"):
return jnp.isdtype(
jnp.result_type(*(jax.tree.flatten(structure)[0])),
"complex floating",
)


def complex_to_real_structure(in_structure):
return jtu.tree_map(
lambda x: ShapeDtypeStruct(
tuple(x.shape) + (2,), complex_to_real_dtype(x.dtype)
)
if jnp.isdtype(x.dtype, "complex floating")
else x,
in_structure,
)


def complex_to_real_tree(x, in_structure):
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(
lambda x, struct: jnp.stack([x.real, x.imag], axis=-1)
if jnp.isdtype(struct.dtype, "complex floating")
else x,
x,
in_structure,
)


def real_to_complex_tree(x, in_structure):
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(
lambda x, struct: x[..., 0] + 1.0j * x[..., 1]
if jnp.isdtype(struct.dtype, "complex floating")
else x,
x,
in_structure,
)
22 changes: 20 additions & 2 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@

from ._custom_types import sentinel
from ._misc import (
complex_to_real_structure,
default_floating_dtype,
inexact_asarray,
is_complex_structure,
jacobian,
NoneAux,
real_to_complex_tree,
strip_weak_dtype,
)
from ._tags import (
Expand Down Expand Up @@ -1322,11 +1325,26 @@ def _(operator):

@materialise.register(FunctionLinearOperator)
def _(operator):
if is_complex_structure(operator.in_structure()) and not is_complex_structure(
operator.out_structure()
):
# We'll use R^2->R representation for C->R function.
in_structure = complex_to_real_structure(operator.in_structure())

map_to_original = lambda x: real_to_complex_tree(
x,
operator.in_structure(),
)
else:
map_to_original = lambda x: x
in_structure = operator.in_structure()
flat, unravel = strip_weak_dtype(
eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure())
eqx.filter_eval_shape(jfu.ravel_pytree, in_structure)
)
fn = lambda x: operator.fn(map_to_original(unravel(x)))
eye = jnp.eye(flat.size, dtype=flat.dtype)
jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye)

jac = jax.vmap(fn, out_axes=-1)(eye)
Comment on lines +1328 to +1347
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this change: given a C->R FunctionLinearOperator, call it operator, then I think operator(some_complex_vector) would have worked but materialise(operator)(some_complex_vector) would not work, as it now expects something from R^2 instead?

The intention was that materialise would not change the observable input-output behaviour of an operator at all.

(FWIW I've just checked the current behaviour on FunctionLinearOperator(lambda x: x.real, jax.ShapeDtypeStruct((), jnp.complex64)) and this is also wrong in the expected way: there's no way to express 'take a real part' when multiplying against a pytree, so it's not like the current state of affairs is any better... !)

It seems to me like the AbstractLinearOperator abstraction might actually just be fundamentally incompatible with complex dtypes, due to the not-really-defined nature of linear operators over such spaces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, as we discussed in patrick-kidger/optimistix#76 (comment) it is impossible to materialize C->R operator, since it is not complex linear but rather linear in imaginary and real parts. The two solutions I see is one in this PR (break the promise that materialize is noop) or just assert we do not support C->R operators (probably just give a warning) and suggest to user to make R^2->R operator out of it. The latter is definitely cleaner from the point of view of consistency, but may be less convenient to end user? Not sure about it, but luckily for me the decision is yours 🙃
The motivation was to support C->R stuff in optax so it can be a good testbed in terms of understanding how convenient this stuff is

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should follow JAX's lead on this one, with how jax.jac{fwd,rev} similarly error out for these kinds of operations.

So I think something like this might be the best choice, then!

class AbstractLinearOperator(eqx.Module):
    def __check_init__(self):
        if is_complex_structure(self.in_structure()) and not is_complex_structure(self.out_structure()):
            raise ValueError(...)

In terms of end user convenience: I care a lot about this! My usual rule for this is that it is better not to support something than to support it awkwardly or with edge-cases. I think this leads to less frustration and a better UX overall. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! Note that your solution will break existing cases for mixed-typed operators (i.e., single operator built from C->C and R->R blocks).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that is a good point about block-with-mixed-type operators. Ach, complex support is very complicated! Is there a clean way to detect and allow that case, that you can see?

I am concerned that our complex support isn't quite meeting the above UX standards I'd like us to have... (this has definitely been a learning experience for me on how complex autodiff/operators/etc work!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possible way is to materialize the operator, but that can be expensive. I was thinking just issue a warning instead of exception.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I'm not sure I'm comfortable with the warning approach -- I try to avoid this kind of flaky maybe-right maybe-not behaviour. I'd rather just prohibit C->R altogether, even if it prohibits mixed C->C + R->R combinations. (Which can anway still be trivially supported by an end using by doing C=R^2 themselves.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough


def batch_unravel(x):
assert x.ndim > 0
Expand Down
29 changes: 24 additions & 5 deletions lineax/_solver/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
import numpy as np
from jaxtyping import Array, PyTree, Shaped

from .._misc import strip_weak_dtype, structure_equal
from .._misc import (
complex_to_real_structure,
complex_to_real_tree,
is_complex_structure,
real_to_complex_tree,
strip_weak_dtype,
structure_equal,
)
from .._operator import (
AbstractLinearOperator,
IdentityLinearOperator,
Expand Down Expand Up @@ -81,11 +88,14 @@ def ravel_vector(
pytree: PyTree[Array], packed_structures: PackedStructures
) -> Shaped[Array, " size"]:
leaves, treedef = packed_structures.value
out_structure, _ = jtu.tree_unflatten(treedef, leaves)
out_structure, in_structure = jtu.tree_unflatten(treedef, leaves)
# `is` in case `tree_equal` returns a Tracer.
if not structure_equal(pytree, out_structure):
raise ValueError("pytree does not match out_structure")
# not using `ravel_pytree` as that doesn't come with guarantees about order

if is_complex_structure(out_structure) and not is_complex_structure(in_structure):
pytree = complex_to_real_tree(pytree, out_structure)
leaves = jtu.tree_leaves(pytree)
dtype = jnp.result_type(*leaves)
return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves])
Expand All @@ -95,15 +105,24 @@ def unravel_solution(
solution: Shaped[Array, " size"], packed_structures: PackedStructures
) -> PyTree[Array]:
leaves, treedef = packed_structures.value
_, in_structure = jtu.tree_unflatten(treedef, leaves)
leaves, treedef = jtu.tree_flatten(in_structure)
out_structure, in_structure = jtu.tree_unflatten(treedef, leaves)
complex_real = is_complex_structure(in_structure) and not is_complex_structure(
out_structure
)
if complex_real:
leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure))
else:
leaves, treedef = jtu.tree_flatten(in_structure)
sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]])
split = jnp.split(solution, sizes)
assert len(split) == len(leaves)
with warnings.catch_warnings():
warnings.simplefilter("ignore") # ignore complex-to-real cast warning
shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)]
return jtu.tree_unflatten(treedef, shaped)
if complex_real:
return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure)
else:
return jtu.tree_unflatten(treedef, shaped)


def transpose_packed_structures(
Expand Down
7 changes: 7 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def make_composed_operator(getkey, matrix, tags):
return lx.TaggedLinearOperator(operator1 @ operator2, tags)


def make_real_function_operator(getkey, matrix, tags):
fn = lambda x: (matrix @ x).real
_, in_size = matrix.shape
in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype)
return lx.FunctionLinearOperator(fn, in_struct, tags)


# Slightly sketchy approach to finite differences, in that this is pulled out of
# Numerical Recipes.
# I also don't know of a handling of the JVP case off the top of my head -- although
Expand Down
123 changes: 123 additions & 0 deletions tests/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import functools as ft

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import pytest
from lineax._misc import complex_to_real_dtype

from .helpers import (
construct_matrix,
Expand All @@ -27,6 +29,7 @@
has_tag,
make_jac_operator,
make_matrix_operator,
make_real_function_operator,
solvers_tags_pseudoinverse,
tree_allclose,
)
Expand Down Expand Up @@ -118,3 +121,123 @@ def test_jvp(
assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3)
assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3)
assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3)


@pytest.mark.parametrize(
"solver, tags, pseudoinverse",
[stp for stp in solvers_tags_pseudoinverse if stp[-1]],
) # only pseudoinverse
@pytest.mark.parametrize("use_state", (True, False))
@pytest.mark.parametrize(
"make_matrix",
(
construct_matrix,
construct_singular_matrix,
),
)
def test_jvp_c_to_r(getkey, solver, tags, pseudoinverse, use_state, make_matrix):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None

matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=jnp.complex128)

out_size, in_size = matrix.shape
out_dtype = complex_to_real_dtype(matrix.dtype)
vec = jr.normal(getkey(), (out_size,), dtype=out_dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype)

if has_tag(tags, lx.unit_diagonal_tag):
# For all the other tags, A + εB with A, B \in {matrices satisfying the tag}
# still satisfies the tag itself.
# This is the exception.
t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0)

make_op = ft.partial(make_real_function_operator, getkey)
operator, t_operator = eqx.filter_jvp(make_op, (matrix, tags), (t_matrix, t_tags))

if use_state:
state = solver.init(operator, options={})
linear_solve = ft.partial(lx.linear_solve, state=state)
else:
linear_solve = lx.linear_solve

solve_vec_only = lambda v: linear_solve(operator, v, solver).value
vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,))

solve_op_only = lambda op: linear_solve(op, vec, solver).value
solve_op_vec = lambda op, v: linear_solve(op, v, solver).value

op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,))
op_vec_out, t_op_vec_out = eqx.filter_jvp(
solve_op_vec,
(operator, vec),
(t_operator, t_vec),
)
(expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp(
lambda op: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec
), # pyright: ignore
(matrix,),
(t_matrix,),
)
(expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp(
lambda op, v: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v
),
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)

# Work around JAX issue #14868.
if jnp.any(jnp.isnan(t_expected_op_out)):
_, (t_expected_op_out, *_) = finite_difference_jvp(
lambda op: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec
), # pyright: ignore
(matrix,),
(t_matrix,),
)
if jnp.any(jnp.isnan(t_expected_op_vec_out)):
_, (t_expected_op_vec_out, *_) = finite_difference_jvp(
lambda op, v: jnp.linalg.lstsq(
jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v
),
(matrix, vec),
(t_matrix, t_vec), # pyright: ignore
)

real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1)
pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore
expected_vec_out = pinv_matrix @ vec
with jax.numpy_dtype_promotion("standard"):
expected_complex_vec_out = (
expected_vec_out[:in_size] + 1.0j * expected_vec_out[in_size:]
)
expected_complex_op_out = (
expected_op_out[:in_size] + 1.0j * expected_op_out[in_size:]
)
expected_complex_op_vec_out = (
expected_op_vec_out[:in_size] + 1.0j * expected_op_vec_out[in_size:]
)

assert tree_allclose(vec_out, expected_complex_vec_out)
assert tree_allclose(op_out, expected_complex_op_out)
assert tree_allclose(op_vec_out, expected_complex_op_vec_out)

t_expected_vec_out = pinv_matrix @ t_vec

with jax.numpy_dtype_promotion("standard"):
t_expected_complex_vec_out = (
t_expected_vec_out[:in_size] + 1.0j * t_expected_vec_out[in_size:]
)
t_expected_complex_op_out = (
t_expected_op_out[:in_size] + 1.0j * t_expected_op_out[in_size:]
)

t_expected_complex_op_vec_out = (
t_expected_op_vec_out[:in_size] + 1.0j * t_expected_op_vec_out[in_size:]
)
assert tree_allclose(
matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3
)
assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3)
assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3)
23 changes: 23 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.random as jr
import lineax as lx
import pytest
from lineax._misc import complex_to_real_dtype

from .helpers import (
make_diagonal_operator,
Expand Down Expand Up @@ -321,6 +322,28 @@ def test_materialise_function_linear_operator(dtype, getkey):
assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct


def test_materialise_function_real_linear_operator(getkey):
dtype = jnp.complex128
x = (
jr.normal(getkey(), (5, 9), dtype=dtype),
jr.normal(getkey(), (3,), dtype=dtype),
)
input_structure = jax.eval_shape(lambda: x)
fn = lambda x: {"a": jnp.broadcast_to(jnp.sum(x[0]).real, (1, 2))}
output_structure = jax.eval_shape(fn, input_structure)
operator = lx.FunctionLinearOperator(fn, input_structure)
materialised_operator = lx.materialise(operator)
assert materialised_operator.out_structure() == output_structure
assert isinstance(materialised_operator, lx.PyTreeLinearOperator)
expected_struct = {
"a": (
jax.ShapeDtypeStruct((1, 2, 5, 9, 2), complex_to_real_dtype(dtype)),
jax.ShapeDtypeStruct((1, 2, 3, 2), complex_to_real_dtype(dtype)),
)
}
assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct


@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_pytree_transpose(dtype, getkey):
out_struct = jax.eval_shape(
Expand Down
Loading