Skip to content

Commit

Permalink
Don't mask out zero elements on the diagonal of the matrix when inver…
Browse files Browse the repository at this point in the history
…ting triangular matrices.

The intent of this code seems to have been to mask out zeros that were part of padding on the diagonal. However, this isn't correct: if there is a zero on the diagonal, we very much want to get an inf or nan! We also appear to now pad with the identity matrix.

Fixes jax-ml/jax#3589
Fixes jax-ml/jax#15429

PiperOrigin-RevId: 653274967
  • Loading branch information
hawkinsp authored and copybara-github committed Jul 18, 2024
1 parent 3e87c94 commit 598d9ad
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 276
_version = 277

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down
7 changes: 5 additions & 2 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2407,20 +2407,23 @@ cc_library(
srcs = ["triangular_solve_expander.cc"],
hdrs = ["triangular_solve_expander.h"],
deps = [
":hlo_module_config",
":op_expander_pass",
"//xla:literal",
"//xla:shape_util",
"//xla:status_macros",
"//xla:util",
"//xla/client:xla_builder",
"//xla/client:xla_computation",
"//xla/client/lib:constants",
"//xla/client/lib:math",
"//xla/client/lib:matrix",
"//xla/client/lib:slicing",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
22 changes: 15 additions & 7 deletions xla/service/triangular_solve_expander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,31 @@ limitations under the License.

#include "xla/service/triangular_solve_expander.h"

#include <algorithm>
#include <cstdint>
#include <memory>
#include <numeric>
#include <string>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/math.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/lib/slicing.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/literal.h"
#include "xla/hlo/ir/hlo_clone_context.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/util.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -250,12 +260,10 @@ XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
// multiplications later on
diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);

// Rescale blocks to be unit triangular, but avoid dividing by
// zero (which can happen if the last block was padded) otherwise it will
// introduce nans which will propagate
// Rescale blocks to be unit triangular. We were careful to pad the last
// block with the identity matrix, which means we won't introduce NaNs by
// doing this (unless the matrix is singular, in which case that's ok).
auto diags = GetMatrixDiagonal(diag_blocks);
auto ones = FullLike(diags, 1);
diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});

// We can now use the fact that for an upper triangular matrix
Expand Down

0 comments on commit 598d9ad

Please sign in to comment.