From 865f2be105987a8e236ffa1a669aa323ce89fd7f Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 23 Oct 2023 19:41:20 -0700 Subject: [PATCH] Testing openxla/stablehlo#1810 build changes PiperOrigin-RevId: 575997096 --- third_party/stablehlo/temporary.patch | 1470 +++++++++++++++++ .../stablehlo_legalize_to_hlo.cc | 1 + xla/service/all_reduce_promotion.cc | 1 + xla/service/float_normalization.cc | 11 + xla/service/float_normalization_test.cc | 163 ++ xla/service/reduce_scatter_decomposer.cc | 6 +- xla/service/spmd/spmd_partitioner.cc | 29 +- 7 files changed, 1671 insertions(+), 10 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 820d2dea0fb42..c66699bf6c0bf 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -34,6 +34,1355 @@ diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ":stablehlo_ops", ":stablehlo_ops_inc_gen", ":stablehlo_pass_inc_gen", +@@ -1311,3 +1330,12 @@ + ":vhlo_ops_td_files", + ], + ) ++ ++test_suite( ++ name = "all_tests", ++ tests = [ ++ "//stablehlo/tests:stablehlo_tests", ++ "//stablehlo/testdata:stablehlo_data_tests", ++ "//stablehlo/conversions/tosa/tests:stablehlo_tosa_tests" ++ ], ++) +diff --ruN a/stablehlo/BUILD.bazel.orig b/stablehlo/BUILD.bazel.orig +--- stablehlo/BUILD.bazel.orig ++++ stablehlo/BUILD.bazel.orig +@@ -0,0 +1,1332 @@ ++# Copyright 2023 The StableHLO Authors. All Rights Reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# https://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "gentbl_filegroup", "td_library") ++ ++package( ++ default_visibility = ["//visibility:public"], ++ licenses = ["notice"], ++) ++ ++exports_files([ ++ "LICENSE", ++ "stablehlo/integrations/python/ChloModule.cpp", ++ "stablehlo/integrations/python/PortableApi.cpp", ++ "stablehlo/integrations/python/PortableApi.h", ++ "stablehlo/integrations/python/StablehloModule.cpp", ++ "stablehlo/integrations/python/VhloModule.cpp", ++]) ++ ++filegroup( ++ name = "stablehlo_ops_td_filegroup", ++ srcs = glob(["stablehlo/dialect/*.td"]), ++) ++ ++cc_library( ++ name = "base", ++ srcs = [ ++ "stablehlo/dialect/Base.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/Base.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base_attr_interfaces_inc_gen", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:BytecodeReader", ++ "@llvm-project//mlir:BytecodeWriter", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "base_attr_interfaces_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-attr-interface-decls"], ++ "stablehlo/dialect/BaseAttrInterfaces.h.inc", ++ ), ++ ( ++ ["-gen-attr-interface-defs"], ++ "stablehlo/dialect/BaseAttrInterfaces.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/Base.td", ++ deps = [":stablehlo_ops_td_files"], ++) ++ ++td_library( ++ name = "base_td_files", ++ srcs = [ ++ "stablehlo/dialect/Base.td", ++ ], ++ includes = ["."], ++ deps = [ ++ "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ "@llvm-project//mlir:QuantizationOpsTdFiles", ++ ], ++) ++ ++cc_library( ++ name = "broadcast_utils", ++ srcs = [ ++ "stablehlo/dialect/BroadcastUtils.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/BroadcastUtils.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:ShapeDialect", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "chlo_attrs_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-attrdef-decls"], ++ "stablehlo/dialect/ChloAttrs.h.inc", ++ ), ++ ( ++ ["-gen-attrdef-defs"], ++ "stablehlo/dialect/ChloAttrs.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/ChloOps.td", ++ deps = [ ++ ":chlo_ops_td_files", ++ ], ++) ++ ++CHLO_CAPI_SOURCES = [ ++ "stablehlo/integrations/c/ChloAttributes.cpp", ++ "stablehlo/integrations/c/ChloDialect.cpp", ++] ++ ++CHLO_CAPI_HEADERS = [ ++ "stablehlo/integrations/c/ChloAttributes.h", ++ "stablehlo/integrations/c/ChloDialect.h", ++] ++ ++cc_library( ++ name = "chlo_capi", ++ srcs = CHLO_CAPI_SOURCES, ++ hdrs = CHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ ":chlo_ops", ++ "@llvm-project//mlir:CAPIIR", ++ ], ++) ++ ++# Header-only target, used when using the C API from a separate shared library. ++cc_library( ++ name = "chlo_capi_headers", ++ hdrs = CHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//mlir:CAPIIRHeaders", ++ ], ++) ++ ++# Alwayslink target, used when exporting the C API from a shared library. ++cc_library( ++ name = "chlo_capi_objects", ++ srcs = CHLO_CAPI_SOURCES, ++ hdrs = CHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ ":chlo_ops", ++ "@llvm-project//mlir:CAPIIRObjects", ++ ], ++ alwayslink = True, ++) ++ ++gentbl_cc_library( ++ name = "chlo_enums_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-enum-decls"], ++ "stablehlo/dialect/ChloEnums.h.inc", ++ ), ++ ( ++ ["-gen-enum-defs"], ++ "stablehlo/dialect/ChloEnums.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/ChloOps.td", ++ deps = [ ++ ":chlo_ops_td_files", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "chlo_ops_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-op-decls"], ++ "stablehlo/dialect/ChloOps.h.inc", ++ ), ++ ( ++ ["-gen-op-defs"], ++ "stablehlo/dialect/ChloOps.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/ChloOps.td", ++ deps = [ ++ ":chlo_ops_td_files", ++ ], ++) ++ ++filegroup( ++ name = "chlo_ops_py_files", ++ srcs = [ ++ "stablehlo/integrations/python/mlir/dialects/chlo.py", ++ ":chlo_ops_py_gen", ++ ], ++) ++ ++gentbl_filegroup( ++ name = "chlo_ops_py_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-python-op-bindings", ++ "-bind-dialect=chlo", ++ ], ++ "stablehlo/integrations/python/mlir/dialects/_chlo_ops_gen.py", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/integrations/python/mlir/dialects/ChloOps.td", ++ deps = [ ++ ":chlo_ops_td_files", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ ], ++) ++ ++td_library( ++ name = "chlo_ops_td_files", ++ srcs = [ ++ "stablehlo/dialect/ChloEnums.td", ++ "stablehlo/dialect/ChloOps.td", ++ ], ++ includes = ["."], ++ deps = [ ++ ":base_td_files", ++ "@llvm-project//mlir:BuiltinDialectTdFiles", ++ "@llvm-project//mlir:ControlFlowInterfacesTdFiles", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ ], ++) ++ ++cc_library( ++ name = "chlo_ops", ++ srcs = [ ++ "stablehlo/dialect/ChloBytecode.cpp", ++ "stablehlo/dialect/ChloOps.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/ChloBytecode.h", ++ "stablehlo/dialect/ChloOps.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base", ++ ":broadcast_utils", ++ ":chlo_attrs_inc_gen", ++ ":chlo_enums_inc_gen", ++ ":chlo_ops_inc_gen", ++ ":stablehlo_type_inference", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:BytecodeReader", ++ "@llvm-project//mlir:BytecodeWriter", ++ "@llvm-project//mlir:ComplexDialect", ++ "@llvm-project//mlir:ControlFlowInterfaces", ++ "@llvm-project//mlir:Dialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:TransformUtils", ++ ], ++) ++ ++cc_library( ++ name = "experimental_ops", ++ srcs = [ ++ "stablehlo/dialect/ExperimentalOps.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/ExperimentalOps.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "interpreter_ops", ++ srcs = [ ++ "stablehlo/reference/InterpreterOps.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/InterpreterOps.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":interpreter_ops_inc_gen", ++ ":reference_interpretervalue", ++ ":reference_numpy", ++ ":reference_ops", ++ ":reference_process_grid", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "interpreter_ops_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-op-decls"], ++ "stablehlo/reference/InterpreterOps.h.inc", ++ ), ++ ( ++ ["-gen-op-defs"], ++ "stablehlo/reference/InterpreterOps.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/reference/InterpreterOps.td", ++ deps = [ ++ ":interpreter_ops_td_files", ++ ], ++) ++ ++td_library( ++ name = "interpreter_ops_td_files", ++ srcs = [ ++ "stablehlo/reference/InterpreterOps.td", ++ ], ++ includes = ["."], ++ deps = [ ++ ":base_td_files", ++ ], ++) ++ ++cc_library( ++ name = "reference_axes", ++ srcs = [ ++ "stablehlo/reference/Axes.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Axes.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "reference_element", ++ srcs = [ ++ "stablehlo/reference/Element.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Element.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_errors", ++ ":reference_types", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:ComplexDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_errors", ++ hdrs = [ ++ "stablehlo/reference/Errors.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//llvm:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_index", ++ srcs = [ ++ "stablehlo/reference/Index.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Index.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_interpretervalue", ++ srcs = [ ++ "stablehlo/reference/InterpreterValue.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/InterpreterValue.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_errors", ++ ":reference_tensor", ++ ":reference_token", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "reference_numpy", ++ srcs = [ ++ "stablehlo/reference/NumPy.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/NumPy.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_tensor", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "reference_ops", ++ srcs = [ ++ "stablehlo/reference/Ops.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Ops.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_axes", ++ ":reference_element", ++ ":reference_errors", ++ ":reference_index", ++ ":reference_interpretervalue", ++ ":reference_process", ++ ":reference_process_grid", ++ ":reference_scope", ++ ":reference_tensor", ++ ":reference_token", ++ ":reference_types", ++ ":stablehlo_ops", ++ ":stablehlo_type_inference", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_process", ++ srcs = [ ++ "stablehlo/reference/Process.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Process.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_process_grid", ++ ":reference_tensor", ++ ], ++) ++ ++cc_library( ++ name = "reference_process_grid", ++ srcs = [ ++ "stablehlo/reference/ProcessGrid.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/ProcessGrid.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_tensor", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_scope", ++ srcs = [ ++ "stablehlo/reference/Scope.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Scope.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_interpretervalue", ++ ":reference_tensor", ++ ":reference_token", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_tensor", ++ srcs = [ ++ "stablehlo/reference/Tensor.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Tensor.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_axes", ++ ":reference_element", ++ ":reference_errors", ++ ":reference_index", ++ ":reference_types", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "reference_token", ++ srcs = [ ++ "stablehlo/reference/Token.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Token.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":reference_errors", ++ ":stablehlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "reference_types", ++ srcs = [ ++ "stablehlo/reference/Types.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/reference/Types.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//mlir:IR", ++ ], ++) ++ ++cc_library( ++ name = "register", ++ srcs = [ ++ "stablehlo/dialect/Register.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/Register.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":chlo_ops", ++ ":interpreter_ops", ++ ":stablehlo_ops", ++ ":vhlo_ops", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:SparseTensorDialect", ++ "@llvm-project//mlir:TensorDialect", ++ ], ++) ++ ++cc_library( ++ name = "stablehlo_assembly_format", ++ srcs = [ ++ "stablehlo/dialect/AssemblyFormat.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/AssemblyFormat.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "stablehlo_attrs_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-attrdef-decls"], ++ "stablehlo/dialect/StablehloAttrs.h.inc", ++ ), ++ ( ++ ["-gen-attrdef-defs"], ++ "stablehlo/dialect/StablehloAttrs.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/StablehloOps.td", ++ deps = [ ++ ":stablehlo_ops_td_files", ++ ], ++) ++ ++STABLEHLO_CAPI_SOURCES = [ ++ "stablehlo/integrations/c/StablehloAttributes.cpp", ++ "stablehlo/integrations/c/StablehloDialect.cpp", ++ "stablehlo/integrations/c/StablehloTypes.cpp", ++] ++ ++STABLEHLO_CAPI_HEADERS = [ ++ "stablehlo/integrations/c/StablehloAttributes.h", ++ "stablehlo/integrations/c/StablehloDialect.h", ++ "stablehlo/integrations/c/StablehloTypes.h", ++] ++ ++cc_library( ++ name = "stablehlo_capi", ++ srcs = STABLEHLO_CAPI_SOURCES, ++ hdrs = STABLEHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_ops", ++ "@llvm-project//mlir:CAPIIR", ++ ], ++) ++ ++# Header-only target, used when using the C API from a separate shared library. ++cc_library( ++ name = "stablehlo_capi_headers", ++ hdrs = STABLEHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//mlir:CAPIIRHeaders", ++ ], ++) ++ ++# Alwayslink target, used when exporting the C API from a shared library. ++cc_library( ++ name = "stablehlo_capi_objects", ++ srcs = STABLEHLO_CAPI_SOURCES, ++ hdrs = STABLEHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_ops", ++ "@llvm-project//mlir:CAPIIRObjects", ++ ], ++ alwayslink = True, ++) ++ ++gentbl_cc_library( ++ name = "stablehlo_enums_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-enum-decls"], ++ "stablehlo/dialect/StablehloEnums.h.inc", ++ ), ++ ( ++ ["-gen-enum-defs"], ++ "stablehlo/dialect/StablehloEnums.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/StablehloOps.td", ++ deps = [ ++ ":stablehlo_ops_td_files", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "stablehlo_ops_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-op-decls"], ++ "stablehlo/dialect/StablehloOps.h.inc", ++ ), ++ ( ++ ["-gen-op-defs"], ++ "stablehlo/dialect/StablehloOps.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/StablehloOps.td", ++ deps = [ ++ ":stablehlo_ops_td_files", ++ ], ++) ++ ++filegroup( ++ name = "stablehlo_ops_py_files", ++ srcs = [ ++ "stablehlo/integrations/python/mlir/dialects/stablehlo.py", ++ ":stablehlo_ops_py_gen", ++ ], ++) ++ ++gentbl_filegroup( ++ name = "stablehlo_ops_py_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-python-op-bindings", ++ "-bind-dialect=stablehlo", ++ ], ++ "stablehlo/integrations/python/mlir/dialects/_stablehlo_ops_gen.py", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/integrations/python/mlir/dialects/StablehloOps.td", ++ deps = [ ++ ":stablehlo_ops_td_files", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ ], ++) ++ ++td_library( ++ name = "stablehlo_ops_td_files", ++ srcs = [ ++ "stablehlo/dialect/Base.td", ++ "stablehlo/dialect/StablehloAttrs.td", ++ "stablehlo/dialect/StablehloEnums.td", ++ "stablehlo/dialect/StablehloOps.td", ++ ], ++ includes = ["."], ++ deps = [ ++ ":base_td_files", ++ "@llvm-project//mlir:BuiltinDialectTdFiles", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ "@llvm-project//mlir:ShapeOpsTdFiles", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "stablehlo_pass_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-pass-decls", ++ ], ++ "stablehlo/transforms/Passes.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/transforms/Passes.td", ++ deps = ["@llvm-project//mlir:PassBaseTdFiles"], ++) ++ ++cc_library( ++ name = "stablehlo_passes", ++ srcs = [ ++ "stablehlo/transforms/PassPipelines.cpp", ++ "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", ++ "stablehlo/transforms/StablehloLegalizeToVhlo.cpp", ++ "stablehlo/transforms/StablehloRefineShapes.cpp", ++ "stablehlo/transforms/VhloLegalizeToStablehlo.cpp", ++ "stablehlo/transforms/VhloToVersion.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/transforms/MapStablehloToVhlo.h", ++ "stablehlo/transforms/Passes.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base", ++ ":chlo_ops", ++ ":experimental_ops", ++ ":stablehlo_ops", ++ ":stablehlo_ops_inc_gen", ++ ":stablehlo_pass_inc_gen", ++ ":stablehlo_type_inference", ++ ":version", ++ ":vhlo_ops", ++ ":vhlo_types", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TensorDialect", ++ "@llvm-project//mlir:TransformUtils", ++ "@llvm-project//mlir:Transforms", ++ ], ++) ++ ++cc_library( ++ name = "stablehlo_portable_api", ++ srcs = [ ++ "stablehlo/api/PortableApi.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/api/PortableApi.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_ops", ++ ":stablehlo_serialization", ++ ":version", ++ ":vhlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:BytecodeWriter", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Parser", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "stablehlo_serialization", ++ srcs = [ ++ "stablehlo/dialect/Serialization.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/Serialization.h", ++ ], ++ deps = [ ++ ":stablehlo_ops", ++ ":stablehlo_passes", ++ ":version", ++ ":vhlo_ops", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:BytecodeReader", ++ "@llvm-project//mlir:BytecodeWriter", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Parser", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "stablehlo_type_inference", ++ srcs = [ ++ "stablehlo/dialect/TypeInference.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/TypeInference.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base", ++ ":stablehlo_assembly_format", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++cc_library( ++ name = "stablehlo_ops", ++ srcs = [ ++ "stablehlo/dialect/StablehloBytecode.cpp", ++ "stablehlo/dialect/StablehloOps.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/StablehloBytecode.h", ++ "stablehlo/dialect/StablehloOps.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base", ++ ":stablehlo_assembly_format", ++ ":stablehlo_attrs_inc_gen", ++ ":stablehlo_enums_inc_gen", ++ ":stablehlo_ops_inc_gen", ++ ":stablehlo_type_inference", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:ArithDialect", ++ "@llvm-project//mlir:ComplexDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:InferTypeOpInterface", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:SparseTensorDialect", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:TensorDialect", ++ ], ++) ++ ++cc_binary( ++ name = "stablehlo-lsp-server", ++ srcs = [ ++ "stablehlo/tools/StablehloLspServerMain.cpp", ++ ], ++ deps = [ ++ ":register", ++ "@llvm-project//mlir:AllExtensions", ++ "@llvm-project//mlir:AllPassesAndDialects", ++ "@llvm-project//mlir:MlirLspServerLib", ++ ], ++) ++ ++cc_binary( ++ name = "stablehlo-translate", ++ srcs = [ ++ "stablehlo/tools/StablehloTranslateMain.cpp", ++ ], ++ deps = [ ++ ":interpreter_ops", ++ ":reference_errors", ++ ":reference_interpretervalue", ++ ":reference_ops", ++ ":reference_process_grid", ++ ":reference_scope", ++ ":reference_tensor", ++ ":register", ++ ":stablehlo_ops", ++ ":stablehlo_serialization", ++ ":vhlo_ops", ++ "//stablehlo/tests:check_ops", ++ "//stablehlo/tests:test_utils", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:AllPassesAndDialects", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:Transforms", ++ "@llvm-project//mlir:TranslateLib", ++ ], ++) ++ ++cc_binary( ++ name = "stablehlo-opt", ++ srcs = [ ++ "stablehlo/tools/StablehloOptMain.cpp", ++ ], ++ deps = [ ++ ":interpreter_ops", ++ ":register", ++ ":stablehlo_passes", ++ ":tosa_passes", ++ "//stablehlo/tests:test_utils", ++ "@llvm-project//mlir:AllExtensions", ++ "@llvm-project//mlir:AllPassesAndDialects", ++ "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:TosaDialect", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "tosa_pass_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-pass-decls", ++ "-name=StablehloTOSATransforms", ++ ], ++ "stablehlo/conversions/tosa/transforms/Passes.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/conversions/tosa/transforms/Passes.td", ++ deps = ["@llvm-project//mlir:PassBaseTdFiles"], ++) ++ ++cc_library( ++ name = "tosa_passes", ++ srcs = [ ++ "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp", ++ "stablehlo/conversions/tosa/transforms/StablehloPrepareForTosa.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/conversions/tosa/transforms/Passes.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_ops", ++ ":tosa_pass_inc_gen", ++ ":tosa_pdll_inc_gen", ++ "@llvm-project//mlir:FuncDialect", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Parser", ++ "@llvm-project//mlir:Pass", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:TosaDialect", ++ "@llvm-project//mlir:Transforms", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "tosa_pdll_inc_gen", ++ tbl_outs = [ ++ ( ++ ["-x=cpp"], ++ "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll.h.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-pdll", ++ td_file = "stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll", ++ deps = [ ++ ":stablehlo_ops_td_files", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ "@llvm-project//mlir:TosaDialectTdFiles", ++ ], ++) ++ ++cc_library( ++ name = "version", ++ srcs = [ ++ "stablehlo/dialect/Version.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/Version.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++VHLO_CAPI_SOURCES = [ ++ "stablehlo/integrations/c/VhloDialect.cpp", ++] ++ ++VHLO_CAPI_HEADERS = [ ++ "stablehlo/integrations/c/VhloDialect.h", ++] ++ ++cc_library( ++ name = "vhlo_capi", ++ srcs = VHLO_CAPI_SOURCES, ++ hdrs = VHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ ":vhlo_ops", ++ "@llvm-project//mlir:CAPIIR", ++ ], ++) ++ ++# Header-only target, used when using the C API from a separate shared library. ++cc_library( ++ name = "vhlo_capi_headers", ++ hdrs = VHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ "@llvm-project//mlir:CAPIIRHeaders", ++ ], ++) ++ ++# Alwayslink target, used when exporting the C API from a shared library. ++cc_library( ++ name = "vhlo_capi_objects", ++ srcs = VHLO_CAPI_SOURCES, ++ hdrs = VHLO_CAPI_HEADERS, ++ strip_include_prefix = ".", ++ deps = [ ++ ":vhlo_ops", ++ "@llvm-project//mlir:CAPIIRObjects", ++ ], ++ alwayslink = True, ++) ++ ++filegroup( ++ name = "vhlo_ops_py_files", ++ srcs = [ ++ "stablehlo/integrations/python/mlir/dialects/vhlo.py", ++ ":vhlo_ops_py_gen", ++ ], ++) ++ ++gentbl_filegroup( ++ name = "vhlo_ops_py_gen", ++ tbl_outs = [ ++ ( ++ [ ++ "-gen-python-op-bindings", ++ "-bind-dialect=vhlo", ++ ], ++ "stablehlo/integrations/python/mlir/dialects/_vhlo_ops_gen.py", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/integrations/python/mlir/dialects/VhloOps.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_attr_interfaces_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-attr-interface-decls"], ++ "stablehlo/dialect/VhloAttrInterfaces.h.inc", ++ ), ++ ( ++ ["-gen-attr-interface-defs"], ++ "stablehlo/dialect/VhloAttrInterfaces.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloAttrs.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_attrs_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-attrdef-decls"], ++ "stablehlo/dialect/VhloAttrs.h.inc", ++ ), ++ ( ++ ["-gen-attrdef-defs"], ++ "stablehlo/dialect/VhloAttrs.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloOps.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_enums_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-enum-decls"], ++ "stablehlo/dialect/VhloEnums.h.inc", ++ ), ++ ( ++ ["-gen-enum-defs"], ++ "stablehlo/dialect/VhloEnums.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloEnums.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_op_interfaces_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-op-interface-decls"], ++ "stablehlo/dialect/VhloOpInterfaces.h.inc", ++ ), ++ ( ++ ["-gen-op-interface-defs"], ++ "stablehlo/dialect/VhloOpInterfaces.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloOps.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) ++ ++cc_library( ++ name = "vhlo_ops", ++ srcs = [ ++ "stablehlo/dialect/VhloBytecode.cpp", ++ "stablehlo/dialect/VhloOps.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/VhloBytecode.h", ++ "stablehlo/dialect/VhloOps.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":base", ++ ":stablehlo_assembly_format", ++ ":version", ++ ":vhlo_attr_interfaces_inc_gen", ++ ":vhlo_attrs_inc_gen", ++ ":vhlo_enums_inc_gen", ++ ":vhlo_op_interfaces_inc_gen", ++ ":vhlo_ops_inc_gen", ++ ":vhlo_types", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:Support", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_ops_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-op-decls"], ++ "stablehlo/dialect/VhloOps.h.inc", ++ ), ++ ( ++ ["-gen-op-defs"], ++ "stablehlo/dialect/VhloOps.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloOps.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) ++ ++td_library( ++ name = "vhlo_ops_td_files", ++ srcs = [ ++ "stablehlo/dialect/VhloAttrs.td", ++ "stablehlo/dialect/VhloBase.td", ++ "stablehlo/dialect/VhloDialect.td", ++ "stablehlo/dialect/VhloEnums.td", ++ "stablehlo/dialect/VhloOps.td", ++ "stablehlo/dialect/VhloTypes.td", ++ ], ++ includes = ["."], ++ deps = [ ++ "@llvm-project//mlir:BuiltinDialectTdFiles", ++ "@llvm-project//mlir:OpBaseTdFiles", ++ "@llvm-project//mlir:ShapeOpsTdFiles", ++ ], ++) ++ ++cc_library( ++ name = "vhlo_types", ++ srcs = [ ++ "stablehlo/dialect/VhloTypes.cpp", ++ ], ++ hdrs = [ ++ "stablehlo/dialect/VhloTypes.h", ++ ], ++ strip_include_prefix = ".", ++ deps = [ ++ ":stablehlo_assembly_format", ++ ":version", ++ ":vhlo_type_interfaces_inc_gen", ++ ":vhlo_types_inc_gen", ++ "@llvm-project//llvm:Support", ++ "@llvm-project//mlir:IR", ++ "@llvm-project//mlir:QuantOps", ++ "@llvm-project//mlir:ShapeDialect", ++ "@llvm-project//mlir:Support", ++ "@llvm-project//mlir:Transforms", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_type_interfaces_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-type-interface-decls"], ++ "stablehlo/dialect/VhloTypeInterfaces.h.inc", ++ ), ++ ( ++ ["-gen-type-interface-defs"], ++ "stablehlo/dialect/VhloTypeInterfaces.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloTypes.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) ++ ++gentbl_cc_library( ++ name = "vhlo_types_inc_gen", ++ strip_include_prefix = ".", ++ tbl_outs = [ ++ ( ++ ["-gen-typedef-decls"], ++ "stablehlo/dialect/VhloTypeDefs.h.inc", ++ ), ++ ( ++ ["-gen-typedef-defs"], ++ "stablehlo/dialect/VhloTypeDefs.cpp.inc", ++ ), ++ ], ++ tblgen = "@llvm-project//mlir:mlir-tblgen", ++ td_file = "stablehlo/dialect/VhloOps.td", ++ deps = [ ++ ":vhlo_ops_td_files", ++ ], ++) diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt --- stablehlo/CMakeLists.txt +++ stablehlo/CMakeLists.txt @@ -181,6 +1530,54 @@ diff --ruN a/stablehlo/CMakeLists.txt b/stablehlo/CMakeLists.txt #------------------------------------------------------------------------------- # Directory setup +diff --ruN a/stablehlo/WORKSPACE.bazel b/stablehlo/WORKSPACE.bazel +--- stablehlo/WORKSPACE.bazel ++++ stablehlo/WORKSPACE.bazel +@@ -12,6 +12,8 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + """Workspace for StableHLO.""" ++ ++workspace(name = "stablehlo") + + load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel b/stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel +--- stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel ++++ stablehlo/stablehlo/conversions/tosa/tests/BUILD.bazel +@@ -27,8 +27,8 @@ + substitutions = { + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"), +- "@STABLEHLO_TOOLS_DIR@": ".", +- "@STABLEHLO_SOURCE_DIR@": ".", ++ "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", + }, + template = "lit.site.cfg.py.in", + ) +@@ -45,6 +45,12 @@ + "@llvm-project//llvm:FileCheck", + ], + size = "small", ++ tags = ["stablehlo_tosa_tests"], + ) + for src in glob(["**/*.mlir"]) + ] ++ ++test_suite( ++ name = "stablehlo_tosa_tests", ++ tags = ["stablehlo_tosa_tests"], ++) +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in b/stablehlo/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in +--- stablehlo/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in ++++ stablehlo/stablehlo/conversions/tosa/tests/lit.site.cfg.py.in +@@ -17,4 +17,4 @@ + lit.llvm.initialize(lit_config, config) + config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" + config.stablehlo_tools_dir = "@STABLEHLO_TOOLS_DIR@" +-lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@/stablehlo/conversions/tosa/tests/lit.cfg.py") ++lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@" + "/stablehlo/conversions/tosa/tests/lit.cfg.py") diff --ruN a/stablehlo/stablehlo/dialect/Base.cpp b/stablehlo/stablehlo/dialect/Base.cpp --- stablehlo/stablehlo/dialect/Base.cpp +++ stablehlo/stablehlo/dialect/Base.cpp @@ -999,6 +2396,70 @@ diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/ parser.parseColon() || parser.parseType(reduceOpFnType) || parser.parseOptionalLocationSpecifier(explicitLoc)) return failure(); +diff --ruN a/stablehlo/stablehlo/testdata/BUILD.bazel b/stablehlo/stablehlo/testdata/BUILD.bazel +--- stablehlo/stablehlo/testdata/BUILD.bazel ++++ stablehlo/stablehlo/testdata/BUILD.bazel +@@ -27,8 +27,8 @@ + substitutions = { + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"), +- "@STABLEHLO_TOOLS_DIR@": ".", +- "@STABLEHLO_SOURCE_DIR@": ".", ++ "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", + }, + template = "lit.site.cfg.py.in", + ) +@@ -46,6 +46,13 @@ + "@llvm-project//llvm:FileCheck", + ], + size = "small", ++ tags = ["stablehlo_data_tests"], ++ + ) + for src in glob(["**/*.mlir"]) + ] ++ ++test_suite( ++ name = "stablehlo_data_tests", ++ tags = ["stablehlo_data_tests"], ++) +diff --ruN a/stablehlo/stablehlo/testdata/lit.site.cfg.py.in b/stablehlo/stablehlo/testdata/lit.site.cfg.py.in +--- stablehlo/stablehlo/testdata/lit.site.cfg.py.in ++++ stablehlo/stablehlo/testdata/lit.site.cfg.py.in +@@ -18,4 +18,4 @@ + lit.llvm.initialize(lit_config, config) + config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" + config.stablehlo_tools_dir = "@STABLEHLO_TOOLS_DIR@" +-lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@/stablehlo/testdata/lit.cfg.py") ++lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@" + "/stablehlo/testdata/lit.cfg.py") +diff --ruN a/stablehlo/stablehlo/tests/BUILD.bazel b/stablehlo/stablehlo/tests/BUILD.bazel +--- stablehlo/stablehlo/tests/BUILD.bazel ++++ stablehlo/stablehlo/tests/BUILD.bazel +@@ -130,8 +130,8 @@ + substitutions = { + "@LIT_SITE_CFG_IN_HEADER@": "# Autogenerated, do not edit.", + "@LLVM_TOOLS_DIR@": package_path("@llvm-project//llvm:BUILD"), +- "@STABLEHLO_TOOLS_DIR@": ".", +- "@STABLEHLO_SOURCE_DIR@": ".", ++ "\"@STABLEHLO_TOOLS_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", ++ "\"@STABLEHLO_SOURCE_DIR@\"": "os.path.join(os.environ['TEST_SRCDIR'], 'stablehlo')", + }, + template = "lit.site.cfg.py.in", + ) +@@ -150,6 +150,12 @@ + "@llvm-project//llvm:not", + ] + glob(["%s.bc" % src]), + size = "small", ++ tags = ["stablehlo_tests"], + ) + for src in glob(["**/*.mlir"]) + ] ++ ++test_suite( ++ name = "stablehlo_tests", ++ tags = ["stablehlo_tests"], ++) diff --ruN a/stablehlo/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/stablehlo/tests/infer_stablehlo.mlir --- stablehlo/stablehlo/tests/infer_stablehlo.mlir +++ stablehlo/stablehlo/tests/infer_stablehlo.mlir @@ -1063,6 +2524,15 @@ diff --ruN a/stablehlo/stablehlo/tests/legalize_stablehlo_to_vhlo_invalid.mlir b func.return %0 : tensor<16xf32> } // CHECK-DISABLED-LABEL: "type_sparsity" +diff --ruN a/stablehlo/stablehlo/tests/lit.site.cfg.py.in b/stablehlo/stablehlo/tests/lit.site.cfg.py.in +--- stablehlo/stablehlo/tests/lit.site.cfg.py.in ++++ stablehlo/stablehlo/tests/lit.site.cfg.py.in +@@ -18,4 +18,4 @@ + lit.llvm.initialize(lit_config, config) + config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" + config.stablehlo_tools_dir = "@STABLEHLO_TOOLS_DIR@" +-lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@/stablehlo/tests/lit.cfg.py") ++lit_config.load_config(config, "@STABLEHLO_SOURCE_DIR@" + "/stablehlo/tests/lit.cfg.py") diff --ruN a/stablehlo/stablehlo/tests/ops_sparse.mlir b/stablehlo/stablehlo/tests/ops_sparse.mlir --- stablehlo/stablehlo/tests/ops_sparse.mlir +++ stablehlo/stablehlo/tests/ops_sparse.mlir diff --git a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index a3c401a7b1abe..df888011de44d 100644 --- a/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -15,6 +15,7 @@ limitations under the License. #include +// FIXME: Test to get GH building #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/xla/service/all_reduce_promotion.cc b/xla/service/all_reduce_promotion.cc index 00469a2a6c9e7..30965128a8152 100644 --- a/xla/service/all_reduce_promotion.cc +++ b/xla/service/all_reduce_promotion.cc @@ -49,6 +49,7 @@ std::unique_ptr CloneAllReduce( return inst->GetModule()->AddEmbeddedComputation(promoted.Build()); }(); new_inst->set_to_apply(to_apply_promoted); + to_apply_promoted->SetCollectiveCallInstruction(new_inst.get()); return new_inst; } diff --git a/xla/service/float_normalization.cc b/xla/service/float_normalization.cc index 267a92383aab1..84774a3b4884e 100644 --- a/xla/service/float_normalization.cc +++ b/xla/service/float_normalization.cc @@ -335,6 +335,9 @@ Status FloatNormalizationVisitor::HandleMultipleOutputs(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { + if (comp->IsCollectiveCalledComputation()) { + continue; + } bool comp_has_low_precision = false; if (comp->root_instruction()->shape().element_type() == HighPrecisionType()) { @@ -411,6 +414,9 @@ Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { + if (comp->IsCollectiveCalledComputation()) { + continue; + } bool comp_has_low_precision = false; high_prec_count += CountSubshapesWithMatchingType( comp->root_instruction()->shape(), HighPrecisionType()); @@ -549,6 +555,11 @@ StatusOr FloatNormalization::Run( ", before:\n" + module->ToString()); FloatNormalizationVisitor visitor(float_support_, this); for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { + if (comp->IsCollectiveCalledComputation()) { + XLA_VLOG_LINES(2, "Skip processing collective called computation: " + + comp->ToString()); + continue; + } TF_RETURN_IF_ERROR(comp->Accept(&visitor)); } XLA_VLOG_LINES(2, "FloatNormalization::Run() for " + diff --git a/xla/service/float_normalization_test.cc b/xla/service/float_normalization_test.cc index 3a41960bad932..2d6a976ff59df 100644 --- a/xla/service/float_normalization_test.cc +++ b/xla/service/float_normalization_test.cc @@ -76,6 +76,38 @@ class TestFloatSupport : public FloatSupport { } }; +// The test float class that doesn't support any compute ops for low-precision +// but supports some collectives. +class TestFloatNoComputeSupport : public FloatSupport { + public: + explicit TestFloatNoComputeSupport(PrimitiveType low_precision_type) + : FloatSupport(low_precision_type) {} + ~TestFloatNoComputeSupport() override = default; + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kAllToAll || + hlo.opcode() == HloOpcode::kAllReduce || + hlo.opcode() == HloOpcode::kReduceScatter) { + return true; + } + return false; + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kAllToAll || + hlo.opcode() == HloOpcode::kAllReduce || + hlo.opcode() == HloOpcode::kReduceScatter) { + return true; + } + return false; + } +}; + class FloatNormalizationTest : public HloTestBase { protected: FloatNormalizationTest() @@ -485,4 +517,135 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); } +class FloatNormalizationNoComputeSupportTest : public FloatNormalizationTest { + protected: + bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16) { + TestFloatNoComputeSupport float_support(low_precision_type); + FloatNormalization normalization(&float_support); + + StatusOr result = normalization.Run(module); + EXPECT_IS_OK(result.status()); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + EXPECT_IS_OK(verifier.Run(module).status()); + + return result.value(); + } +}; + +TEST_F(FloatNormalizationNoComputeSupportTest, + NoNormalizationForToApplyMultiOuputAllReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_shape_b = ShapeUtil::MakeShape(BF16, {16, 16}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape_b, "b")); + + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({bf16_shape_a, bf16_shape_b}), {a, b}, + reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape_b, crs, 1)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Since we skip processing to_apply region, nothing should change in the + // original HLO. + EXPECT_FALSE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(1)->shape().element_type(), BF16); + EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); +} + +TEST_F(FloatNormalizationNoComputeSupportTest, + NoNormalizationForToApplyAllReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + + HloInstruction* crs = builder.AddInstruction( + HloInstruction::CreateAllReduce(bf16_shape_a, {a}, reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Since we skip processing to_apply region, nothing should change in the + // original HLO. + EXPECT_FALSE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); +} + +TEST_F(FloatNormalizationNoComputeSupportTest, + NoNormalizationForToApplyReduceScatter) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_shape_scattered = ShapeUtil::MakeShape(BF16, {1, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateReduceScatter( + bf16_shape_scattered, {a}, reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false, /*scatter_dimension*/ 0)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Since we skip processing to_apply region, nothing should change in the + // original HLO. + EXPECT_FALSE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); +} + } // namespace xla diff --git a/xla/service/reduce_scatter_decomposer.cc b/xla/service/reduce_scatter_decomposer.cc index 1fb197bcea8b8..59366639b84e2 100644 --- a/xla/service/reduce_scatter_decomposer.cc +++ b/xla/service/reduce_scatter_decomposer.cc @@ -55,11 +55,15 @@ StatusOr ReduceScatterDecomposer::Run( } // Create an all-reduce + HloComputation *apply_clone = module->AddComputationAndUnifyNamesAndIds( + rs->to_apply()->Clone(), /*is_entry=*/false); HloInstruction *ar = computation->AddInstruction(HloInstruction::CreateAllReduce( - rs->operand(0)->shape(), rs->operands(), rs->to_apply(), + rs->operand(0)->shape(), rs->operands(), apply_clone, rs->replica_groups(), rs->constrain_layout(), channel_id, rs->use_global_device_ids())); + apply_clone->SetCollectiveCallInstruction(ar); + // Create start indices for a dynamic slice to decompose the all-reduce // results. TF_ASSIGN_OR_RETURN( diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 429181dbecfe7..98438aa08ae25 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -48,7 +48,6 @@ limitations under the License. #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_pass_pipeline.h" -#include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" #include "xla/service/spmd/custom_call_handler.h" #include "xla/service/spmd/spmd_partitioner_util.h" @@ -4730,10 +4729,16 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, for (int64_t i = 0; i < num_replicas; ++i) { groups[i].add_replica_ids(i); } - return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, groups, - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); + HloComputation* reduction_clone = + reduction->parent()->AddComputationAndUnifyNamesAndIds( + reduction->Clone(), false); + HloInstruction* all_reduce = + b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction_clone, groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + reduction_clone->SetCollectiveCallInstruction(all_reduce); + return all_reduce; } std::vector device_groups; @@ -4746,10 +4751,16 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, } } } - return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, device_groups, - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/true)); + HloComputation* reduction_clone = + reduction->parent()->AddComputationAndUnifyNamesAndIds( + reduction->Clone(), false); + HloInstruction* all_reduce = + b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction_clone, device_groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/true)); + reduction_clone->SetCollectiveCallInstruction(all_reduce); + return all_reduce; }, [num_partitions](SpmdBuilder* b, HloInstruction* operand, std::vector>& src_dst_pairs,