Skip to content

Commit

Permalink
update XLA revision (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley authored Nov 23, 2024
1 parent 3e14f64 commit b68d523
Show file tree
Hide file tree
Showing 35 changed files with 101 additions and 180 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/c-xla-version.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ jobs:
with:
fetch-depth: 2
- name: Check backend version is updated when necessary
run: sh -c "
run: |
sh -c "
(git diff --quiet HEAD^ XLA_VERSION && git diff --quiet HEAD^ spidr/backend/**) || \
! git diff --quiet HEAD^ spidr/backend/VERSION
"
7 changes: 4 additions & 3 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:
run: |
apt-get update && apt-get install -y curl
pack switch HEAD
pack --no-prompt build xla-cpu.ipkg
SPIDR_INSTALL_SUPPORT_LIBS=false pack --no-prompt build xla-cpu.ipkg
tar cfz tests-xla-cpu.tar.gz -C build/exec .
- name: Upload tests
uses: actions/upload-artifact@v4
Expand All @@ -124,7 +124,7 @@ jobs:
run: |
apt-get update && apt-get install -y curl
pack switch HEAD
pack --no-prompt build xla-cuda.ipkg
SPIDR_INSTALL_SUPPORT_LIBS=false pack --no-prompt build xla-cuda.ipkg
tar cfz tests-xla-cuda.tar.gz -C build/exec .
- name: Upload tests
uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -178,7 +178,7 @@ jobs:
run: |
apt-get update && apt-get install -y curl
pack switch HEAD
pack --no-prompt typecheck readme.ipkg
SPIDR_INSTALL_SUPPORT_LIBS=false pack --no-prompt typecheck readme.ipkg
tutorials:
runs-on: ubuntu-latest
container: ghcr.io/stefan-hoeck/idris2-pack
Expand All @@ -188,4 +188,5 @@ jobs:
run: |
apt-get update && apt-get install -y curl
pack switch HEAD
export SPIDR_INSTALL_SUPPORT_LIBS=false
res=0; for f in tutorials/*.ipkg; do pack --no-prompt typecheck $f || res=$?; done; $(exit $res)
2 changes: 1 addition & 1 deletion XLA_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
79d745f5873c0c93bcb43b25436c212da1f2b997
43517d94ad963a96a2a308b7b33d77ecd7de4b4a
2 changes: 2 additions & 0 deletions pjrt-plugins/xla-cpu/postinstall.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/bin/sh -e

if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."
. ./dev.sh
Expand Down
2 changes: 2 additions & 0 deletions pjrt-plugins/xla-cuda/postinstall.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/bin/sh -e

if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
cd "$script_dir/../.."
. ./dev.sh
Expand Down
6 changes: 4 additions & 2 deletions spidr/backend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@ cc_binary(
srcs = [
"//src/xla",
"//src/xla/client",
"//src/xla/client/lib",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src",
],
deps = [
"//src/xla",
"//src/xla/client",
"//src/xla/client/lib",
"//src/xla/hlo/builder",
"//src/xla/hlo/builder/lib",
"//src/xla/pjrt",
"//src/xla/pjrt/c",
"//src",
Expand Down
2 changes: 1 addition & 1 deletion spidr/backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.12
0.0.13
12 changes: 12 additions & 0 deletions spidr/backend/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
# so we can run ./configure.py before invoking bazel
local_repository(name = "xla", path = "xla")

load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
python_init_rules()

load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
python_init_repositories(requirements = {"3.11": "@xla//:requirements_lock_3_11.txt"})

load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains")
python_init_toolchains()

load("@xla//:workspace4.bzl", "xla_workspace4")
xla_workspace4()

Expand All @@ -16,3 +25,6 @@ xla_workspace1()

load("@xla//:workspace0.bzl", "xla_workspace0")
xla_workspace0()

load("@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure")
cuda_configure(name = "local_config_cuda")
3 changes: 0 additions & 3 deletions spidr/backend/src/xla/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ cc_library(
hdrs = glob(["*.h"]),
deps = [
"@xla//xla/client:executable_build_options",
"@xla//xla/client:xla_builder",
"//src",
"//src/xla",
],
visibility = ["//visibility:public"],
)
13 changes: 13 additions & 0 deletions spidr/backend/src/xla/hlo/builder/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
cc_library(
name = "builder",
linkstatic = True,
alwayslink = True,
srcs = glob(["*.cpp"]),
hdrs = glob(["*.h"]),
deps = [
"@xla//xla/hlo/builder:xla_builder",
"//src",
"//src/xla",
],
visibility = ["//visibility:public"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ cc_library(
srcs = glob(["*.cpp"]),
hdrs = glob(["*.h"]),
deps = [
"@xla//xla/client/lib:math",
"@xla//xla/client/lib:matrix",
"@xla//xla/client/lib:prng",
"@xla//xla/hlo/builder/lib:math",
"@xla//xla/hlo/builder/lib:matrix",
"@xla//xla/hlo/builder/lib:prng",
"//src/xla",
"//src/xla/client",
"//src/xla/hlo/builder",
],
visibility = ["//visibility:public"],
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xla/client/lib/arithmetic.h"
#include "xla/hlo/builder/lib/arithmetic.h"

#include "../xla_builder.h"

Expand All @@ -23,10 +23,4 @@ extern "C" {
xla::XlaOp res = xla::ArgMax(input_, (xla::PrimitiveType) output_type, axis);
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}

XlaOp* ArgMin(XlaOp& input, int output_type, int axis) {
auto& input_ = reinterpret_cast<xla::XlaOp&>(input);
xla::XlaOp res = xla::ArgMin(input_, (xla::PrimitiveType) output_type, axis);
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xla/client/lib/constants.h"
#include "xla/hlo/builder/lib/constants.h"

#include "../xla_builder.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xla/client/lib/math.h"
#include "xla/hlo/builder/lib/math.h"

#include "../xla_builder.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xla/client/lib/matrix.h"
#include "xla/hlo/builder/lib/matrix.h"

#include "../xla_builder.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xla/client/lib/prng.h"
#include "xla/hlo/builder/lib/prng.h"

#include "../../shape.h"
#include "../../../shape.h"
#include "../xla_builder.h"

xla::BitGeneratorTy BitGenerator(int bit_generator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ limitations under the License.
#include <string>

#include "absl/types/span.h"
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
#include "xla/shape.h"
#include "xla/xla_data.pb.h"

#include "../literal.h"
#include "../shape.h"
#include "../xla_data.pb.h"
#include "../../literal.h"
#include "../../shape.h"
#include "../../xla_data.pb.h"
#include "xla_builder.h"
#include "xla_computation.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ limitations under the License.
*/
#include <functional>

#include "xla/client/xla_builder.h"
#include "xla/hlo/builder/xla_builder.h"

extern "C" {
struct XlaOp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xla/client/xla_computation.h"
#include "xla/hlo/builder/xla_computation.h"

#include "../../ffi.h"
#include "../../../ffi.h"
#include "xla_computation.h"

extern "C" {
Expand Down
28 changes: 0 additions & 28 deletions spidr/backend/src/xla/status.cpp

This file was deleted.

18 changes: 0 additions & 18 deletions spidr/backend/src/xla/status.h

This file was deleted.

2 changes: 2 additions & 0 deletions spidr/postinstall.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/bin/sh -e

if [ "$SPIDR_INSTALL_SUPPORT_LIBS" = false ]; then exit 0; fi

script_dir=$(CDPATH="" cd -- "$(dirname -- "$0")" && pwd)
xla_ext_version=$(cat "$script_dir/backend/VERSION")

Expand Down
15 changes: 7 additions & 8 deletions spidr/spidr.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@ modules =
BayesianOptimization,
BayesianOptimization.Acquisition,

Compiler.Xla.Client.Lib.Arithmetic,
Compiler.Xla.Client.Lib.Constants,
Compiler.Xla.Client.Lib.Math,
Compiler.Xla.Client.Lib.Matrix,
Compiler.Xla.Client.Lib.PRNG,
Compiler.Xla.Client.XlaBuilder,
Compiler.Xla.Client.ExecutableBuildOptions,
Compiler.Xla.Client.XlaComputation,
Compiler.Xla.HLO.Builder.Lib.Arithmetic,
Compiler.Xla.HLO.Builder.Lib.Constants,
Compiler.Xla.HLO.Builder.Lib.Math,
Compiler.Xla.HLO.Builder.Lib.Matrix,
Compiler.Xla.HLO.Builder.Lib.PRNG,
Compiler.Xla.HLO.Builder.XlaBuilder,
Compiler.Xla.HLO.Builder.XlaComputation,
Compiler.Xla.PJRT.C.PjrtCApi,
Compiler.Xla.PJRT.PjrtExecutable,
Compiler.Xla.Literal,
Compiler.Xla.Shape,
Compiler.Xla.ShapeUtil,
Compiler.Xla.XlaData,
Compiler.Xla.Status,

Compiler.Eval,
Compiler.Expr,
Expand Down
15 changes: 7 additions & 8 deletions spidr/src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ import Data.List.Elem
import Compiler.Expr
import Compiler.FFI
import Compiler.LiteralRW
import Compiler.Xla.Client.Lib.Arithmetic
import Compiler.Xla.Client.Lib.Constants
import Compiler.Xla.Client.Lib.Math
import Compiler.Xla.Client.Lib.Matrix
import Compiler.Xla.Client.Lib.PRNG
import Compiler.Xla.Client.ExecutableBuildOptions
import Compiler.Xla.Client.XlaBuilder
import Compiler.Xla.Client.XlaComputation
import Compiler.Xla.HLO.Builder.Lib.Arithmetic
import Compiler.Xla.HLO.Builder.Lib.Constants
import Compiler.Xla.HLO.Builder.Lib.Math
import Compiler.Xla.HLO.Builder.Lib.Matrix
import Compiler.Xla.HLO.Builder.Lib.PRNG
import Compiler.Xla.HLO.Builder.XlaBuilder
import Compiler.Xla.HLO.Builder.XlaComputation
import Compiler.Xla.PJRT.C.PjrtCApi
import Compiler.Xla.PJRT.PjrtExecutable
import Compiler.Xla.Literal
Expand Down Expand Up @@ -183,7 +183,6 @@ interpret @{cache} xlaBuilder (MkFn params root env) = do
Asinh => asinh
Acosh => acosh
Atanh => atanh
interpretE (Argmin {out} axis x) = argMin {outputType = out} !(interpretE x) axis
interpretE (Argmax {out} axis x) = argMax {outputType = out} !(interpretE x) axis
interpretE (Select pred true false) =
select !(interpretE pred) !(interpretE true) !(interpretE false)
Expand Down
3 changes: 0 additions & 3 deletions spidr/src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ data Expr : Type where
Reverse : (axes : List Nat) -> Expr -> Expr
BinaryElementwise : BinaryOp -> Expr -> Expr -> Expr
UnaryElementwise : UnaryOp -> Expr -> Expr
Argmin : Primitive out => (axis : Nat) -> Expr -> Expr
Argmax : Primitive out => (axis : Nat) -> Expr -> Expr
Select : (predicate, onTrue, onFalse : Expr) -> Expr
Cond : (pred : Expr) -> (onTrue : Fn 1) -> (onTrueArg : Expr) ->
Expand Down Expand Up @@ -217,8 +216,6 @@ showExpr indent (Reverse axes x) = "Reverse \{axes} (\{showExpr indent x})"
showExpr indent (BinaryElementwise op x y) =
"\{show op} (\{showExpr indent x}) (\{showExpr indent y})"
showExpr indent (UnaryElementwise op x) = "\{show op} (\{showExpr indent x})"
showExpr indent (Argmin {out} axis x) =
"Argmin {outType = \{xlaIdentifier {dtype = out}}} \{axis} (\{showExpr indent x})"
showExpr indent (Argmax {out} axis x) =
"Argmax {outType = \{xlaIdentifier {dtype = out}}} \{axis} (\{showExpr indent x})"
showExpr indent (Select p t f) =
Expand Down
Loading

0 comments on commit b68d523

Please sign in to comment.