From 0fd9b1f5946134960ed844157fba1f235aadd668 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Sun, 17 Nov 2024 14:14:18 +0000 Subject: [PATCH 1/7] Add test that checks C->R case --- lineax/_operator.py | 5 +++++ tests/helpers.py | 8 ++++++++ tests/test_jvp.py | 6 +++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 53d6ccc..2c2f25f 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -1325,6 +1325,11 @@ def _(operator): flat, unravel = strip_weak_dtype( eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) ) + if jnp.result_type(operator.out_structure()) != jnp.result_type( + operator.in_structure() + ): + # We'll use R^2->R representation for C->R function. + pass eye = jnp.eye(flat.size, dtype=flat.dtype) jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) diff --git a/tests/helpers.py b/tests/helpers.py index bb2d396..915be12 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -180,6 +180,14 @@ def make_function_operator(getkey, matrix, tags): return lx.FunctionLinearOperator(fn, in_struct, tags) +@_operators_append +def make_real_function_operator(getkey, matrix, tags): + fn = lambda x: (matrix @ x).real + _, in_size = matrix.shape + in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype) + return lx.FunctionLinearOperator(fn, in_struct, tags) + + @_operators_append def make_jac_operator(getkey, matrix, tags): out_size, in_size = matrix.shape diff --git a/tests/test_jvp.py b/tests/test_jvp.py index 592e530..85f3f61 100644 --- a/tests/test_jvp.py +++ b/tests/test_jvp.py @@ -27,12 +27,16 @@ has_tag, make_jac_operator, make_matrix_operator, + make_real_function_operator, solvers_tags_pseudoinverse, tree_allclose, ) -@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) +@pytest.mark.parametrize( + "make_operator", + (make_matrix_operator, make_jac_operator, make_real_function_operator), +) @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( From 41791943add5cabf4b77dec544129ae9e280dd8f Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 18 Nov 2024 13:49:01 +0000 Subject: [PATCH 2/7] Materialize functional C->R operator --- lineax/_operator.py | 44 +++++++++++++++++++++++++++++++++++------- tests/test_operator.py | 23 ++++++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 2c2f25f..75b2b3f 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -28,6 +28,7 @@ import jax.tree_util as jtu import numpy as np from equinox.internal import ω +from jax import ShapeDtypeStruct from jaxtyping import ( Array, ArrayLike, @@ -39,6 +40,7 @@ from ._custom_types import sentinel from ._misc import ( + complex_to_real_dtype, default_floating_dtype, inexact_asarray, jacobian, @@ -1322,16 +1324,44 @@ def _(operator): @materialise.register(FunctionLinearOperator) def _(operator): - flat, unravel = strip_weak_dtype( - eqx.filter_eval_shape(jfu.ravel_pytree, operator.in_structure()) + complex_input = jnp.isdtype( + jnp.result_type(*(jax.tree.flatten(operator.in_structure())[0])), + "complex floating", ) - if jnp.result_type(operator.out_structure()) != jnp.result_type( - operator.in_structure() - ): + real_output = not jnp.isdtype( + jnp.result_type(*(jax.tree.flatten(operator.out_structure())[0])), + "complex floating", + ) + if complex_input and real_output: # We'll use R^2->R representation for C->R function. - pass + in_structure = jtu.tree_map( + lambda x: ShapeDtypeStruct( + tuple(x.shape) + (2,), complex_to_real_dtype(x.dtype) + ) + if jnp.isdtype(x.dtype, "complex floating") + else x, + operator.in_structure(), + ) + + def map_to_original(x): + with jax.numpy_dtype_promotion("standard"): + return jtu.tree_map( + lambda x, struct: x[..., 0] + 1.0j * x[..., 1] + if jnp.isdtype(struct.dtype, "complex floating") + else x, + x, + operator.in_structure(), + ) + else: + map_to_original = lambda x: x + in_structure = operator.in_structure() + flat, unravel = strip_weak_dtype( + eqx.filter_eval_shape(jfu.ravel_pytree, in_structure) + ) + fn = lambda x: operator.fn(map_to_original(unravel(x))) eye = jnp.eye(flat.size, dtype=flat.dtype) - jac = jax.vmap(lambda x: operator.fn(unravel(x)), out_axes=-1)(eye) + + jac = jax.vmap(fn, out_axes=-1)(eye) def batch_unravel(x): assert x.ndim > 0 diff --git a/tests/test_operator.py b/tests/test_operator.py index 2a10135..aa8140d 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -20,6 +20,7 @@ import jax.random as jr import lineax as lx import pytest +from lineax._misc import complex_to_real_dtype from .helpers import ( make_diagonal_operator, @@ -321,6 +322,28 @@ def test_materialise_function_linear_operator(dtype, getkey): assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct +def test_materialise_function_real_linear_operator(getkey): + dtype = jnp.complex128 + x = ( + jr.normal(getkey(), (5, 9), dtype=dtype), + jr.normal(getkey(), (3,), dtype=dtype), + ) + input_structure = jax.eval_shape(lambda: x) + fn = lambda x: {"a": jnp.broadcast_to(jnp.sum(x[0]).real, (1, 2))} + output_structure = jax.eval_shape(fn, input_structure) + operator = lx.FunctionLinearOperator(fn, input_structure) + materialised_operator = lx.materialise(operator) + assert materialised_operator.out_structure() == output_structure + assert isinstance(materialised_operator, lx.PyTreeLinearOperator) + expected_struct = { + "a": ( + jax.ShapeDtypeStruct((1, 2, 5, 9, 2), complex_to_real_dtype(dtype)), + jax.ShapeDtypeStruct((1, 2, 3, 2), complex_to_real_dtype(dtype)), + ) + } + assert jax.eval_shape(lambda: materialised_operator.pytree) == expected_struct + + @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_pytree_transpose(dtype, getkey): out_struct = jax.eval_shape( From 299d8acadc857f331f80c4fdc53530738e3e0ec0 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 18 Nov 2024 15:53:50 +0000 Subject: [PATCH 3/7] Start implementing solver for C->R --- lineax/_misc.py | 23 +++++++++++++++++++++++ lineax/_operator.py | 24 ++++++------------------ lineax/_solver/misc.py | 11 ++++++++--- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/lineax/_misc.py b/lineax/_misc.py index 3cc5117..dc47693 100644 --- a/lineax/_misc.py +++ b/lineax/_misc.py @@ -19,6 +19,7 @@ import jax.core import jax.numpy as jnp import jax.tree_util as jtu +from jax import ShapeDtypeStruct from jaxtyping import Array, ArrayLike, Bool, PyTree # pyright:ignore @@ -110,3 +111,25 @@ def structure_equal(x, y) -> bool: x = strip_weak_dtype(jax.eval_shape(lambda: x)) y = strip_weak_dtype(jax.eval_shape(lambda: y)) return eqx.tree_equal(x, y) is True + + +def complex_to_real_structure(in_structure): + return jtu.tree_map( + lambda x: ShapeDtypeStruct( + tuple(x.shape) + (2,), complex_to_real_dtype(x.dtype) + ) + if jnp.isdtype(x.dtype, "complex floating") + else x, + in_structure, + ) + + +def real_to_complex_tree(x, in_structure): + with jax.numpy_dtype_promotion("standard"): + return jtu.tree_map( + lambda x, struct: x[..., 0] + 1.0j * x[..., 1] + if jnp.isdtype(struct.dtype, "complex floating") + else x, + x, + in_structure, + ) diff --git a/lineax/_operator.py b/lineax/_operator.py index 75b2b3f..d4af182 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -28,7 +28,6 @@ import jax.tree_util as jtu import numpy as np from equinox.internal import ω -from jax import ShapeDtypeStruct from jaxtyping import ( Array, ArrayLike, @@ -40,11 +39,12 @@ from ._custom_types import sentinel from ._misc import ( - complex_to_real_dtype, + complex_to_real_structure, default_floating_dtype, inexact_asarray, jacobian, NoneAux, + real_to_complex_tree, strip_weak_dtype, ) from ._tags import ( @@ -1334,24 +1334,12 @@ def _(operator): ) if complex_input and real_output: # We'll use R^2->R representation for C->R function. - in_structure = jtu.tree_map( - lambda x: ShapeDtypeStruct( - tuple(x.shape) + (2,), complex_to_real_dtype(x.dtype) - ) - if jnp.isdtype(x.dtype, "complex floating") - else x, + in_structure = complex_to_real_structure(operator.in_structure()) + + map_to_original = lambda x: real_to_complex_tree( + x, operator.in_structure(), ) - - def map_to_original(x): - with jax.numpy_dtype_promotion("standard"): - return jtu.tree_map( - lambda x, struct: x[..., 0] + 1.0j * x[..., 1] - if jnp.isdtype(struct.dtype, "complex floating") - else x, - x, - operator.in_structure(), - ) else: map_to_original = lambda x: x in_structure = operator.in_structure() diff --git a/lineax/_solver/misc.py b/lineax/_solver/misc.py index b7e1a09..cfcfdb1 100644 --- a/lineax/_solver/misc.py +++ b/lineax/_solver/misc.py @@ -22,7 +22,12 @@ import numpy as np from jaxtyping import Array, PyTree, Shaped -from .._misc import strip_weak_dtype, structure_equal +from .._misc import ( + complex_to_real_structure, + real_to_complex_tree, + strip_weak_dtype, + structure_equal, +) from .._operator import ( AbstractLinearOperator, IdentityLinearOperator, @@ -96,14 +101,14 @@ def unravel_solution( ) -> PyTree[Array]: leaves, treedef = packed_structures.value _, in_structure = jtu.tree_unflatten(treedef, leaves) - leaves, treedef = jtu.tree_flatten(in_structure) + leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure)) sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]]) split = jnp.split(solution, sizes) assert len(split) == len(leaves) with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore complex-to-real cast warning shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)] - return jtu.tree_unflatten(treedef, shaped) + return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure) def transpose_packed_structures( From b0f0e291113128c6cef8d5c358ea999e7d088106 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Mon, 18 Nov 2024 16:16:49 +0000 Subject: [PATCH 4/7] add c->r jvp test --- tests/test_jvp.py | 141 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 137 insertions(+), 4 deletions(-) diff --git a/tests/test_jvp.py b/tests/test_jvp.py index 85f3f61..55da937 100644 --- a/tests/test_jvp.py +++ b/tests/test_jvp.py @@ -15,10 +15,12 @@ import functools as ft import equinox as eqx +import jax import jax.numpy as jnp import jax.random as jr import lineax as lx import pytest +from lineax._misc import complex_to_real_dtype from .helpers import ( construct_matrix, @@ -33,10 +35,7 @@ ) -@pytest.mark.parametrize( - "make_operator", - (make_matrix_operator, make_jac_operator, make_real_function_operator), -) +@pytest.mark.parametrize("make_operator", (make_matrix_operator, make_jac_operator)) @pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( @@ -122,3 +121,137 @@ def test_jvp( assert tree_allclose(matrix @ t_vec_out, matrix @ t_expected_vec_out, rtol=1e-3) assert tree_allclose(t_op_out, t_expected_op_out, rtol=1e-3) assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3) + + +@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) +@pytest.mark.parametrize("use_state", (True, False)) +@pytest.mark.parametrize( + "make_matrix", + ( + construct_matrix, + construct_singular_matrix, + ), +) +def test_jvp_c_to_r(getkey, solver, tags, pseudoinverse, use_state, make_matrix): + dtype = jnp.complex128 + make_operator = make_real_function_operator + t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None + if solver not in ( + lx.QR(), + lx.SVD(), + ): + print(solver) + return + pytest.skip("Real function operators are only supported for QR and SVD.") + if (make_matrix is construct_matrix) or pseudoinverse: + matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) + + out_size, _ = matrix.shape + out_dtype = ( + complex_to_real_dtype(matrix.dtype) + if make_operator == make_real_function_operator + else matrix.dtype + ) + vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) + t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) + + if has_tag(tags, lx.unit_diagonal_tag): + # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} + # still satisfies the tag itself. + # This is the exception. + t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0) + + make_op = ft.partial(make_operator, getkey) + operator, t_operator = eqx.filter_jvp( + make_op, (matrix, tags), (t_matrix, t_tags) + ) + + if use_state: + state = solver.init(operator, options={}) + linear_solve = ft.partial(lx.linear_solve, state=state) + else: + linear_solve = lx.linear_solve + + solve_vec_only = lambda v: linear_solve(operator, v, solver).value + vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,)) + + solve_op_only = lambda op: linear_solve(op, vec, solver).value + solve_op_vec = lambda op, v: linear_solve(op, v, solver).value + + op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,)) + op_vec_out, t_op_vec_out = eqx.filter_jvp( + solve_op_vec, + (operator, vec), + (t_operator, t_vec), + ) + (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp( + lambda op: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec + ), # pyright: ignore + (matrix,), + (t_matrix,), + ) + (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp( + lambda op, v: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v + ), + (matrix, vec), + (t_matrix, t_vec), # pyright: ignore + ) + + # Work around JAX issue #14868. + if jnp.any(jnp.isnan(t_expected_op_out)): + _, (t_expected_op_out, *_) = finite_difference_jvp( + lambda op: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec + ), # pyright: ignore + (matrix,), + (t_matrix,), + ) + if jnp.any(jnp.isnan(t_expected_op_vec_out)): + _, (t_expected_op_vec_out, *_) = finite_difference_jvp( + lambda op, v: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v + ), + (matrix, vec), + (t_matrix, t_vec), # pyright: ignore + ) + + real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1) + pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore + expected_vec_out = pinv_matrix @ vec + with jax.numpy_dtype_promotion("standard"): + expected_complex_vec_out = ( + expected_vec_out[:out_size] + 1.0j * expected_vec_out[out_size:] + ) + assert tree_allclose(vec_out, expected_complex_vec_out) + + with jax.numpy_dtype_promotion("standard"): + expected_complex_op_out = ( + expected_op_out[:out_size] + 1.0j * expected_op_out[out_size:] + ) + expected_complex_op_vec_out = ( + expected_op_vec_out[:out_size] + 1.0j * expected_op_vec_out[out_size:] + ) + + assert tree_allclose(op_out, expected_complex_op_out) + assert tree_allclose(op_vec_out, expected_complex_op_vec_out) + + t_expected_vec_out = pinv_matrix @ t_vec + + with jax.numpy_dtype_promotion("standard"): + t_expected_complex_vec_out = ( + t_expected_vec_out[:out_size] + 1.0j * t_expected_vec_out[out_size:] + ) + t_expected_complex_op_out = ( + t_expected_op_out[:out_size] + 1.0j * t_expected_op_out[out_size:] + ) + t_expected_complex_op_vec_out = ( + t_expected_op_vec_out[:out_size] + + 1.0j * t_expected_op_vec_out[out_size:] + ) + assert tree_allclose( + matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3 + ) + assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3) + assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3) From b55ec0e42e0a4ea6ec974cc3903e3d9f5cbb0f9d Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Tue, 19 Nov 2024 04:33:42 +0000 Subject: [PATCH 5/7] Fix raveling --- lineax/_misc.py | 18 ++++ lineax/_operator.py | 13 +-- lineax/_solver/misc.py | 22 ++++- tests/test_jvp.py | 192 +++++++++++++++++++---------------------- 4 files changed, 129 insertions(+), 116 deletions(-) diff --git a/lineax/_misc.py b/lineax/_misc.py index dc47693..30bd665 100644 --- a/lineax/_misc.py +++ b/lineax/_misc.py @@ -113,6 +113,13 @@ def structure_equal(x, y) -> bool: return eqx.tree_equal(x, y) is True +def is_complex_structure(structure): + return jnp.isdtype( + jnp.result_type(*(jax.tree.flatten(structure)[0])), + "complex floating", + ) + + def complex_to_real_structure(in_structure): return jtu.tree_map( lambda x: ShapeDtypeStruct( @@ -124,6 +131,17 @@ def complex_to_real_structure(in_structure): ) +def complex_to_real_tree(x, in_structure): + with jax.numpy_dtype_promotion("standard"): + return jtu.tree_map( + lambda x, struct: jnp.stack([x.real, x.imag], axis=-1) + if jnp.isdtype(struct.dtype, "complex floating") + else x, + x, + in_structure, + ) + + def real_to_complex_tree(x, in_structure): with jax.numpy_dtype_promotion("standard"): return jtu.tree_map( diff --git a/lineax/_operator.py b/lineax/_operator.py index d4af182..d93cbbb 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -42,6 +42,7 @@ complex_to_real_structure, default_floating_dtype, inexact_asarray, + is_complex_structure, jacobian, NoneAux, real_to_complex_tree, @@ -1324,15 +1325,9 @@ def _(operator): @materialise.register(FunctionLinearOperator) def _(operator): - complex_input = jnp.isdtype( - jnp.result_type(*(jax.tree.flatten(operator.in_structure())[0])), - "complex floating", - ) - real_output = not jnp.isdtype( - jnp.result_type(*(jax.tree.flatten(operator.out_structure())[0])), - "complex floating", - ) - if complex_input and real_output: + if is_complex_structure(operator.in_structure()) and not is_complex_structure( + operator.out_structure() + ): # We'll use R^2->R representation for C->R function. in_structure = complex_to_real_structure(operator.in_structure()) diff --git a/lineax/_solver/misc.py b/lineax/_solver/misc.py index cfcfdb1..8bc982a 100644 --- a/lineax/_solver/misc.py +++ b/lineax/_solver/misc.py @@ -24,6 +24,8 @@ from .._misc import ( complex_to_real_structure, + complex_to_real_tree, + is_complex_structure, real_to_complex_tree, strip_weak_dtype, structure_equal, @@ -86,11 +88,14 @@ def ravel_vector( pytree: PyTree[Array], packed_structures: PackedStructures ) -> Shaped[Array, " size"]: leaves, treedef = packed_structures.value - out_structure, _ = jtu.tree_unflatten(treedef, leaves) + out_structure, in_structure = jtu.tree_unflatten(treedef, leaves) # `is` in case `tree_equal` returns a Tracer. if not structure_equal(pytree, out_structure): raise ValueError("pytree does not match out_structure") # not using `ravel_pytree` as that doesn't come with guarantees about order + + if is_complex_structure(out_structure) and not is_complex_structure(in_structure): + pytree = complex_to_real_tree(pytree, out_structure) leaves = jtu.tree_leaves(pytree) dtype = jnp.result_type(*leaves) return jnp.concatenate([x.astype(dtype).reshape(-1) for x in leaves]) @@ -100,15 +105,24 @@ def unravel_solution( solution: Shaped[Array, " size"], packed_structures: PackedStructures ) -> PyTree[Array]: leaves, treedef = packed_structures.value - _, in_structure = jtu.tree_unflatten(treedef, leaves) - leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure)) + out_structure, in_structure = jtu.tree_unflatten(treedef, leaves) + complex_real = is_complex_structure(in_structure) and not is_complex_structure( + out_structure + ) + if complex_real: + leaves, treedef = jtu.tree_flatten(complex_to_real_structure(in_structure)) + else: + leaves, treedef = jtu.tree_flatten(in_structure) sizes = np.cumsum([math.prod(x.shape) for x in leaves[:-1]]) split = jnp.split(solution, sizes) assert len(split) == len(leaves) with warnings.catch_warnings(): warnings.simplefilter("ignore") # ignore complex-to-real cast warning shaped = [x.reshape(y.shape).astype(y.dtype) for x, y in zip(split, leaves)] - return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure) + if complex_real: + return real_to_complex_tree(jtu.tree_unflatten(treedef, shaped), in_structure) + else: + return jtu.tree_unflatten(treedef, shaped) def transpose_packed_structures( diff --git a/tests/test_jvp.py b/tests/test_jvp.py index 55da937..8478574 100644 --- a/tests/test_jvp.py +++ b/tests/test_jvp.py @@ -123,7 +123,10 @@ def test_jvp( assert tree_allclose(t_op_vec_out, t_expected_op_vec_out, rtol=1e-3) -@pytest.mark.parametrize("solver, tags, pseudoinverse", solvers_tags_pseudoinverse) +@pytest.mark.parametrize( + "solver, tags, pseudoinverse", + [stp for stp in solvers_tags_pseudoinverse if stp[-1]], +) # only pseudoinverse @pytest.mark.parametrize("use_state", (True, False)) @pytest.mark.parametrize( "make_matrix", @@ -133,65 +136,68 @@ def test_jvp( ), ) def test_jvp_c_to_r(getkey, solver, tags, pseudoinverse, use_state, make_matrix): - dtype = jnp.complex128 - make_operator = make_real_function_operator t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None - if solver not in ( - lx.QR(), - lx.SVD(), - ): - print(solver) - return - pytest.skip("Real function operators are only supported for QR and SVD.") - if (make_matrix is construct_matrix) or pseudoinverse: - matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) - - out_size, _ = matrix.shape - out_dtype = ( - complex_to_real_dtype(matrix.dtype) - if make_operator == make_real_function_operator - else matrix.dtype - ) - vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) - t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) - - if has_tag(tags, lx.unit_diagonal_tag): - # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} - # still satisfies the tag itself. - # This is the exception. - t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0) - - make_op = ft.partial(make_operator, getkey) - operator, t_operator = eqx.filter_jvp( - make_op, (matrix, tags), (t_matrix, t_tags) - ) - - if use_state: - state = solver.init(operator, options={}) - linear_solve = ft.partial(lx.linear_solve, state=state) - else: - linear_solve = lx.linear_solve - solve_vec_only = lambda v: linear_solve(operator, v, solver).value - vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,)) - - solve_op_only = lambda op: linear_solve(op, vec, solver).value - solve_op_vec = lambda op, v: linear_solve(op, v, solver).value - - op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,)) - op_vec_out, t_op_vec_out = eqx.filter_jvp( - solve_op_vec, - (operator, vec), - (t_operator, t_vec), - ) - (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp( + matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=jnp.complex128) + + out_size, in_size = matrix.shape + out_dtype = complex_to_real_dtype(matrix.dtype) + vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) + t_vec = jr.normal(getkey(), (out_size,), dtype=out_dtype) + + if has_tag(tags, lx.unit_diagonal_tag): + # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} + # still satisfies the tag itself. + # This is the exception. + t_matrix.at[jnp.arange(3), jnp.arange(3)].set(0) + + make_op = ft.partial(make_real_function_operator, getkey) + operator, t_operator = eqx.filter_jvp(make_op, (matrix, tags), (t_matrix, t_tags)) + + if use_state: + state = solver.init(operator, options={}) + linear_solve = ft.partial(lx.linear_solve, state=state) + else: + linear_solve = lx.linear_solve + + solve_vec_only = lambda v: linear_solve(operator, v, solver).value + vec_out, t_vec_out = eqx.filter_jvp(solve_vec_only, (vec,), (t_vec,)) + + solve_op_only = lambda op: linear_solve(op, vec, solver).value + solve_op_vec = lambda op, v: linear_solve(op, v, solver).value + + op_out, t_op_out = eqx.filter_jvp(solve_op_only, (operator,), (t_operator,)) + op_vec_out, t_op_vec_out = eqx.filter_jvp( + solve_op_vec, + (operator, vec), + (t_operator, t_vec), + ) + (expected_op_out, *_), (t_expected_op_out, *_) = eqx.filter_jvp( + lambda op: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec + ), # pyright: ignore + (matrix,), + (t_matrix,), + ) + (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp( + lambda op, v: jnp.linalg.lstsq( + jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v + ), + (matrix, vec), + (t_matrix, t_vec), # pyright: ignore + ) + + # Work around JAX issue #14868. + if jnp.any(jnp.isnan(t_expected_op_out)): + _, (t_expected_op_out, *_) = finite_difference_jvp( lambda op: jnp.linalg.lstsq( jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec ), # pyright: ignore (matrix,), (t_matrix,), ) - (expected_op_vec_out, *_), (t_expected_op_vec_out, *_) = eqx.filter_jvp( + if jnp.any(jnp.isnan(t_expected_op_vec_out)): + _, (t_expected_op_vec_out, *_) = finite_difference_jvp( lambda op, v: jnp.linalg.lstsq( jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v ), @@ -199,59 +205,39 @@ def test_jvp_c_to_r(getkey, solver, tags, pseudoinverse, use_state, make_matrix) (t_matrix, t_vec), # pyright: ignore ) - # Work around JAX issue #14868. - if jnp.any(jnp.isnan(t_expected_op_out)): - _, (t_expected_op_out, *_) = finite_difference_jvp( - lambda op: jnp.linalg.lstsq( - jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), vec - ), # pyright: ignore - (matrix,), - (t_matrix,), - ) - if jnp.any(jnp.isnan(t_expected_op_vec_out)): - _, (t_expected_op_vec_out, *_) = finite_difference_jvp( - lambda op, v: jnp.linalg.lstsq( - jnp.concatenate([jnp.real(op), -jnp.imag(op)], axis=-1), v - ), - (matrix, vec), - (t_matrix, t_vec), # pyright: ignore - ) + real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1) + pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore + expected_vec_out = pinv_matrix @ vec + with jax.numpy_dtype_promotion("standard"): + expected_complex_vec_out = ( + expected_vec_out[:in_size] + 1.0j * expected_vec_out[in_size:] + ) + expected_complex_op_out = ( + expected_op_out[:in_size] + 1.0j * expected_op_out[in_size:] + ) + expected_complex_op_vec_out = ( + expected_op_vec_out[:in_size] + 1.0j * expected_op_vec_out[in_size:] + ) - real_mat = jnp.concatenate([jnp.real(matrix), -jnp.imag(matrix)], axis=-1) - pinv_matrix = jnp.linalg.pinv(real_mat) # pyright: ignore - expected_vec_out = pinv_matrix @ vec - with jax.numpy_dtype_promotion("standard"): - expected_complex_vec_out = ( - expected_vec_out[:out_size] + 1.0j * expected_vec_out[out_size:] - ) - assert tree_allclose(vec_out, expected_complex_vec_out) + assert tree_allclose(vec_out, expected_complex_vec_out) + assert tree_allclose(op_out, expected_complex_op_out) + assert tree_allclose(op_vec_out, expected_complex_op_vec_out) - with jax.numpy_dtype_promotion("standard"): - expected_complex_op_out = ( - expected_op_out[:out_size] + 1.0j * expected_op_out[out_size:] - ) - expected_complex_op_vec_out = ( - expected_op_vec_out[:out_size] + 1.0j * expected_op_vec_out[out_size:] - ) + t_expected_vec_out = pinv_matrix @ t_vec - assert tree_allclose(op_out, expected_complex_op_out) - assert tree_allclose(op_vec_out, expected_complex_op_vec_out) - - t_expected_vec_out = pinv_matrix @ t_vec + with jax.numpy_dtype_promotion("standard"): + t_expected_complex_vec_out = ( + t_expected_vec_out[:in_size] + 1.0j * t_expected_vec_out[in_size:] + ) + t_expected_complex_op_out = ( + t_expected_op_out[:in_size] + 1.0j * t_expected_op_out[in_size:] + ) - with jax.numpy_dtype_promotion("standard"): - t_expected_complex_vec_out = ( - t_expected_vec_out[:out_size] + 1.0j * t_expected_vec_out[out_size:] - ) - t_expected_complex_op_out = ( - t_expected_op_out[:out_size] + 1.0j * t_expected_op_out[out_size:] - ) - t_expected_complex_op_vec_out = ( - t_expected_op_vec_out[:out_size] - + 1.0j * t_expected_op_vec_out[out_size:] - ) - assert tree_allclose( - matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3 + t_expected_complex_op_vec_out = ( + t_expected_op_vec_out[:in_size] + 1.0j * t_expected_op_vec_out[in_size:] ) - assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3) - assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3) + assert tree_allclose( + matrix @ t_vec_out, matrix @ t_expected_complex_vec_out, rtol=1e-3 + ) + assert tree_allclose(t_op_out, t_expected_complex_op_out, rtol=1e-3) + assert tree_allclose(t_op_vec_out, t_expected_complex_op_vec_out, rtol=1e-3) From 8bdf8952f47b240b21f8dd42bf2dfefe2c277ffd Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Tue, 19 Nov 2024 04:37:17 +0000 Subject: [PATCH 6/7] Don't add `make_real_function_operator` to the list of operators since it behaves differently --- tests/helpers.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 915be12..e3dd25e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -180,14 +180,6 @@ def make_function_operator(getkey, matrix, tags): return lx.FunctionLinearOperator(fn, in_struct, tags) -@_operators_append -def make_real_function_operator(getkey, matrix, tags): - fn = lambda x: (matrix @ x).real - _, in_size = matrix.shape - in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype) - return lx.FunctionLinearOperator(fn, in_struct, tags) - - @_operators_append def make_jac_operator(getkey, matrix, tags): out_size, in_size = matrix.shape @@ -262,6 +254,13 @@ def make_composed_operator(getkey, matrix, tags): return lx.TaggedLinearOperator(operator1 @ operator2, tags) +def make_real_function_operator(getkey, matrix, tags): + fn = lambda x: (matrix @ x).real + _, in_size = matrix.shape + in_struct = jax.ShapeDtypeStruct((in_size,), matrix.dtype) + return lx.FunctionLinearOperator(fn, in_struct, tags) + + # Slightly sketchy approach to finite differences, in that this is pulled out of # Numerical Recipes. # I also don't know of a handling of the JVP case off the top of my head -- although From 8ce46c89718769ef2cf67dbb0a5b29aa28e7fe9a Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Tue, 19 Nov 2024 04:42:16 +0000 Subject: [PATCH 7/7] Fix mixed dtypes --- lineax/_misc.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lineax/_misc.py b/lineax/_misc.py index 30bd665..4deb170 100644 --- a/lineax/_misc.py +++ b/lineax/_misc.py @@ -114,10 +114,11 @@ def structure_equal(x, y) -> bool: def is_complex_structure(structure): - return jnp.isdtype( - jnp.result_type(*(jax.tree.flatten(structure)[0])), - "complex floating", - ) + with jax.numpy_dtype_promotion("standard"): + return jnp.isdtype( + jnp.result_type(*(jax.tree.flatten(structure)[0])), + "complex floating", + ) def complex_to_real_structure(in_structure):