From b68d523a4d0e751949b28aa1ba2cf95aec178277 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 23 Nov 2024 18:12:13 +0000 Subject: [PATCH] update XLA revision (#426) --- .github/workflows/c-xla-version.yml | 3 +- .github/workflows/checks.yml | 7 ++-- XLA_VERSION | 2 +- pjrt-plugins/xla-cpu/postinstall.sh | 2 + pjrt-plugins/xla-cuda/postinstall.sh | 2 + spidr/backend/BUILD | 6 ++- spidr/backend/VERSION | 2 +- spidr/backend/WORKSPACE | 12 ++++++ spidr/backend/src/xla/client/BUILD | 3 -- spidr/backend/src/xla/hlo/builder/BUILD | 13 +++++++ .../src/xla/{client => hlo/builder}/lib/BUILD | 8 ++-- .../builder}/lib/arithmetic.cpp | 8 +--- .../{client => hlo/builder}/lib/constants.cpp | 2 +- .../xla/{client => hlo/builder}/lib/math.cpp | 2 +- .../{client => hlo/builder}/lib/matrix.cpp | 2 +- .../xla/{client => hlo/builder}/lib/prng.cpp | 4 +- .../{client => hlo/builder}/xla_builder.cpp | 7 ++-- .../xla/{client => hlo/builder}/xla_builder.h | 2 +- .../builder}/xla_computation.cpp | 4 +- .../{client => hlo/builder}/xla_computation.h | 0 spidr/backend/src/xla/status.cpp | 28 -------------- spidr/backend/src/xla/status.h | 18 --------- spidr/postinstall.sh | 2 + spidr/spidr.ipkg | 15 ++++---- spidr/src/Compiler/Eval.idr | 15 ++++---- spidr/src/Compiler/Expr.idr | 3 -- .../Builder}/Lib/Arithmetic.idr | 14 +------ .../{Client => HLO/Builder}/Lib/Constants.idr | 4 +- .../Xla/{Client => HLO/Builder}/Lib/Math.idr | 4 +- .../{Client => HLO/Builder}/Lib/Matrix.idr | 4 +- .../Xla/{Client => HLO/Builder}/Lib/PRNG.idr | 4 +- .../{Client => HLO/Builder}/XlaBuilder.idr | 4 +- .../Builder}/XlaComputation.idr | 2 +- spidr/src/Compiler/Xla/Status.idr | 37 ------------------- spidr/src/Tensor.idr | 36 ++++++++---------- 35 files changed, 101 insertions(+), 180 deletions(-) create mode 100644 spidr/backend/src/xla/hlo/builder/BUILD rename spidr/backend/src/xla/{client => hlo/builder}/lib/BUILD (57%) rename spidr/backend/src/xla/{client => hlo/builder}/lib/arithmetic.cpp (73%) rename spidr/backend/src/xla/{client => hlo/builder}/lib/constants.cpp (97%) rename spidr/backend/src/xla/{client => hlo/builder}/lib/math.cpp (97%) rename spidr/backend/src/xla/{client => hlo/builder}/lib/matrix.cpp (97%) rename spidr/backend/src/xla/{client => hlo/builder}/lib/prng.cpp (97%) rename spidr/backend/src/xla/{client => hlo/builder}/xla_builder.cpp (99%) rename spidr/backend/src/xla/{client => hlo/builder}/xla_builder.h (94%) rename spidr/backend/src/xla/{client => hlo/builder}/xla_computation.cpp (93%) rename spidr/backend/src/xla/{client => hlo/builder}/xla_computation.h (100%) delete mode 100644 spidr/backend/src/xla/status.cpp delete mode 100644 spidr/backend/src/xla/status.h rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/Lib/Arithmetic.idr (69%) rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/Lib/Constants.idr (95%) rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/Lib/Math.idr (96%) rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/Lib/Matrix.idr (95%) rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/Lib/PRNG.idr (96%) rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/XlaBuilder.idr (99%) rename spidr/src/Compiler/Xla/{Client => HLO/Builder}/XlaComputation.idr (96%) delete mode 100644 spidr/src/Compiler/Xla/Status.idr diff --git a/.github/workflows/c-xla-version.yml b/.github/workflows/c-xla-version.yml index 9c41e4419..9ec2368dc 100644 --- a/.github/workflows/c-xla-version.yml +++ b/.github/workflows/c-xla-version.yml @@ -10,7 +10,8 @@ jobs: with: fetch-depth: 2 - name: Check backend version is updated when necessary - run: sh -c " + run: | + sh -c " (git diff --quiet HEAD^ XLA_VERSION && git diff --quiet HEAD^ spidr/backend/**) || \ ! git diff --quiet HEAD^ spidr/backend/VERSION " diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index a956de3b3..87496bcec 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -106,7 +106,7 @@ jobs: run: | apt-get update && apt-get install -y curl pack switch HEAD - pack --no-prompt build xla-cpu.ipkg + SPIDR_INSTALL_SUPPORT_LIBS=false pack --no-prompt build xla-cpu.ipkg tar cfz tests-xla-cpu.tar.gz -C build/exec . - name: Upload tests uses: actions/upload-artifact@v4 @@ -124,7 +124,7 @@ jobs: run: | apt-get update && apt-get install -y curl pack switch HEAD - pack --no-prompt build xla-cuda.ipkg + SPIDR_INSTALL_SUPPORT_LIBS=false pack --no-prompt build xla-cuda.ipkg tar cfz tests-xla-cuda.tar.gz -C build/exec . - name: Upload tests uses: actions/upload-artifact@v4 @@ -178,7 +178,7 @@ jobs: run: | apt-get update && apt-get install -y curl pack switch HEAD - pack --no-prompt typecheck readme.ipkg + SPIDR_INSTALL_SUPPORT_LIBS=false pack --no-prompt typecheck readme.ipkg tutorials: runs-on: ubuntu-latest container: ghcr.io/stefan-hoeck/idris2-pack @@ -188,4 +188,5 @@ jobs: run: | apt-get update && apt-get install -y curl pack switch HEAD + export SPIDR_INSTALL_SUPPORT_LIBS=false res=0; for f in tutorials/*.ipkg; do pack --no-prompt typecheck $f || res=$?; done; $(exit $res) diff --git a/XLA_VERSION b/XLA_VERSION index b7be4718b..66066d25b 100644 --- a/XLA_VERSION +++ b/XLA_VERSION @@ -1 +1 @@ -79d745f5873c0c93bcb43b25436c212da1f2b997 \ No newline at end of file +43517d94ad963a96a2a308b7b33d77ecd7de4b4a \ No newline at end of file diff --git a/pjrt-plugins/xla-cpu/postinstall.sh b/pjrt-plugins/xla-cpu/postinstall.sh index f3e1f2110..c7de35295 100755 --- a/pjrt-plugins/xla-cpu/postinstall.sh +++ b/pjrt-plugins/xla-cpu/postinstall.sh @@ -1,5 +1,7 @@ #!/bin/sh -e +if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi + script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) cd "$script_dir/../.." . ./dev.sh diff --git a/pjrt-plugins/xla-cuda/postinstall.sh b/pjrt-plugins/xla-cuda/postinstall.sh index 9c5a0e28f..841ad4d0d 100755 --- a/pjrt-plugins/xla-cuda/postinstall.sh +++ b/pjrt-plugins/xla-cuda/postinstall.sh @@ -1,5 +1,7 @@ #!/bin/sh -e +if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi + script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) cd "$script_dir/../.." . ./dev.sh diff --git a/spidr/backend/BUILD b/spidr/backend/BUILD index 85ad76bd8..ba1a39bfa 100644 --- a/spidr/backend/BUILD +++ b/spidr/backend/BUILD @@ -14,7 +14,8 @@ cc_binary( srcs = [ "//src/xla", "//src/xla/client", - "//src/xla/client/lib", + "//src/xla/hlo/builder", + "//src/xla/hlo/builder/lib", "//src/xla/pjrt", "//src/xla/pjrt/c", "//src", @@ -22,7 +23,8 @@ cc_binary( deps = [ "//src/xla", "//src/xla/client", - "//src/xla/client/lib", + "//src/xla/hlo/builder", + "//src/xla/hlo/builder/lib", "//src/xla/pjrt", "//src/xla/pjrt/c", "//src", diff --git a/spidr/backend/VERSION b/spidr/backend/VERSION index 8cbf02c39..43b296183 100644 --- a/spidr/backend/VERSION +++ b/spidr/backend/VERSION @@ -1 +1 @@ -0.0.12 +0.0.13 diff --git a/spidr/backend/WORKSPACE b/spidr/backend/WORKSPACE index d1333866d..991ab29d9 100644 --- a/spidr/backend/WORKSPACE +++ b/spidr/backend/WORKSPACE @@ -2,6 +2,15 @@ # so we can run ./configure.py before invoking bazel local_repository(name = "xla", path = "xla") +load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") +python_init_rules() + +load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") +python_init_repositories(requirements = {"3.11": "@xla//:requirements_lock_3_11.txt"}) + +load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") +python_init_toolchains() + load("@xla//:workspace4.bzl", "xla_workspace4") xla_workspace4() @@ -16,3 +25,6 @@ xla_workspace1() load("@xla//:workspace0.bzl", "xla_workspace0") xla_workspace0() + +load("@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") +cuda_configure(name = "local_config_cuda") diff --git a/spidr/backend/src/xla/client/BUILD b/spidr/backend/src/xla/client/BUILD index b0f072f0a..dd03ccb61 100644 --- a/spidr/backend/src/xla/client/BUILD +++ b/spidr/backend/src/xla/client/BUILD @@ -6,9 +6,6 @@ cc_library( hdrs = glob(["*.h"]), deps = [ "@xla//xla/client:executable_build_options", - "@xla//xla/client:xla_builder", - "//src", - "//src/xla", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/xla/hlo/builder/BUILD b/spidr/backend/src/xla/hlo/builder/BUILD new file mode 100644 index 000000000..e729f1eef --- /dev/null +++ b/spidr/backend/src/xla/hlo/builder/BUILD @@ -0,0 +1,13 @@ +cc_library( + name = "builder", + linkstatic = True, + alwayslink = True, + srcs = glob(["*.cpp"]), + hdrs = glob(["*.h"]), + deps = [ + "@xla//xla/hlo/builder:xla_builder", + "//src", + "//src/xla", + ], + visibility = ["//visibility:public"], +) diff --git a/spidr/backend/src/xla/client/lib/BUILD b/spidr/backend/src/xla/hlo/builder/lib/BUILD similarity index 57% rename from spidr/backend/src/xla/client/lib/BUILD rename to spidr/backend/src/xla/hlo/builder/lib/BUILD index a9d5a47ad..263298d9d 100644 --- a/spidr/backend/src/xla/client/lib/BUILD +++ b/spidr/backend/src/xla/hlo/builder/lib/BUILD @@ -5,11 +5,11 @@ cc_library( srcs = glob(["*.cpp"]), hdrs = glob(["*.h"]), deps = [ - "@xla//xla/client/lib:math", - "@xla//xla/client/lib:matrix", - "@xla//xla/client/lib:prng", + "@xla//xla/hlo/builder/lib:math", + "@xla//xla/hlo/builder/lib:matrix", + "@xla//xla/hlo/builder/lib:prng", "//src/xla", - "//src/xla/client", + "//src/xla/hlo/builder", ], visibility = ["//visibility:public"], ) diff --git a/spidr/backend/src/xla/client/lib/arithmetic.cpp b/spidr/backend/src/xla/hlo/builder/lib/arithmetic.cpp similarity index 73% rename from spidr/backend/src/xla/client/lib/arithmetic.cpp rename to spidr/backend/src/xla/hlo/builder/lib/arithmetic.cpp index 83e50c39d..6f43938c5 100644 --- a/spidr/backend/src/xla/client/lib/arithmetic.cpp +++ b/spidr/backend/src/xla/hlo/builder/lib/arithmetic.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "xla/client/lib/arithmetic.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include "../xla_builder.h" @@ -23,10 +23,4 @@ extern "C" { xla::XlaOp res = xla::ArgMax(input_, (xla::PrimitiveType) output_type, axis); return reinterpret_cast(new xla::XlaOp(res)); } - - XlaOp* ArgMin(XlaOp& input, int output_type, int axis) { - auto& input_ = reinterpret_cast(input); - xla::XlaOp res = xla::ArgMin(input_, (xla::PrimitiveType) output_type, axis); - return reinterpret_cast(new xla::XlaOp(res)); - } } diff --git a/spidr/backend/src/xla/client/lib/constants.cpp b/spidr/backend/src/xla/hlo/builder/lib/constants.cpp similarity index 97% rename from spidr/backend/src/xla/client/lib/constants.cpp rename to spidr/backend/src/xla/hlo/builder/lib/constants.cpp index be34960c5..be199d690 100644 --- a/spidr/backend/src/xla/client/lib/constants.cpp +++ b/spidr/backend/src/xla/hlo/builder/lib/constants.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include "../xla_builder.h" diff --git a/spidr/backend/src/xla/client/lib/math.cpp b/spidr/backend/src/xla/hlo/builder/lib/math.cpp similarity index 97% rename from spidr/backend/src/xla/client/lib/math.cpp rename to spidr/backend/src/xla/hlo/builder/lib/math.cpp index 741cbbbeb..b86cc8691 100644 --- a/spidr/backend/src/xla/client/lib/math.cpp +++ b/spidr/backend/src/xla/hlo/builder/lib/math.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "xla/client/lib/math.h" +#include "xla/hlo/builder/lib/math.h" #include "../xla_builder.h" diff --git a/spidr/backend/src/xla/client/lib/matrix.cpp b/spidr/backend/src/xla/hlo/builder/lib/matrix.cpp similarity index 97% rename from spidr/backend/src/xla/client/lib/matrix.cpp rename to spidr/backend/src/xla/hlo/builder/lib/matrix.cpp index 382f5604d..9acdd476d 100644 --- a/spidr/backend/src/xla/client/lib/matrix.cpp +++ b/spidr/backend/src/xla/hlo/builder/lib/matrix.cpp @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "xla/client/lib/matrix.h" +#include "xla/hlo/builder/lib/matrix.h" #include "../xla_builder.h" diff --git a/spidr/backend/src/xla/client/lib/prng.cpp b/spidr/backend/src/xla/hlo/builder/lib/prng.cpp similarity index 97% rename from spidr/backend/src/xla/client/lib/prng.cpp rename to spidr/backend/src/xla/hlo/builder/lib/prng.cpp index 20d98e4f8..b273ea212 100644 --- a/spidr/backend/src/xla/client/lib/prng.cpp +++ b/spidr/backend/src/xla/hlo/builder/lib/prng.cpp @@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "xla/client/lib/prng.h" +#include "xla/hlo/builder/lib/prng.h" -#include "../../shape.h" +#include "../../../shape.h" #include "../xla_builder.h" xla::BitGeneratorTy BitGenerator(int bit_generator) { diff --git a/spidr/backend/src/xla/client/xla_builder.cpp b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp similarity index 99% rename from spidr/backend/src/xla/client/xla_builder.cpp rename to spidr/backend/src/xla/hlo/builder/xla_builder.cpp index a0f8c03ce..195562645 100644 --- a/spidr/backend/src/xla/client/xla_builder.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_builder.cpp @@ -18,14 +18,13 @@ limitations under the License. #include #include "absl/types/span.h" -#include "xla/client/xla_builder.h" #include "xla/literal.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" -#include "../literal.h" -#include "../shape.h" -#include "../xla_data.pb.h" +#include "../../literal.h" +#include "../../shape.h" +#include "../../xla_data.pb.h" #include "xla_builder.h" #include "xla_computation.h" diff --git a/spidr/backend/src/xla/client/xla_builder.h b/spidr/backend/src/xla/hlo/builder/xla_builder.h similarity index 94% rename from spidr/backend/src/xla/client/xla_builder.h rename to spidr/backend/src/xla/hlo/builder/xla_builder.h index a3eb524fc..cb2e9e192 100644 --- a/spidr/backend/src/xla/client/xla_builder.h +++ b/spidr/backend/src/xla/hlo/builder/xla_builder.h @@ -15,7 +15,7 @@ limitations under the License. */ #include -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" extern "C" { struct XlaOp; diff --git a/spidr/backend/src/xla/client/xla_computation.cpp b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp similarity index 93% rename from spidr/backend/src/xla/client/xla_computation.cpp rename to spidr/backend/src/xla/hlo/builder/xla_computation.cpp index a8281cc60..1cba3a527 100644 --- a/spidr/backend/src/xla/client/xla_computation.cpp +++ b/spidr/backend/src/xla/hlo/builder/xla_computation.cpp @@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" -#include "../../ffi.h" +#include "../../../ffi.h" #include "xla_computation.h" extern "C" { diff --git a/spidr/backend/src/xla/client/xla_computation.h b/spidr/backend/src/xla/hlo/builder/xla_computation.h similarity index 100% rename from spidr/backend/src/xla/client/xla_computation.h rename to spidr/backend/src/xla/hlo/builder/xla_computation.h diff --git a/spidr/backend/src/xla/status.cpp b/spidr/backend/src/xla/status.cpp deleted file mode 100644 index a3cb91a6b..000000000 --- a/spidr/backend/src/xla/status.cpp +++ /dev/null @@ -1,28 +0,0 @@ -/* -Copyright 2022 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "xla/status.h" - -#include "status.h" - -extern "C" { - void Status_delete(Status* status) { - delete reinterpret_cast(status); - } - - int Status_ok(Status& status) { - return (int) reinterpret_cast(status).ok(); - } -} diff --git a/spidr/backend/src/xla/status.h b/spidr/backend/src/xla/status.h deleted file mode 100644 index 25eeedddc..000000000 --- a/spidr/backend/src/xla/status.h +++ /dev/null @@ -1,18 +0,0 @@ -/* -Copyright 2022 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -extern "C" { - struct Status; -} diff --git a/spidr/postinstall.sh b/spidr/postinstall.sh index 49ea9591f..cf9e83853 100755 --- a/spidr/postinstall.sh +++ b/spidr/postinstall.sh @@ -1,5 +1,7 @@ #!/bin/sh -e +if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi + script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd) xla_ext_version=$(cat "$script_dir/backend/VERSION") diff --git a/spidr/spidr.ipkg b/spidr/spidr.ipkg index 3ebc14eab..fa7b670c0 100644 --- a/spidr/spidr.ipkg +++ b/spidr/spidr.ipkg @@ -8,21 +8,20 @@ modules = BayesianOptimization, BayesianOptimization.Acquisition, - Compiler.Xla.Client.Lib.Arithmetic, - Compiler.Xla.Client.Lib.Constants, - Compiler.Xla.Client.Lib.Math, - Compiler.Xla.Client.Lib.Matrix, - Compiler.Xla.Client.Lib.PRNG, - Compiler.Xla.Client.XlaBuilder, Compiler.Xla.Client.ExecutableBuildOptions, - Compiler.Xla.Client.XlaComputation, + Compiler.Xla.HLO.Builder.Lib.Arithmetic, + Compiler.Xla.HLO.Builder.Lib.Constants, + Compiler.Xla.HLO.Builder.Lib.Math, + Compiler.Xla.HLO.Builder.Lib.Matrix, + Compiler.Xla.HLO.Builder.Lib.PRNG, + Compiler.Xla.HLO.Builder.XlaBuilder, + Compiler.Xla.HLO.Builder.XlaComputation, Compiler.Xla.PJRT.C.PjrtCApi, Compiler.Xla.PJRT.PjrtExecutable, Compiler.Xla.Literal, Compiler.Xla.Shape, Compiler.Xla.ShapeUtil, Compiler.Xla.XlaData, - Compiler.Xla.Status, Compiler.Eval, Compiler.Expr, diff --git a/spidr/src/Compiler/Eval.idr b/spidr/src/Compiler/Eval.idr index e35ba1a3c..6c26186f5 100644 --- a/spidr/src/Compiler/Eval.idr +++ b/spidr/src/Compiler/Eval.idr @@ -26,14 +26,14 @@ import Data.List.Elem import Compiler.Expr import Compiler.FFI import Compiler.LiteralRW -import Compiler.Xla.Client.Lib.Arithmetic -import Compiler.Xla.Client.Lib.Constants -import Compiler.Xla.Client.Lib.Math -import Compiler.Xla.Client.Lib.Matrix -import Compiler.Xla.Client.Lib.PRNG import Compiler.Xla.Client.ExecutableBuildOptions -import Compiler.Xla.Client.XlaBuilder -import Compiler.Xla.Client.XlaComputation +import Compiler.Xla.HLO.Builder.Lib.Arithmetic +import Compiler.Xla.HLO.Builder.Lib.Constants +import Compiler.Xla.HLO.Builder.Lib.Math +import Compiler.Xla.HLO.Builder.Lib.Matrix +import Compiler.Xla.HLO.Builder.Lib.PRNG +import Compiler.Xla.HLO.Builder.XlaBuilder +import Compiler.Xla.HLO.Builder.XlaComputation import Compiler.Xla.PJRT.C.PjrtCApi import Compiler.Xla.PJRT.PjrtExecutable import Compiler.Xla.Literal @@ -183,7 +183,6 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do Asinh => asinh Acosh => acosh Atanh => atanh - interpretE (Argmin {out} axis x) = argMin {outputType = out} !(interpretE x) axis interpretE (Argmax {out} axis x) = argMax {outputType = out} !(interpretE x) axis interpretE (Select pred true false) = select !(interpretE pred) !(interpretE true) !(interpretE false) diff --git a/spidr/src/Compiler/Expr.idr b/spidr/src/Compiler/Expr.idr index 9908caada..c5c4ff68d 100644 --- a/spidr/src/Compiler/Expr.idr +++ b/spidr/src/Compiler/Expr.idr @@ -127,7 +127,6 @@ data Expr : Type where Reverse : (axes : List Nat) -> Expr -> Expr BinaryElementwise : BinaryOp -> Expr -> Expr -> Expr UnaryElementwise : UnaryOp -> Expr -> Expr - Argmin : Primitive out => (axis : Nat) -> Expr -> Expr Argmax : Primitive out => (axis : Nat) -> Expr -> Expr Select : (predicate, onTrue, onFalse : Expr) -> Expr Cond : (pred : Expr) -> (onTrue : Fn 1) -> (onTrueArg : Expr) -> @@ -217,8 +216,6 @@ showExpr indent (Reverse axes x) = "Reverse \{axes} (\{showExpr indent x})" showExpr indent (BinaryElementwise op x y) = "\{show op} (\{showExpr indent x}) (\{showExpr indent y})" showExpr indent (UnaryElementwise op x) = "\{show op} (\{showExpr indent x})" -showExpr indent (Argmin {out} axis x) = - "Argmin {outType = \{xlaIdentifier {dtype = out}}} \{axis} (\{showExpr indent x})" showExpr indent (Argmax {out} axis x) = "Argmax {outType = \{xlaIdentifier {dtype = out}}} \{axis} (\{showExpr indent x})" showExpr indent (Select p t f) = diff --git a/spidr/src/Compiler/Xla/Client/Lib/Arithmetic.idr b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Arithmetic.idr similarity index 69% rename from spidr/src/Compiler/Xla/Client/Lib/Arithmetic.idr rename to spidr/src/Compiler/Xla/HLO/Builder/Lib/Arithmetic.idr index ae2cb6c6b..c367d4fdf 100644 --- a/spidr/src/Compiler/Xla/Client/Lib/Arithmetic.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Arithmetic.idr @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.Lib.Arithmetic +module Compiler.Xla.HLO.Builder.Lib.Arithmetic import Compiler.FFI -import Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.XlaData %foreign (libxla "ArgMax") @@ -29,13 +29,3 @@ argMax (MkXlaOp input) axis = do opPtr <- primIO $ prim__argMax input (xlaIdentifier {dtype = outputType}) (cast axis) opPtr <- onCollectAny opPtr XlaOp.delete pure (MkXlaOp opPtr) - -%foreign (libxla "ArgMin") -prim__argMin : GCAnyPtr -> Int -> Int -> PrimIO AnyPtr - -export -argMin : (HasIO io, Primitive outputType) => XlaOp -> Nat -> io XlaOp -argMin (MkXlaOp input) axis = do - opPtr <- primIO $ prim__argMin input (xlaIdentifier {dtype = outputType}) (cast axis) - opPtr <- onCollectAny opPtr XlaOp.delete - pure (MkXlaOp opPtr) diff --git a/spidr/src/Compiler/Xla/Client/Lib/Constants.idr b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Constants.idr similarity index 95% rename from spidr/src/Compiler/Xla/Client/Lib/Constants.idr rename to spidr/src/Compiler/Xla/HLO/Builder/Lib/Constants.idr index 164ce6068..05a05f723 100644 --- a/spidr/src/Compiler/Xla/Client/Lib/Constants.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Constants.idr @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.Lib.Constants +module Compiler.Xla.HLO.Builder.Lib.Constants import Compiler.FFI -import Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.XlaData %foreign (libxla "MinValue") diff --git a/spidr/src/Compiler/Xla/Client/Lib/Math.idr b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Math.idr similarity index 96% rename from spidr/src/Compiler/Xla/Client/Lib/Math.idr rename to spidr/src/Compiler/Xla/HLO/Builder/Lib/Math.idr index 779a5c777..095a11e96 100644 --- a/spidr/src/Compiler/Xla/Client/Lib/Math.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Math.idr @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.Lib.Math +module Compiler.Xla.HLO.Builder.Lib.Math import Compiler.FFI -import Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.HLO.Builder.XlaBuilder %foreign (libxla "Square") prim__square : GCAnyPtr -> PrimIO AnyPtr diff --git a/spidr/src/Compiler/Xla/Client/Lib/Matrix.idr b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Matrix.idr similarity index 95% rename from spidr/src/Compiler/Xla/Client/Lib/Matrix.idr rename to spidr/src/Compiler/Xla/HLO/Builder/Lib/Matrix.idr index e2d3714dd..a3a3bbda3 100644 --- a/spidr/src/Compiler/Xla/Client/Lib/Matrix.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/Lib/Matrix.idr @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.Lib.Matrix +module Compiler.Xla.HLO.Builder.Lib.Matrix import Compiler.FFI -import Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.XlaData %foreign (libxla "IdentityMatrix") diff --git a/spidr/src/Compiler/Xla/Client/Lib/PRNG.idr b/spidr/src/Compiler/Xla/HLO/Builder/Lib/PRNG.idr similarity index 96% rename from spidr/src/Compiler/Xla/Client/Lib/PRNG.idr rename to spidr/src/Compiler/Xla/HLO/Builder/Lib/PRNG.idr index 2c3a3ada2..c4a34ca6e 100644 --- a/spidr/src/Compiler/Xla/Client/Lib/PRNG.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/Lib/PRNG.idr @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.Lib.PRNG +module Compiler.Xla.HLO.Builder.Lib.PRNG import Compiler.FFI -import Compiler.Xla.Client.XlaBuilder +import Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.Xla.Shape public export diff --git a/spidr/src/Compiler/Xla/Client/XlaBuilder.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr similarity index 99% rename from spidr/src/Compiler/Xla/Client/XlaBuilder.idr rename to spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr index 0d018375e..c85bc08bf 100644 --- a/spidr/src/Compiler/Xla/Client/XlaBuilder.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaBuilder.idr @@ -14,10 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.XlaBuilder +module Compiler.Xla.HLO.Builder.XlaBuilder import Compiler.FFI -import Compiler.Xla.Client.XlaComputation +import Compiler.Xla.HLO.Builder.XlaComputation import Compiler.Xla.XlaData import Compiler.Xla.Literal import Compiler.Xla.Shape diff --git a/spidr/src/Compiler/Xla/Client/XlaComputation.idr b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr similarity index 96% rename from spidr/src/Compiler/Xla/Client/XlaComputation.idr rename to spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr index 8fd7d9c04..1e35ba4dc 100644 --- a/spidr/src/Compiler/Xla/Client/XlaComputation.idr +++ b/spidr/src/Compiler/Xla/HLO/Builder/XlaComputation.idr @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. --} ||| For internal spidr use only. -module Compiler.Xla.Client.XlaComputation +module Compiler.Xla.HLO.Builder.XlaComputation import Compiler.FFI diff --git a/spidr/src/Compiler/Xla/Status.idr b/spidr/src/Compiler/Xla/Status.idr deleted file mode 100644 index 4bc90a16d..000000000 --- a/spidr/src/Compiler/Xla/Status.idr +++ /dev/null @@ -1,37 +0,0 @@ -{-- -Copyright 2022 Joel Berkeley - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. ---} -||| For internal spidr use only. -module Compiler.Xla.Status - -import Compiler.FFI - -public export -data Status : Type where - MkStatus : GCAnyPtr -> Status - -%foreign (libxla "Status_delete") -prim__delete : AnyPtr -> PrimIO () - -export -delete : AnyPtr -> IO () -delete = primIO . prim__delete - -%foreign (libxla "Status_ok") -prim__ok : GCAnyPtr -> Int - -export -ok : Status -> Bool -ok (MkStatus ptr) = cIntToBool (prim__ok ptr) diff --git a/spidr/src/Tensor.idr b/spidr/src/Tensor.idr index c2d4c087f..24272644c 100644 --- a/spidr/src/Tensor.idr +++ b/spidr/src/Tensor.idr @@ -1440,20 +1440,29 @@ namespace Monoid Monoid (Tensor shape dtype) using Semigroup.Max where neutral = fill (- 1.0 / 0.0) -highlightNan : Primitive.Ord dtype => Bool -> Tensor [S n] dtype -> Tag $ Tensor [S n] dtype -highlightNan minimize x with (x) +||| The first index of the maximum value in a vector. For example, +||| `argmax (tensor [-1, 3, -2, -2, 3])` produces `tensor 1`. If the vector contains NaN values, +||| `argmax` returns the index of the first NaN. +||| +||| **Note:** `argmax` uses `Tag` to work around what we believe to be an inconsistency in the XLA +||| compiler's handling of NaN. Specifically, we have modified `argmax` to return the first index of +||| the value returned by `reduce @{Max}`. +export +argmax : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64 +argmax x with (x) _ | (MkTensor {shape = _} _) = do x <- tag x - cond !(reduce @{All} [0] (x == x)) pure x extremizeNan x + MkTensor x <- cond !(reduce @{All} [0] (x == x)) pure x extremizeNan x + pure $ MkTensor $ Argmax {out = U64} 0 x where - extremizeNan : {n : _} -> Tensor [S n] dtype -> Tag $ Tensor [S n] dtype + extremizeNan : {m : _} -> Tensor [S m] dtype -> Tag $ Tensor [S m] dtype extremizeNan x = do x <- tag x let min' = broadcast $ Types.min @{NonFinite} max' = broadcast $ Types.max @{NonFinite} - pure $ select (if minimize then x == x else x /= x) max' min' + pure $ select (x /= x) max' min' ||| The first index of the minimum value in a vector. For example, ||| `argmin (tensor [-1, 3, -2, -2, 3])` produces `tensor 2`. If the vector contains NaN values, @@ -1464,22 +1473,7 @@ highlightNan minimize x with (x) ||| the value returned by `reduce @{Min}`. export argmin : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64 -argmin x = do - MkTensor x <- highlightNan True x - pure $ MkTensor $ Argmin {out = U64} 0 x - -||| The first index of the maximum value in a vector. For example, -||| `argmax (tensor [-1, 3, -2, -2, 3])` produces `tensor 1`. If the vector contains NaN values, -||| `argmax` returns the index of the first NaN. -||| -||| **Note:** `argmax` uses `Tag` to work around what we believe to be an inconsistency in the XLA -||| compiler's handling of NaN. Specifically, we have modified `argmax` to return the first index of -||| the value returned by `reduce @{Max}`. -export -argmax : Primitive.Ord dtype => Tensor [S n] dtype -> Tag $ Tensor [] U64 -argmax x = do - MkTensor x <- highlightNan False x - pure $ MkTensor $ Argmax {out = U64} 0 x +argmin (MkTensor x) = argmax (MkTensor {shape = [S n], dtype} $ UnaryElementwise Neg x) ---------------------------- other ----------------------------------