From 598d9ad72d114881b8115febe998fcc380c1cffc Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 17 Jul 2024 10:14:35 -0700 Subject: [PATCH] Don't mask out zero elements on the diagonal of the matrix when inverting triangular matrices. The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix. Fixes https://github.com/google/jax/issues/3589 Fixes https://github.com/google/jax/issues/15429 PiperOrigin-RevId: 653274967 --- xla/python/xla_client.py | 2 +- xla/service/BUILD | 7 +++++-- xla/service/triangular_solve_expander.cc | 22 +++++++++++++++------- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 13e1fe3d7371ac..e63b058f5ab24c 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 276 +_version = 277 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/xla/service/BUILD b/xla/service/BUILD index 80b8dea89b3b59..fab0681b6c0234 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -2407,10 +2407,9 @@ cc_library( srcs = ["triangular_solve_expander.cc"], hdrs = ["triangular_solve_expander.h"], deps = [ + ":hlo_module_config", ":op_expander_pass", - "//xla:literal", "//xla:shape_util", - "//xla:status_macros", "//xla:util", "//xla/client:xla_builder", "//xla/client:xla_computation", @@ -2418,9 +2417,13 @@ cc_library( "//xla/client/lib:math", "//xla/client/lib:matrix", "//xla/client/lib:slicing", + "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/service/triangular_solve_expander.cc b/xla/service/triangular_solve_expander.cc index 8d58b5a3d46828..c61dc148c0ec33 100644 --- a/xla/service/triangular_solve_expander.cc +++ b/xla/service/triangular_solve_expander.cc @@ -15,10 +15,15 @@ limitations under the License. #include "xla/service/triangular_solve_expander.h" +#include +#include #include +#include +#include #include #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/client/lib/constants.h" #include "xla/client/lib/math.h" @@ -26,10 +31,15 @@ limitations under the License. #include "xla/client/lib/slicing.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/util.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -250,12 +260,10 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks( // multiplications later on diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular); - // Rescale blocks to be unit triangular, but avoid dividing by - // zero (which can happen if the last block was padded) otherwise it will - // introduce nans which will propagate + // Rescale blocks to be unit triangular. We were careful to pad the last + // block with the identity matrix, which means we won't introduce NaNs by + // doing this (unless the matrix is singular, in which case that's ok). auto diags = GetMatrixDiagonal(diag_blocks); - auto ones = FullLike(diags, 1); - diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags); auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2}); // We can now use the fact that for an upper triangular matrix