diff --git a/exla/c_src/exla/custom_calls.cc b/exla/c_src/exla/custom_calls.cc index d5beee7958..8acd67cab6 100644 --- a/exla/c_src/exla/custom_calls.cc +++ b/exla/c_src/exla/custom_calls.cc @@ -4,6 +4,10 @@ void qr_cpu_custom_call_f32(void *out[], const void *in[]); void qr_cpu_custom_call_f64(void *out[], const void *in[]); void qr_cpu_custom_call_f16(void *out[], const void *in[]); void qr_cpu_custom_call_bf16(void *out[], const void *in[]); +void lu_cpu_custom_call_f32(void *out[], const void *in[]); +void lu_cpu_custom_call_f64(void *out[], const void *in[]); +void lu_cpu_custom_call_f16(void *out[], const void *in[]); +void lu_cpu_custom_call_bf16(void *out[], const void *in[]); void eigh_cpu_custom_call_f32(void *out[], const void *in[]); void eigh_cpu_custom_call_f64(void *out[], const void *in[]); @@ -12,4 +16,8 @@ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_cu XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64); -XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); \ No newline at end of file +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f64", lu_cpu_custom_call_f64); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f32", lu_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_f16", lu_cpu_custom_call_f16); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("lu_cpu_custom_call_bf16", lu_cpu_custom_call_bf16); \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls/lu.h b/exla/c_src/exla/custom_calls/lu.h new file mode 100644 index 0000000000..1c72565d4b --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu.h @@ -0,0 +1,95 @@ +#pragma once + +#include "Eigen/LU"; + +template +void single_matrix_lu_cpu_custom_call(uint8_t *p_out, DataType *l_out, DataType *u_out, DataType *in, uint64_t n) { + typedef Eigen::Matrix RowMajorMatrix; + + Eigen::Map input(in, n, n); + Eigen::PartialPivLU lu = input.partialPivLu(); + + // Get the permutation matrix P and convert to indices + Eigen::PermutationMatrix P = lu.permutationP(); + for (uint64_t i = 0; i < n; i++) { + for (uint64_t j = 0; j < n; j++) { + p_out[i * n + j] = static_cast(P.indices()[i] == j ? 1 : 0); + } + } + + // Get L and U matrices + RowMajorMatrix L = lu.matrixLU().template triangularView(); + RowMajorMatrix U = lu.matrixLU().template triangularView(); + + // Copy L matrix + for (uint64_t i = 0; i < n; i++) { + for (uint64_t j = 0; j < n; j++) { + + if (j < i) { + l_out[i * n + j] = static_cast(L(i, j)); + } else if (j == i) { + l_out[i * n + j] = static_cast(1.0); + } else { + l_out[i * n + j] = static_cast(0.0); + } + } + } + + // Copy U matrix + for (uint64_t i = 0; i < n; i++) { + for (uint64_t j = 0; j < n; j++) { + if (j >= i) { + u_out[i * n + j] = static_cast(U(i, j)); + } else { + u_out[i * n + j] = static_cast(0.0); + } + } + } +} + +template +void lu_cpu_custom_call(void *out[], const void *in[]) { + DataType *operand = (DataType *)in[0]; + + uint64_t *dim_sizes = (uint64_t *)in[1]; + uint64_t num_operand_dims = dim_sizes[0]; + uint64_t num_p_dims = dim_sizes[1]; + uint64_t num_l_dims = dim_sizes[2]; + uint64_t num_u_dims = dim_sizes[3]; + + uint64_t *operand_dims_ptr = (uint64_t *)in[2]; + std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); + + uint64_t *p_dims_ptr = (uint64_t *)in[3]; + std::vector p_dims(p_dims_ptr, p_dims_ptr + num_p_dims); + + uint64_t *l_dims_ptr = (uint64_t *)in[4]; + std::vector l_dims(l_dims_ptr, l_dims_ptr + num_l_dims); + + uint64_t *u_dims_ptr = (uint64_t *)in[5]; + std::vector u_dims(u_dims_ptr, u_dims_ptr + num_u_dims); + + uint64_t n = l_dims[l_dims.size() - 1]; + + auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + + uint64_t batch_items = 1; + for (uint64_t i = 0; i < leading_dimensions.size(); i++) { + batch_items *= leading_dimensions[i]; + } + + uint8_t *p = (uint8_t *)out[0]; + DataType *l = (DataType *)out[1]; + DataType *u = (DataType *)out[2]; + + uint64_t stride = n * n; + + for (uint64_t i = 0; i < batch_items; i++) { + single_matrix_lu_cpu_custom_call( + p + i * stride, + l + i * stride, + u + i * stride, + operand + i * stride, + n); + } +} \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls/lu_bf16.cc b/exla/c_src/exla/custom_calls/lu_bf16.cc new file mode 100644 index 0000000000..806f886b4c --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_bf16.cc @@ -0,0 +1,6 @@ +#include "lu.h" +#include "../exla_types.h" + +void lu_cpu_custom_call_bf16(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/lu_f16.cc b/exla/c_src/exla/custom_calls/lu_f16.cc new file mode 100644 index 0000000000..81f6724e6e --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_f16.cc @@ -0,0 +1,6 @@ +#include "lu.h" +#include "../exla_types.h" + +void lu_cpu_custom_call_f16(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/lu_f32.cc b/exla/c_src/exla/custom_calls/lu_f32.cc new file mode 100644 index 0000000000..c506caab72 --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_f32.cc @@ -0,0 +1,5 @@ +#include "lu.h" + +void lu_cpu_custom_call_f32(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/lu_f64.cc b/exla/c_src/exla/custom_calls/lu_f64.cc new file mode 100644 index 0000000000..aed6ed2dab --- /dev/null +++ b/exla/c_src/exla/custom_calls/lu_f64.cc @@ -0,0 +1,5 @@ +#include "lu.h" + +void lu_cpu_custom_call_f64(void *out[], const void *in[]) { + lu_cpu_custom_call(out, in); +} diff --git a/exla/c_src/exla/custom_calls/qr.h b/exla/c_src/exla/custom_calls/qr.h index 3615353ddf..85e881447c 100644 --- a/exla/c_src/exla/custom_calls/qr.h +++ b/exla/c_src/exla/custom_calls/qr.h @@ -73,15 +73,15 @@ void qr_cpu_custom_call(void *out[], const void *in[]) { DataType *q = (DataType *)out[0]; DataType *r = (DataType *)out[1]; - uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2] * sizeof(DataType); - uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2] * sizeof(DataType); - uint64_t inner_stride = m * n * sizeof(DataType); + uint64_t r_stride = r_dims[r_dims.size() - 1] * r_dims[r_dims.size() - 2]; + uint64_t q_stride = q_dims[q_dims.size() - 1] * q_dims[q_dims.size() - 2]; + uint64_t inner_stride = m * n; for (uint64_t i = 0; i < batch_items; i++) { single_matrix_qr_cpu_custom_call( (DataType *)out[0] + i * q_stride, (DataType *)out[1] + i * r_stride, - operand + i * inner_stride * sizeof(DataType), + operand + i * inner_stride, m, k, n, complete); } } \ No newline at end of file diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index e7f7bfae47..5e8b94c5ab 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -544,6 +544,43 @@ defmodule EXLA.Defn do end end + defp cached_recur_operator( + :lu, + %T{data: %Expr{args: [{p_expr, l_expr, u_expr}, tensor, _opts]}}, + state, + cache + ) do + %{type: {p_type_kind, _}} = p_expr + %{type: {out_type_kind, _}} = l_expr + + if state.client.platform != :host do + raise ArgumentError, "XLA does not currently support the LU operation on non-host devices" + end + + if p_type_kind == :c or out_type_kind == :c do + raise ArgumentError, "XLA does not currently support the LU operation for complex inputs" + end + + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + tensor = + if op_type(tensor) != u_expr.type do + to_type(tensor, u_expr.type) + else + tensor + end + + {p, l, u} = + Value.lu( + tensor, + expr_to_typespec(p_expr), + expr_to_typespec(l_expr), + expr_to_typespec(u_expr) + ) + + {[p, l, u], cache} + end + defp cached_recur_operator(:attach_token, %T{data: %Expr{args: [token, expr]}}, state, cache) do {op, cache} = recur_operator(expr, state, cache) {_, cache} = recur_operator(token, state, cache) @@ -772,10 +809,6 @@ defmodule EXLA.Defn do end end - defp to_operator(:lu, [{_, _, _}, _tensor, _opts], _ans, _state) do - raise ArgumentError, "XLA does not currently support the LU operation" - end - ## to_operator element-wise defp to_operator(:negate, [%Value{} = op], ans, _state), diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 5dfd72ca23..e38d09fc0b 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -815,6 +815,81 @@ defmodule EXLA.MLIR.Value do {q, r} end + def lu(%Value{function: func} = value, p_typespec, l_typespec, u_typespec) do + %{type: op_type, shape: op_shape} = get_typespec(value) + %{type: _p_type, shape: p_shape} = p_typespec + %{type: l_type, shape: l_shape} = l_typespec + %{type: u_type, shape: u_shape} = u_typespec + + dim_sizes = [ + tuple_size(op_shape), + tuple_size(p_shape), + tuple_size(l_shape), + tuple_size(u_shape) + ] + + operand_dims = Tuple.to_list(op_shape) + p_dims = Tuple.to_list(p_shape) + l_dims = Tuple.to_list(l_shape) + u_dims = Tuple.to_list(u_shape) + + dim_sizes = constant(func, dim_sizes, Typespec.tensor({:u, 64}, {length(dim_sizes)})) + operand_dims = constant(func, operand_dims, Typespec.tensor({:u, 64}, {length(operand_dims)})) + p_dims = constant(func, p_dims, Typespec.tensor({:u, 64}, {length(p_dims)})) + l_dims = constant(func, l_dims, Typespec.tensor({:u, 64}, {length(l_dims)})) + u_dims = constant(func, u_dims, Typespec.tensor({:u, 64}, {length(u_dims)})) + operands = [value, dim_sizes, operand_dims, p_dims, l_dims, u_dims] + + # Force P to always b u8 to avoid requiring too many template instances during custom_call registration + p_result_type = type_tensor({:u, 8}, p_shape) + l_result_type = type_tensor(l_type, l_shape) + u_result_type = type_tensor(u_type, u_shape) + result_types = [type_tuple([p_result_type, l_result_type, u_result_type])] + + call_target_name = + case op_type do + {:f, 32} -> + "lu_cpu_custom_call_f32" + + {:f, 64} -> + "lu_cpu_custom_call_f64" + + {:f, 16} -> + "lu_cpu_custom_call_f16" + + {:bf, 16} -> + "lu_cpu_custom_call_bf16" + + type -> + # Due to matching on EXLA.Defn, we are sure that the device here is always :host + raise "LU decomposition not supported on :host device for type #{inspect(type)}" + end + + attributes = [ + call_target_name: attr_string(call_target_name), + backend_config: attr_string("Host") + ] + + result = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() + + # This is not the best approach, but the alternative would require many more template instances + u8_typespec = Typespec.to_type(p_typespec, {:u, 8}) + p = get_tuple_element(result, 0, u8_typespec) + + p = + if u8_typespec != p_typespec do + convert(p, p_typespec) + else + p + end + + l = get_tuple_element(result, 1, l_typespec) + u = get_tuple_element(result, 2, u_typespec) + + {p, l, u} + end + def get_tuple_element(%Value{function: func} = operand, index, typespec) do result_types = typespecs_to_mlir_types([typespec]) attributes = [index: attr_i32(index)] diff --git a/exla/mix.exs b/exla/mix.exs index 17ef1915d4..8e9f66fd03 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -187,7 +187,7 @@ defmodule EXLA.MixProject do File.rm_rf!("cache/#{@version}/libexla.so") Mix.shell().info("Removing libexla.so cache at #{cached_so}") - File.rm!(cached_so) + File.rm_rf!(cached_so) end if cached? do diff --git a/exla/test/exla/nx_linalg_doctest_test.exs b/exla/test/exla/nx_linalg_doctest_test.exs index 10c2cbce05..09d60ba8f6 100644 --- a/exla/test/exla/nx_linalg_doctest_test.exs +++ b/exla/test/exla/nx_linalg_doctest_test.exs @@ -1,16 +1,24 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do use EXLA.Case, async: true - @invalid_type_error_doctests [svd: 2, pinv: 2, matrix_rank: 2] + @invalid_type_error_doctests [ + svd: 2, + pinv: 2 + ] + @function_clause_error_doctests [ - norm: 2, - lu: 2, - solve: 2, + solve: 2 + ] + + @rounding_error_doctests [ + triangular_solve: 3, + eigh: 2, + cholesky: 1, + least_squares: 3, determinant: 1, - invert: 1, - matrix_power: 2 + matrix_power: 2, + lu: 2 ] - @rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3] @excluded_doctests @function_clause_error_doctests ++ @rounding_error_doctests ++