diff --git a/.bazelrc b/.bazelrc index 142ed60871ce3..f8a9ef174f7ec 100644 --- a/.bazelrc +++ b/.bazelrc @@ -738,42 +738,42 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cpu_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cuda_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_x86_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... @@ -786,8 +786,8 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test @@ -796,15 +796,15 @@ build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP @@ -813,8 +813,8 @@ build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_co # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... diff --git a/third_party/tsl/.bazelrc b/third_party/tsl/.bazelrc index 086e35096080a..5998529e822a8 100644 --- a/third_party/tsl/.bazelrc +++ b/third_party/tsl/.bazelrc @@ -738,42 +738,42 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL -test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cpu_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # CUDA WHEEL -test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_cuda_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # ARM64 WHEEL -test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 -test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 +test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:linux_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS ARM64 WHEEL -test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_arm64_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # MACOS X86 WHEEL -test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium test:macos_x86_wheel_test --@tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. # LINUX CPU PYCPP: -test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... # LINUX CUDA PYCPP: -test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 -test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 +test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_gpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... @@ -786,8 +786,8 @@ test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflo # do not run them. By prefixing the configs with "build", we can run both # `bazel build` and `bazel test` commands with the same config as test configs # inherit from build. -build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only -build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only +build:linux_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only build:linux_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --flaky_test_attempts=3 # TODO(michaelhudgins): Why do we need to specifically omit go and java here? build:linux_arm64_pycpp_test --config=linux_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/python/tools:aot_compiled_test @@ -796,15 +796,15 @@ build:cross_compile_linux_arm64_pycpp_test --config=linux_arm64_pycpp_test # Tests that fail only when cross-compiled build:cross_compile_linux_arm64_pycpp_test -//tensorflow/compiler/mlir/quantization/stablehlo:convert_tf_quant_to_mhlo_int_test # MACOS ARM64 PYCPP -test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 -test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 +test:macos_arm64_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:macos_arm64_pycpp_test --config=macos_arm64_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/core/kernels/image:resize_bicubic_op_test # MACOS X86 PYCPP # These are defined as build configs so that we can run a build only job. See # the note under "ARM64 PYCPP" for more details. -build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test -build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --test_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test +build:macos_x86_pycpp_test_filters --build_tag_filters=-no_oss,-tf_tosa,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test build:macos_x86_pycpp_test_filters --keep_going --test_lang_filters=cc,py --test_size_filters=small,medium build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/tools/toolchains/... -//tensorflow/lite/... -//tensorflow/compiler/aot/... # CROSS-COMPILE MACOS X86 PYCPP @@ -813,8 +813,8 @@ build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_co # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off build:windows_x86_cpu_pycpp_test_build_opts_debug --config=windows_x86_cpu_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 -test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test -test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_build_opts --build_tests_only test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... diff --git a/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl index 03a9dde83cfdd..ac3082fbcb305 100644 --- a/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -111,7 +111,7 @@ filegroup( ) filegroup( - name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = [":clang/bin/crosstool_wrapper_driver_is_not_gcc"], + data = ["@local_config_rocm//rocm:all_files"], ) - diff --git a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index e1217ad09e670..e97d13f681217 100755 --- a/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -186,6 +186,7 @@ def InvokeHipcc(argv, log=False): hipccopts += defines hipccopts += std_options hipccopts += m_options + hipccopts += ' --rocm-path="%{rocm_path}" ' if depfiles: # Generate the dependency file diff --git a/third_party/tsl/third_party/gpus/rocm/BUILD.tpl b/third_party/tsl/third_party/gpus/rocm/BUILD.tpl index aa3688e335df3..7ebf2773eb48b 100644 --- a/third_party/tsl/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/tsl/third_party/gpus/rocm/BUILD.tpl @@ -1,8 +1,22 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@local_config_rocm//rocm:build_defs.bzl", "rocm_version_number", "select_threshold") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like -package(default_visibility = ["//visibility:public"]) +package(default_visibility = ["//visibility:private"]) + +bool_flag( + name = "use_rocm_hermetic_rpath", + build_setting_default = False, +) + +config_setting( + name = "build_hermetic", + flag_values = { + ":use_rocm_hermetic_rpath": "True", + }, +) config_setting( name = "using_hipcc", @@ -12,171 +26,434 @@ config_setting( ) cc_library( - name = "rocm_headers", + name = "config", hdrs = [ - "rocm/rocm_config.h", - %{rocm_headers} + "rocm_config/rocm_config.h", ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config", +) + +cc_library( + name = "config_hermetic", + hdrs = [ + "rocm_config_hermetic/rocm_config.h", + ], + include_prefix = "rocm", + strip_include_prefix = "rocm_config_hermetic", +) + +cc_library( + name = "rocm_config", + visibility = ["//visibility:public"], + deps = select({ + ":build_hermetic": [ + ":config_hermetic", + ], + "//conditions:default": [ + "config", + ], + }), +) + +cc_library( + name = "rocm_headers", + hdrs = glob([ + "%{rocm_root}/include/**", + "%{rocm_root}/lib/llvm/lib/**/*.h", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", - "rocm/include/roctracer", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", + "%{rocm_root}/include/roctracer", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [ + ":rocm_rpath", + ], ) cc_library( - name = "hip", - srcs = ["rocm/lib/%{hip_lib}"], - data = ["rocm/lib/%{hip_lib}"], + name = "rocm", + visibility = ["//visibility:public"], + deps = [ + ":hip", + ":hipblas", + ":hipblaslt", + ":hiprand", + ":hipsolver", + ":hipsparse", + ":hsa_rocr", + ":miopen", + ":rocblas", + ":rocm_config", + ":rocprofiler_register", + ":rocsolver", + ":roctracer", + ":rocsparse", + ] + select_threshold( + above_or_eq = [":hipfft"], + below = [":rocfft"], + threshold = 40100, + value = rocm_version_number(), + ), +) + +cc_library( + name = "hsa_rocr", + srcs = glob(["%{rocm_root}/lib/libhsa-runtime*.so*"]), + hdrs = glob(["%{rocm_root}/include/hsa/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_rpath", + linkopts = select({ + ":build_hermetic": [ + "-Wl,-rpath=%{rocm_toolkit_path}/lib", + ], + "//conditions:default": [ + "-Wl,-rpath=/opt/rocm/lib", + ], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hip", visibility = ["//visibility:public"], + deps = [ + ":rocm_hip", + ":rocm_rpath", + ], +) + +cc_library( + name = "rocm_hip", + srcs = glob(["%{rocm_root}/lib/libamdhip*.so*"]), + hdrs = glob(["%{rocm_root}/include/hip/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [ + ":amd_comgr", + ":hsa_rocr", + ":rocm_config", + ":rocm_smi", + ":rocprofiler_register", + ":system_libs", + ], ) cc_library( name = "rocblas", - srcs = ["rocm/lib/%{rocblas_lib}"], - data = ["rocm/lib/%{rocblas_lib}"], + hdrs = glob(["%{rocm_root}/include/rocblas/**"]), + data = glob([ + "%{rocm_root}/lib/librocblas*.so*", + "%{rocm_root}/lib/rocblas/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring tensile files to the same fs layout as expected in the lib + # rocblas assumes that tensile files are located in ../roblas/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "%{hipfft_or_rocfft}", - srcs = ["rocm/lib/%{hipfft_or_rocfft_lib}"], - data = ["rocm/lib/%{hipfft_or_rocfft_lib}"], + name = "rocfft", + srcs = glob(["%{rocm_root}/lib/librocfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], linkstatic = 1, visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "hiprand", - srcs = ["rocm/lib/%{hiprand_lib}"], - data = ["rocm/lib/%{hiprand_lib}"], + name = "hipfft", + srcs = glob(["%{rocm_root}/lib/libhipfft*.so*"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", - "rocm/include/rocrand", + "%{rocm_root}/include", ], linkstatic = 1, - visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "miopen", - srcs = ["rocm/lib/%{miopen_lib}"], - data = ["rocm/lib/%{miopen_lib}"], + name = "hiprand", + srcs = glob(["%{rocm_root}/lib/libhiprand*.so*"]), + hdrs = glob(["%{rocm_root}/include/hiprand/**"]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", + "%{rocm_root}/include/rocrand", ], linkstatic = 1, + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rccl", - srcs = ["rocm/lib/%{rccl_lib}"], - data = ["rocm/lib/%{rccl_lib}"], + name = "miopen", + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + data = glob([ + "%{rocm_root}/lib/libMIOpen*.so*", + "%{rocm_root}/share/miopen/**", + ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include", + "%{rocm_root}/include", ], - linkstatic = 1, + # workaround to bring miopen db files to the same fs layout as expected in the lib + # rocblas assumes that miopen db files are located in ../share/miopen/db directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( - name = "rocm", - visibility = ["//visibility:public"], - deps = [ - ":rocm_headers", - ":hip", - ":rocblas", - ":hipblas", - ":%{hipfft_or_rocfft}", - ":hiprand", - ":miopen", - ":hipsparse", - ":roctracer", - ":rocsolver", - ":hipsolver", + name = "rccl", + srcs = glob(["%{rocm_root}/lib/librccl*.so*"]), + hdrs = glob(["%{rocm_root}/include/rccl/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", ], + linkstatic = 1, + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) bzl_library( name = "build_defs_bzl", srcs = ["build_defs.bzl"], + visibility = ["//visibility:public"], ) cc_library( name = "rocprim", srcs = [ - "rocm/include/hipcub/hipcub_version.hpp", - "rocm/include/rocprim/rocprim_version.hpp", + "%{rocm_root}/include/hipcub/hipcub_version.hpp", + "%{rocm_root}/include/rocprim/rocprim_version.hpp", ], hdrs = glob([ - "rocm/include/hipcub/**", - "rocm/include/rocprim/**", + "%{rocm_root}/include/hipcub/**", + "%{rocm_root}/include/rocprim/**", ]), + include_prefix = "rocm", includes = [ - ".", - "rocm/include/hipcub", - "rocm/include/rocprim", + "%{rocm_root}/include/hipcub", + "%{rocm_root}/include/rocprim", ], + strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], deps = [ - "@local_config_rocm//rocm:rocm_headers", + ":rocm_config", + ":rocm_headers", ], ) cc_library( name = "hipsparse", - srcs = ["rocm/lib/%{hipsparse_lib}"], - data = ["rocm/lib/%{hipsparse_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsparse/**"]), + data = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "roctracer", - data = ["rocm/lib/%{roctracer_lib}"], + hdrs = glob(["%{rocm_root}/include/roctracer/**"]), + data = glob(["%{rocm_root}/lib/libroctracer*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "rocsolver", - srcs = ["rocm/lib/%{rocsolver_lib}"], - data = ["rocm/lib/%{rocsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/librocsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocsolver/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocsparse", + srcs = glob(["%{rocm_root}/lib/librocsparse*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipsolver", - srcs = ["rocm/lib/%{hipsolver_lib}"], - data = ["rocm/lib/%{hipsolver_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), + data = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], ) cc_library( name = "hipblas", - srcs = ["rocm/lib/%{hipblas_lib}"], - data = ["rocm/lib/%{hipblas_lib}"], + srcs = glob(["%{rocm_root}/lib/libhipblas.so*"]), + hdrs = glob(["%{rocm_root}/include/hipblas/**"]), + data = glob(["%{rocm_root}/lib/libhipblas.so*"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "hipblaslt", + hdrs = glob(["%{rocm_root}/include/hipblaslt/**"]), + data = glob([ + "%{rocm_root}/lib/hipblaslt/**", + "%{rocm_root}/lib/libhipblaslt.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + # workaround to bring tensile files to the same fs layout as expected in the lib + # hibplatslt assumes that tensile files are located in ../hipblaslt/libraries directory + linkopts = ["-Wl,-rpath=local_config_rocm/rocm/rocm_dis/lib"], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocrand", + srcs = glob(["%{rocm_root}/lib/librocrand*.so*"]), + hdrs = glob(["%{rocm_root}/include/rocrand/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include/", + ], + strip_include_prefix = "%{rocm_root}", + visibility = ["//visibility:public"], + deps = [":rocm_config"], +) + +cc_library( + name = "rocprofiler_register", + srcs = glob([ + "%{rocm_root}/lib/librocprofiler-register.so*", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "amd_comgr", + srcs = glob([ + "%{rocm_root}/lib/libamd_comgr.so*", + ]), + hdrs = glob(["%{rocm_root}/include/amd_comgr/**"]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "rocm_smi", + srcs = glob([ + "%{rocm_root}/lib/librocm_smi64.so*", + "%{rocm_root}/lib/libroam.so*", + ]), + hdrs = glob([ + "%{rocm_root}/include/oam/**", + "%{rocm_root}/include/rocm_smi/**", + ]), + include_prefix = "rocm", + includes = [ + "%{rocm_root}/include", + ], + strip_include_prefix = "%{rocm_root}", + deps = [":rocm_config"], +) + +cc_library( + name = "system_libs", + srcs = glob([ + "rocm_dist/usr/lib/**/libelf.so*", + "rocm_dist/usr/lib/**/libdrm.so*", + "rocm_dist/usr/lib/**/libnuma.so*", + "rocm_dist/usr/lib/**/libdrm_amdgpu.so*", + ]), + data = glob([ + "rocm_dist/usr/**", + ]), ) filegroup( name = "rocm_root", srcs = [ - "rocm/bin/clang-offload-bundler", + "%{rocm_root}/bin/clang-offload-bundler", ], + visibility = ["//visibility:public"], ) -%{copy_rules} +filegroup( + name = "all_files", + srcs = glob(["%{rocm_root}/**"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl index 83a7e9dababf3..d327083e4dc8e 100644 --- a/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/tsl/third_party/gpus/rocm/build_defs.bzl.tpl @@ -11,6 +11,8 @@ def if_rocm(if_true, if_false = []): "//conditions:default": if_false }) +def select_threshold(value, above_or_eq, threshold, below): + return below if value < threshold else above_or_eq def rocm_default_copts(): """Default options for all ROCm compilations.""" diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl new file mode 100644 index 0000000000000..c1cc501e1a2de --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist.bzl @@ -0,0 +1,18 @@ +load( + "@tsl//third_party/gpus/rocm:rocm_redist_ubuntu_20_04.bzl", + "rocm_redist_ubuntu_20_04", +) +load( + "@tsl//third_party/gpus/rocm:rocm_redist_ubuntu_22_04.bzl", + "rocm_redist_ubuntu_22_04", +) +load( + "@tsl//third_party/gpus/rocm:rocm_redist_ubuntu_24_04.bzl", + "rocm_redist_ubuntu_24_04", +) + +rocm_redist = { + "ubuntu_20.04": rocm_redist_ubuntu_20_04, + "ubuntu_22.04": rocm_redist_ubuntu_22_04, + "ubuntu_24.04": rocm_redist_ubuntu_24_04, +} diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl new file mode 100644 index 0000000000000..ecae2197563b3 --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_20_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_20_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~20.04_amd64.deb", + sha256 = "fabf4a831f21b5248932e08654149bc215da2a816613ad8d05b805d4e226171a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "215fae8759742bc048699feaacd6256a3ac2138771b69731dab7779325bb1b41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "e901d66275b3b520ee73250caa4a1836be142823083528b4db6cc31a18bfb94d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "f8a20128b5c26198bd9ecec894f8a4c74fa28ee668e4ef1bf73d0c3edff8c144", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "ab3ee54b33eba013fbf3d9aefe64b54e1918b9fb72790ca0b57fb391cb662cf0", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~20.04_amd64.deb", + sha256 = "a68123c046b8c913705262014463a8a30768167a1b68a78d8455deaf85a802d7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "c71fab59f62ad9d4b60aa4217f4db42c6996d83d5ad7ba29e127cc13bda59afc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "25887526ea2e955d4c0afa4749f8db55a49e399a349d43ccf66e0ad99ff78b2a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~20.04_amd64.deb", + sha256 = "3cfec840c79c6bce4e83bf6e056e241cc13ff572352b040a952c7642b61d45aa", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "cb56dd79ff52eaddfed379831023484d9ec32b9538bc3d02ee34c328457cd20e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~20.04_amd64.deb", + sha256 = "1e968f9405c8b90fbb58dff09d8bab08cf31c8386880fff95e1cb8932320bc37", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "f08ba25b6b950754b5a2bb64c125a01b9f44280f227ff19eeb78e188f0b17320", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~20.04_amd64.deb", + sha256 = "e9464369619bbea7299ac83e17b3cbbabdeb16e6d4da116400532e7737332b65", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "2efed49be9413e08e91b3fb67736644bb0e8809fc673d310a0abab65b69eacad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~20.04_amd64.deb", + sha256 = "19564fb2f9616860234aa8bd69cca324a1a3ec33476581ec57200a1dac1d4dcb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~20.04_amd64.deb", + sha256 = "e4940a5d47e9e39d603f18936e7921c603fd8dde0e359e0be796f9c1cdacd431", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "638a28c5407c3af7d16e1b0179b7494b0aeb36c314114af148b1bcd52e883db1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "77c9d26c4f0053b71fb86f7a6b489655e27053f9605efca3a16344ccf286e313", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "2b3ce1ca2e58e891963f26d4bd31ae45894480483f691d371f269e698f75f8eb", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~20.04_amd64.deb", + sha256 = "0dedbffa5bb272d656086a9586e3705551345945f35f4f6be6dc8a27b63127a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "6e5b3caeadf592367f8638db67a70b8dd9231a8257dc2012a9c46e2c5974fff5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~20.04_amd64.deb", + sha256 = "eaefe5a7d75ef61314b83af5bb85d8e652a730deaa58e1d600b1e9c2e673673c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "b2bfe29ab688781bad5bc067ee682658085e22caaf09b18278f2f4b9905081d3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~20.04_amd64.deb", + sha256 = "e94d50fd6f24d70649ce046dbfe4dda2587d1d82892d4c126a4c3e91d1570071", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "0e16c9fc58fc904542be4dad63bb2ff34268b5c13957c432e91ec0e4fd149c82", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~20.04_amd64.deb", + sha256 = "14f47d79b508eb259bfe4e0e5f360edb5721b908caf3bb981a4eee4181783be9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~20.04_amd64.deb", + sha256 = "97e6e77eaea56de6cc4ea2c525dd8b9a587546eb99c782c7af46cdc5363b99bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "ae055b579d319e1a779783ba774f119fb0e1a731d058a03b36dc5c15214d210a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~20.04_amd64.deb", + sha256 = "3bcf3dc22dbede7da70299cde1484776827808b967d371441f6cf6d3fe8af30d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "ce17d2b85407b9539e0feda513fd360a48ebfd971c19af122dda21d60448c9fc", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~20.04_amd64.deb", + sha256 = "322ca8425c3a8f2ec17c551bad606b96d957b0c1eea07196dd66ac9f15460ed5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~20.04_amd64.deb", + sha256 = "1bbdb32d21dbc12bf9a736f6ca8726df9673e4401465d2b9b537c47b358b67f1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "e74e1907eb90a692344626e881cb88eeed5565ac3b487eb94ad4ac02ffd838ed", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~20.04_amd64.deb", + sha256 = "4be88c5010c2cf0223c1dd7dc9d4a430fc54ee401ca093de2dcca60dabea763a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~20.04_amd64.deb", + sha256 = "ddd0ac44b08470dfc128d6f6d2598a9728879f5a78bc5290645baebf22433b63", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "b94cdf230b372ebcaf97085cf67f01ef7977f814280fdaf1886797f39899ef41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~20.04_amd64.deb", + sha256 = "9a85b57eea3790432eae06421081b3e59d3c9841d59646364ecd174f9ed4821a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "87dcd34a9b50f46161ecdb7781ab03c2b311fb7e13aa167c4a9c5e3bcf24b473", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~20.04_amd64.deb", + sha256 = "21e4aa1957e7bc5d293a418a983d9b3c3917fb78eb79d3d4d55a253b9bae7743", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~20.04_amd64.deb", + sha256 = "dacc13278f2be1cd847fca30ce409dcf95749df5f1a27635bc6dbd61be488d14", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.101-2_amd64.deb", + sha256 = "4cd2e10f9486456a2782487f8bfd39f330f35a4d5bd6d693412b9e4ca2a6acbd", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.101-2_amd64.deb", + sha256 = "d4567a30f7d68b4dcf794f8677b96e89083693c94e88279fecf577ceba8b9774", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.176-1.1build1_amd64.deb", + sha256 = "78a8761227efc04a1e37527f2f33ba608c6fb5d6c911616346ada5d7b9b72ee3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.12-1_amd64.deb", + sha256 = "0b1edf08cf9befecd21fe94e298ac25e476f87fd876ddd4adf42ef713449e637", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl new file mode 100644 index 0000000000000..88dca226f795b --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_22_04.bzl @@ -0,0 +1,183 @@ +rocm_redist_ubuntu_22_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~22.04_amd64.deb", + sha256 = "bc5d620e4e0db3746fc6b2279e463f618681f1f95ba973e40b687cef50ca2489", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-runtime-amd6.2.0/hip-runtime-amd6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "38e9670bedc7bbdc0b9f38c7a0fe90f73ef80f161cbf63c98d30e422438ce2c5", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "c66cc8c19b57cab740710811457f02a16e24cff761e5c99c3640f63ceefe8281", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "fbd647e1b13e7aa2c14c9581f9102c069ddab9ecb47a4b226d433ec37b19e92d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "885cf3f3a52ebde9caadf6348a6cda28fd15e3bc52bab0c90b587d72b29ff7ef", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~22.04_amd64.deb", + sha256 = "468026fa8eb70121f0c545557a926ddc41228cef9457b4a00d8fc3a36b04310f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "c2c7d2ec5a8a31837c0addfc619ee67a374ea967cc6d43900472005489f62722", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "6e649430cc5e247bbd052dff2d681b6bf0ef09d0bc3446a4911f4ab4cd317140", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~22.04_amd64.deb", + sha256 = "389b0c83a39adbeeec442adde3fedba2820ed948179a4a0df03d67560501cd97", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "adf9aad1fc062445e34cdddbeca80db9c02f4c5f258e01c45e2a6222d15cb66d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~22.04_amd64.deb", + sha256 = "cb46dfbff3943a3167f6173fc381d744eb966a3451bcff49458c696888ec452c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "8c7a216aeef6ceeb3881d3e443a89a0f5c15a17deb5926cba4b787554c8fab87", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~22.04_amd64.deb", + sha256 = "501cad72df5f09572f99c11eebbb1eff49afb6ca8c91bcf4966f81068177a95d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "b20c86be57698a944f91048699d0fbde5253bea28ba9d4035ce1de1d3c20f9ac", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~22.04_amd64.deb", + sha256 = "9dab6f44b92b6020e183777f6f07219d68de5d10cad7538c7ddcae0192aa3e33", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~22.04_amd64.deb", + sha256 = "62d280204d8ff642b464dab03fc344442df6dc5f04e152da20604e8050303c41", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "6c2aa042067e51d5b70a264ca83c92ffaa6e81d00d08b55986917da860e66d85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "f3452b2bd9c2869c550c7f963cca65fb35a37183ad4a56d96e05c69adb2f1d04", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "f3205c0a7d736f457ee2262988260e8dc4c495fa74a394ff73a9dfe002aff335", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~22.04_amd64.deb", + sha256 = "953a248cd44f403e5423185918166bfa29a009519c3d7b5b5a8e067fdf672602", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "c306ca3e59b851ebb35872e09e5598adf2e2ebb736c1b200ff4ee204fe262f7e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~22.04_amd64.deb", + sha256 = "115d0e9ec1b93bf7cba5fa1e3de1428f0d999d931c2dd495e4cdad22b5078936", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "0d40fc9aa1da617cd8864258cd1259a0e7444ea0da446297d154b5b3422393af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~22.04_amd64.deb", + sha256 = "8c1e72cf1c165e20960b0c2f3c499900a809d59340d14a0acff95c543c7087f2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "22c80c1a704f4ce7d6a49a8b41acd64f3ed0513cd7f5570a0664a10df5858334", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~22.04_amd64.deb", + sha256 = "9c2ff1dc100e342969bd51a7cd4918048c8b25579de709efde56425d969cd50f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~22.04_amd64.deb", + sha256 = "1101f3edb9dbc9f4914d7f26b5569ec9bde076d52d4125c98d22a99dd730ab51", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "d5b660df350130e0ab04ddf3e36dd442bde27ae9cbb8e5f12c047b0d3cb05463", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~22.04_amd64.deb", + sha256 = "0d06a84ac53d388089b7b8c80133f60c1eea5bfd85155ecc113efb206a747c25", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "4a29539480a7e4b27991ccf533a35526dd3994a457fa84e4c960192c2fa05b46", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~22.04_amd64.deb", + sha256 = "febb8614cedd98f13ba0624072ffdd13b9a6dc3431380a17a0eaf87583627890", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "3d859bb735ff8bf1962ce680e9257dcc574ab36224f50069f833fa19c6d7e69d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~22.04_amd64.deb", + sha256 = "ffd4e064e8a1d52b9e72114e8a1d51c78004a960f1d923448af8ed07a1b6f30b", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~22.04_amd64.deb", + sha256 = "66df78d8c5e2d1a0ae43cd4a5e41cf75ec120c870a0bbd7da18a2ba4dec42f9c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~22.04_amd64.deb", + sha256 = "317c16a6e0b0b456153437406dd92225e17dbd454fc1304b0c3fef5fbfc69bc2", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9ddf8835f1e94d5004b4c466091c8110cb72e11eda545d0de395395832076c0a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~22.04_amd64.deb", + sha256 = "9a9ed0c66d3a9d9ff50f1fc3a9e9105bb8b1a6d93c1f856682625dfb68ab627f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "5b86bf7b33a3ffa7098878f27d1b119aada69ebb02bd121b47209559c32703be", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~22.04_amd64.deb", + sha256 = "4573f99191fbe3a2afab84fdf5a05e024bd230ca7866d7eba71a5f2560a3a0bf", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~22.04_amd64.deb", + sha256 = "4fbc91db9085ecd80a5e051bba56863ae33b22516d727ab3fef15fb500187222", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.110-1ubuntu1_amd64.deb", + sha256 = "e5ea68db36b31aab442c790e1c78ecdf53646c16b0cd83db15966632ba04152c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.110-1ubuntu1_amd64.deb", + sha256 = "ae1f0d77668d7275d085ba820206ba91e90833dd1a02b8e251af0c73aa119ba3", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1_0.186-1build1_amd64.deb", + sha256 = "8effc4d7a0cc341bcf6cb11af0134f3defa6292376ecfdfc697a9b228606345c", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.14-3ubuntu2_amd64.deb", + sha256 = "0721c89001fbbd1ada23e89da5d60e762763c1a7b3dc814a2e9a518480a8043d", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl new file mode 100644 index 0000000000000..da9ef00998f93 --- /dev/null +++ b/third_party/tsl/third_party/gpus/rocm/rocm_redist_ubuntu_24_04.bzl @@ -0,0 +1,187 @@ +rocm_redist_ubuntu_24_04 = { + "6.2.0": { + "archives": [ + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/c/comgr6.2.0/comgr6.2.0_2.8.0.60200-66~24.04_amd64.deb", + sha256 = "7e1ff2d9f2435f5b9db9aa952bb57d1a878a8aa7d96bda61361c107b7e1428e3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev6.2.0/hip-dev6.2.0_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "5e6601ada30432ee0dab0473585bdf1fa7c398f0c655538d48eba9c44e6dc77a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas6.2.0/hipblas6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "7ff8f6308c744c71008959b17ab6338de1c6fd3e4581dd94271e6eca9fdc4c13", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblas-dev6.2.0/hipblas-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "e9f71e71db600d72dcb2b61e64b965b6c60d47bd4bb699e8abec85edb260b819", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt6.2.0/hipblaslt6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "e5dfd8ba9e49f919a96c102d3a652e8ef0c4d1a63b3f3909c856d40b1745e2a9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipblaslt-dev6.2.0/hipblaslt-dev6.2.0_0.8.0.60200-66~24.04_amd64.deb", + sha256 = "639bd47010035ee6719425510be33d2f54483004a909dfa4c64f853d7394a22f", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcc6.2.0/hipcc6.2.0_1.1.1.60200-66~24.04_amd64.deb", + sha256 = "c2782a98633e4400f46ba732605e56b2821366db60ec06d88db0615e4d1acf3c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipcub-dev6.2.0/hipcub-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "48fec4d06aef3159db4117125b728242a1eeb480ea3d55d3901d945d4b883694", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft6.2.0/hipfft6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "8dd73cdbd4f0563f4a0481304771e4cbcac5905eea1f2d8ef41f922cdf9aba85", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipfft-dev6.2.0/hipfft-dev6.2.0_1.0.14.60200-66~24.04_amd64.deb", + sha256 = "e3c0a4ebda8d3aacd44b19c6872f23222513be0a5c04f793605088d9183f1be4", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver6.2.0/hipsolver6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "adbba9ffcf8b5e4202efbe45924d87520bf4100ec5464bd0ba3beb61cb535c6c", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsolver-dev6.2.0/hipsolver-dev6.2.0_2.2.0.60200-66~24.04_amd64.deb", + sha256 = "01d3dd6195111808b40a5837d3e51d8c27c4700b4bd8bb2d901e39d0474fd98a", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse6.2.0/hipsparse6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "2ba33a96388cd3edd7b5b8b261fe99cbd569894f4d7db291fc0dd0ff5d7c67ce", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hipsparse-dev6.2.0/hipsparse-dev6.2.0_3.1.1.60200-66~24.04_amd64.deb", + sha256 = "6a767f493a722e2d4260a9bc23cf9db66fd275a094b395c768e305f60d6b4fe9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand6.2.0/hiprand6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "82f182134b415080ba4a12fd7993b6099ee9b9e549c72bfebee24c8486704078", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hiprand-dev6.2.0/hiprand-dev6.2.0_2.11.0.60200-66~24.04_amd64.deb", + sha256 = "011d5c28f45cd9d756e0cf6ea6a3d37eabd98a3381ffd961c772ab92a37e4ee8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hsa-rocr6.2.0/hsa-rocr6.2.0_1.14.0.60200-66~24.04_amd64.deb", + sha256 = "fa04f707debb75087ea2bf5e327602034eaa3a6900421f2cf32ad5f5f1c887b9", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip6.2.0/miopen-hip6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "2dbf6d126d0de6930e0cd94d0e525e07d3019d90bd7256f3151a7f1fbc2250af", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/m/miopen-hip-dev/miopen-hip-dev_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "df5fdd2218e4d380b133ba402f3734fbe0589d9cdd8618a101b71b968909b4ba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl6.2.0/rccl6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4d7efa4ee6aa2bf69b0aab449cc1d01c25ca65814e1b3cb07f6b59fa8b1608b8", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rccl-dev6.2.0/rccl-dev6.2.0_2.20.5.60200-66~24.04_amd64.deb", + sha256 = "4ab4f880344e04d61b6fa746be5c4bdc2841409fb6987ee61e39c6420b4eca42", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas6.2.0/rocblas6.2.0_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "521c87ce396c6ce10076cc641b6035451fd68ddb36a684c5a9c9538dfc831ade", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocblas-dev/rocblas-dev_4.2.0.60200-66~24.04_amd64.deb", + sha256 = "00f135ce2ae47c35085ef06248ff7d5ce8c12fd0d5b82e7bd77b1dbc0ce7058e", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft6.2.0/rocfft6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "40c936452e84bfec87236f08de5a9d3f232c397a3305b6143c26697ed56ceda1", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocfft-dev6.2.0/rocfft-dev6.2.0_1.0.28.60200-66~24.04_amd64.deb", + sha256 = "eb3904263b396d46799eeea1081d8e8d1a551a890432a803364db2d013849f92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-core/rocm-core_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "af5fcbe8dc2b6cbec30e2d39d30736e8a47a0b9d0ca2be7f179f2947f9c98245", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-hip-libraries/rocm-hip-libraries_6.2.0.60200-66~24.04_amd64.deb", + sha256 = "228f07a3caefc41f6efd5345eb9d3630f1db769f9b4abd1313cbcb32d077ce53", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/h/hip-dev/hip-dev_6.2.41133.60200-66~24.04_amd64.deb", + sha256 = "cda72054d2011dbb062e75386766d928fd8905c15c88685c3ef87fc963bd88ad", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-device-libs6.2.0/rocm-device-libs6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "298544f717dfb236b9257b19a0ab81abaaa770128976d4abfdea546cd32d8b02", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocminfo6.2.0/rocminfo6.2.0_1.0.0.60200-66~24.04_amd64.deb", + sha256 = "8e78ed8e480b55a496153b150acb22bab39c3bb8cf1e62f9aff7eaf75a3a3a92", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm6.2.0/rocm-llvm6.2.0_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "72c388eae7c0f54151b46fbd8fa6e26f1ca81e2b8b415c43411a156b3f25b6e7", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-llvm-dev/rocm-llvm-dev_18.0.0.24292.60200-66~24.04_amd64.deb", + sha256 = "3e85a859c5dafa82a9a57dda096d566b821217bacfac995f7cc45ed460b68999", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocm-smi-lib6.2.0/rocm-smi-lib6.2.0_7.3.0.60200-66~24.04_amd64.deb", + sha256 = "c094e3022c73fca2aa6c8bb435f93550109531de37fe8de5fbf6cfe1f047b645", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprim-dev6.2.0/rocprim-dev6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "6c832e2feb0885fbe481245825c76a466921b294f530eb0d0da70a44cfe6e608", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocprofiler-register6.2.0/rocprofiler-register6.2.0_0.4.0.60200-66~24.04_amd64.deb", + sha256 = "d198d010fedfbe51d3fd19444e2848d430e08f91d19a5b2661b94ac6d1135863", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocrand-dev/rocrand-dev_3.1.0.60200-66~24.04_amd64.deb", + sha256 = "2a2a95185ce0e54df226474b2f5cfcdc9e5ede5a6d88a8a70c2635ea2237abba", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer6.2.0/roctracer6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "2f2fb6f8d06ace89131934c833b0ea359335a4b45aeec1559b293d7bc14b1d1d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/roctracer-dev6.2.0/roctracer-dev6.2.0_4.1.60200.60200-66~24.04_amd64.deb", + sha256 = "c6c781ee87c459aed32e943b389137f98ecd402fb83a3d1c98de9a76abadc3a3", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver6.2.0/rocsolver6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5e4b3e38556f0826e5322971635a49a72283d60862ccc4d28efd11c8fb955b47", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsolver-dev6.2.0/rocsolver-dev6.2.0_3.26.0.60200-66~24.04_amd64.deb", + sha256 = "5bb6ae92a25f33488f2ee5f123ac4f67ad130e18e4949161715451509be3b89d", + ), + struct( + url = "https://repo.radeon.com/rocm/apt/6.2/pool/main/r/rocsparse6.2.0/rocsparse6.2.0_3.2.0.60200-66~24.04_amd64.deb", + sha256 = "1867833a569fbf3f87b82c81bc47f5d62085ea40f12d1cb33475c1f2dec89bc4", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm2_2.4.120-2build1_amd64.deb", + sha256 = "f5fb4e7ce17921cc466fb7911abf91495ffb181b36772f68e2e82cb621703112", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libdrm-amdgpu1_2.4.120-2build1_amd64.deb", + sha256 = "e149d4daea33f58853b8013fd6c24888429ce7716a4b26d1a1f45181b5a4e73e", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libelf1t64_0.190-1.1build4_amd64.deb", + sha256 = "b277e52769302778bd052376ac6687b52954b6605dd5f781bff8631e3504d58f", + ), + struct( + url = "https://mirror.bazel.build/github.com/alekstheod/rocm-deps/releases/download/rocm-6.2.0/libnuma1_2.0.18-1build1_amd64.deb", + sha256 = "508daa855e99959acaa945e6a89d218e0be6b5727fd28773580942ff37cf5805", + ), + ], + "rocm_root": "opt/rocm-6.2.0", + }, +} diff --git a/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/tsl/third_party/gpus/rocm_configure.bzl index 7b071f7f99a6e..935a018772443 100644 --- a/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -12,6 +12,10 @@ * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ +load( + "//third_party/gpus/rocm:rocm_redist.bzl", + "rocm_redist", +) load( "//third_party/remote_config:common.bzl", "config_repo_label", @@ -33,8 +37,6 @@ load( load( ":cuda_configure.bzl", "enable_cuda", - "make_copy_dir_rule", - "make_copy_files_rule", ) load( ":sycl_configure.bzl", @@ -48,6 +50,9 @@ _TF_SYSROOT = "TF_SYSROOT" _ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" +_DISTRIBUTION_PATH = "rocm/rocm_dist" +_OS = "OS" +_ROCM_VERSION = "ROCM_VERSION" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" @@ -203,20 +208,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): """ inc_dirs = [] - # Add HSA headers (needs to match $HSA_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hsa/include") - - # Add HIP headers (needs to match $HIP_PATH) - inc_dirs.append(rocm_config.rocm_toolkit_path + "/hip/include") - if int(rocm_config.rocm_version_number) >= 50200: - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/hip") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocprim") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocsolver") - inc_dirs.append(rocm_config.rocm_toolkit_path + "/include/rocblas") - - # Add HIP-Clang headers (realpath relative to compiler binary) - rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) + # Add full paths + rocm_toolkit_path = str(repository_ctx.path(rocm_config.rocm_toolkit_path)) inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") @@ -367,7 +360,7 @@ def _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin): return libs -def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin): +def _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin): """Returns the ROCm libraries on the system. Args: @@ -383,7 +376,6 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ for name, path in [ ("amdhip64", rocm_config.rocm_toolkit_path), ("rocblas", rocm_config.rocm_toolkit_path), - (hipfft_or_rocfft, rocm_config.rocm_toolkit_path), ("hiprand", rocm_config.rocm_toolkit_path), ("MIOpen", miopen_path), ("rccl", rccl_path), @@ -401,17 +393,17 @@ def _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_ libs_paths.append(("hipblaslt", _rocm_lib_paths(repository_ctx, "hipblaslt", rocm_config.rocm_toolkit_path), True)) return _select_rocm_lib_paths(repository_ctx, libs_paths, bash_bin) -def find_rocm_config(repository_ctx): +def find_rocm_config(repository_ctx, rocm_path): """Returns ROCm config dictionary from running find_rocm_config.py""" python_bin = get_python_bin(repository_ctx) - exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config]) + exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_rocm_config], env_vars = {"ROCM_PATH": rocm_path}) if exec_result.return_code: auto_configure_fail("Failed to run find_rocm_config.py: %s" % err_out(exec_result)) # Parse the dict from stdout. return dict([tuple(x.split(": ")) for x in exec_result.stdout.splitlines()]) -def _get_rocm_config(repository_ctx, bash_bin): +def _get_rocm_config(repository_ctx, bash_bin, rocm_path, install_path): """Detects and returns information about the ROCm installation on the system. Args: @@ -426,7 +418,7 @@ def _get_rocm_config(repository_ctx, bash_bin): miopen_version_number: The version of MIOpen on the system. hipruntime_version_number: The version of HIP Runtime on the system. """ - config = find_rocm_config(repository_ctx) + config = find_rocm_config(repository_ctx, rocm_path) rocm_toolkit_path = config["rocm_toolkit_path"] rocm_version_number = config["rocm_version_number"] miopen_version_number = config["miopen_version_number"] @@ -437,6 +429,7 @@ def _get_rocm_config(repository_ctx, bash_bin): rocm_version_number = rocm_version_number, miopen_version_number = miopen_version_number, hipruntime_version_number = hipruntime_version_number, + install_path = install_path, ) def _tpl_path(repository_ctx, labelname): @@ -500,15 +493,12 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": "hipfft", - "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), "%{roctracer_lib}": _lib_name("roctracer64"), "%{rocsolver_lib}": _lib_name("rocsolver"), "%{hipsolver_lib}": _lib_name("hipsolver"), "%{hipblaslt_lib}": _lib_name("hipblaslt"), - "%{copy_rules}": "", "%{rocm_headers}": "", }, ) @@ -526,7 +516,7 @@ def _create_dummy_repository(repository_ctx): "%{rocm_toolkit_path}": _DEFAULT_ROCM_TOOLKIT_PATH, "%{hipblaslt_flag}": "0", }, - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", ) # If rocm_configure is not configured to build with GPU support, and the user @@ -578,6 +568,53 @@ def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def _download_package(repository_ctx, archive): + file_name = _get_file_name(archive.url) + tmp_dir = "tmp" + repository_ctx.file(tmp_dir + "/.idx") # create tmp dir + + repository_ctx.report_progress("Downloading and extracting {}, expected hash is {}".format(archive.url, archive.sha256)) # buildifier: disable=print + repository_ctx.download_and_extract( + url = archive.url, + output = tmp_dir if archive.url.endswith(".deb") else _DISTRIBUTION_PATH, + sha256 = archive.sha256, + ) + + all_files = repository_ctx.path(tmp_dir).readdir() + + matched_files = [f for f in all_files if _get_file_name(str(f)).startswith("data.")] + for f in matched_files: + repository_ctx.extract(f, _DISTRIBUTION_PATH) + + repository_ctx.delete(tmp_dir) + repository_ctx.delete(file_name) + +def _remove_root_dir(path, root_dir): + if path.startswith(root_dir + "/"): + return path[len(root_dir) + 1:] + return path + +def _setup_rocm_distro_dir(repository_ctx): + """Sets up the rocm hermetic installation directory to be used in hermetic build""" + bash_bin = get_bash_bin(repository_ctx) + os = repository_ctx.os.environ.get(_OS) + rocm_version = repository_ctx.os.environ.get(_ROCM_VERSION) + if os and rocm_version: + redist = rocm_redist[os][rocm_version] + repository_ctx.file("rocm/.index") + for archive in redist["archives"]: + _download_package(repository_ctx, archive) + return _get_rocm_config(repository_ctx, bash_bin, "{}/{}".format(_DISTRIBUTION_PATH, redist["rocm_root"]), "/{}".format(redist["rocm_root"])) + else: + rocm_path = repository_ctx.os.environ.get(_ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + repository_ctx.report_progress("Using local rocm installation {}".format(rocm_path)) # buildifier: disable=print + repository_ctx.symlink(rocm_path, _DISTRIBUTION_PATH) + return _get_rocm_config(repository_ctx, bash_bin, _DISTRIBUTION_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) + def _create_local_rocm_repository(repository_ctx): """Creates the repository containing files set up to build with ROCm.""" @@ -590,12 +627,8 @@ def _create_local_rocm_repository(repository_ctx): "rocm:rocm_config.h", ]} - bash_bin = get_bash_bin(repository_ctx) - rocm_config = _get_rocm_config(repository_ctx, bash_bin) - - # For ROCm 4.1 and above use hipfft, older ROCm versions use rocfft + rocm_config = _setup_rocm_distro_dir(repository_ctx) rocm_version_number = int(rocm_config.rocm_version_number) - hipfft_or_rocfft = "rocfft" if rocm_version_number < 40100 else "hipfft" # For ROCm 5.2 and above, find MIOpen and RCCL in the main rocm lib path miopen_path = rocm_config.rocm_toolkit_path + "/miopen" if rocm_version_number < 50200 else rocm_config.rocm_toolkit_path @@ -603,75 +636,19 @@ def _create_local_rocm_repository(repository_ctx): # Copy header and library files to execroot. # rocm_toolkit_path - rocm_toolkit_path = rocm_config.rocm_toolkit_path - copy_rules = [ - make_copy_dir_rule( - repository_ctx, - name = "rocm-include", - src_dir = rocm_toolkit_path + "/include", - out_dir = "rocm/include", - ), - ] - - # explicitly copy (into the local_config_rocm repo) the $ROCM_PATH/hiprand/include and - # $ROCM_PATH/rocrand/include dirs, only once the softlink to them in $ROCM_PATH/include - # dir has been removed. This removal will happen in a near-future ROCm release. - hiprand_include = "" - hiprand_include_softlink = rocm_config.rocm_toolkit_path + "/include/hiprand" - softlink_exists = files_exist(repository_ctx, [hiprand_include_softlink], bash_bin) - if not softlink_exists[0]: - hiprand_include = '":hiprand-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "hiprand-include", - src_dir = rocm_toolkit_path + "/hiprand/include", - out_dir = "rocm/include/hiprand", - ), - ) - - rocrand_include = "" - rocrand_include_softlink = rocm_config.rocm_toolkit_path + "/include/rocrand" - softlink_exists = files_exist(repository_ctx, [rocrand_include_softlink], bash_bin) - if not softlink_exists[0]: - rocrand_include = '":rocrand-include",\n' - copy_rules.append( - make_copy_dir_rule( - repository_ctx, - name = "rocrand-include", - src_dir = rocm_toolkit_path + "/rocrand/include", - out_dir = "rocm/include/rocrand", - ), - ) + rocm_toolkit_path = _remove_root_dir(rocm_config.rocm_toolkit_path, "rocm") - rocm_libs = _find_libs(repository_ctx, rocm_config, hipfft_or_rocfft, miopen_path, rccl_path, bash_bin) + bash_bin = get_bash_bin(repository_ctx) + rocm_libs = _find_libs(repository_ctx, rocm_config, miopen_path, rccl_path, bash_bin) rocm_lib_srcs = [] rocm_lib_outs = [] for lib in rocm_libs.values(): if lib: rocm_lib_srcs.append(lib.path) rocm_lib_outs.append("rocm/lib/" + lib.file_name) - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-lib", - srcs = rocm_lib_srcs, - outs = rocm_lib_outs, - )) clang_offload_bundler_path = rocm_toolkit_path + "/llvm/bin/clang-offload-bundler" - # copy files mentioned in third_party/gpus/rocm/BUILD - copy_rules.append(make_copy_files_rule( - repository_ctx, - name = "rocm-bin", - srcs = [ - clang_offload_bundler_path, - ], - outs = [ - "rocm/bin/" + "clang-offload-bundler", - ], - )) - have_hipblaslt = "1" if rocm_libs["hipblaslt"] != None else "0" # Set up BUILD file for rocm/ @@ -693,20 +670,8 @@ def _create_local_rocm_repository(repository_ctx): ) repository_dict = { - "%{hip_lib}": rocm_libs["amdhip64"].file_name, - "%{rocblas_lib}": rocm_libs["rocblas"].file_name, - "%{hipfft_or_rocfft}": hipfft_or_rocfft, - "%{hipfft_or_rocfft_lib}": rocm_libs[hipfft_or_rocfft].file_name, - "%{hiprand_lib}": rocm_libs["hiprand"].file_name, - "%{miopen_lib}": rocm_libs["MIOpen"].file_name, - "%{rccl_lib}": rocm_libs["rccl"].file_name, - "%{hipsparse_lib}": rocm_libs["hipsparse"].file_name, - "%{roctracer_lib}": rocm_libs["roctracer64"].file_name, - "%{rocsolver_lib}": rocm_libs["rocsolver"].file_name, - "%{copy_rules}": "\n".join(copy_rules), - "%{rocm_headers}": ('":rocm-include",\n' + - hiprand_include + - rocrand_include), + "%{rocm_root}": rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), } is_rocm_clang = _use_rocm_clang(repository_ctx) @@ -726,7 +691,6 @@ def _create_local_rocm_repository(repository_ctx): ) # Set up crosstool/ - cc = find_cc(repository_ctx, is_rocm_clang) host_compiler_includes = get_cxx_inc_directories( repository_ctx, @@ -785,6 +749,7 @@ def _create_local_rocm_repository(repository_ctx): repository_ctx.template( "crosstool/cc_toolchain_config.bzl", tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"], + rocm_defines, ) repository_ctx.template( @@ -792,12 +757,13 @@ def _create_local_rocm_repository(repository_ctx): tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"], { "%{cpu_compiler}": str(cc), - "%{compiler_is_clang}": "True" if is_rocm_clang else "False", - "%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc", + "%{compiler_is_clang}": "True" if is_rocm_clang else "False", + "%{hipcc_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/bin/hipcc")), "%{hipcc_env}": _hipcc_env(repository_ctx), - "%{rocr_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{rocm_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), + "%{rocr_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{rocr_runtime_library}": "hsa-runtime64", - "%{hip_runtime_path}": rocm_config.rocm_toolkit_path + "/lib", + "%{hip_runtime_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path + "/lib")), "%{hip_runtime_library}": "amdhip64", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), @@ -807,13 +773,32 @@ def _create_local_rocm_repository(repository_ctx): # Set up rocm_config.h, which is used by # tensorflow/compiler/xla/stream_executor/dso_loader.cc. repository_ctx.template( - "rocm/rocm/rocm_config.h", + "rocm/rocm_config/rocm_config.h", + tpl_paths["rocm:rocm_config.h"], + { + "%{rocm_amdgpu_targets}": ",".join( + ["\"%s\"" % c for c in rocm_config.amdgpu_targets], + ), + "%{rocm_toolkit_path}": rocm_config.install_path, + "%{rocm_version_number}": rocm_config.rocm_version_number, + "%{miopen_version_number}": rocm_config.miopen_version_number, + "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, + "%{hipblaslt_flag}": have_hipblaslt, + "%{hip_soversion_number}": "6" if int(rocm_config.rocm_version_number) >= 60000 else "5", + "%{rocblas_soversion_number}": "4" if int(rocm_config.rocm_version_number) >= 60000 else "3", + }, + ) + + # Set up rocm_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "rocm/rocm_config_hermetic/rocm_config.h", tpl_paths["rocm:rocm_config.h"], { "%{rocm_amdgpu_targets}": ",".join( ["\"%s\"" % c for c in rocm_config.amdgpu_targets], ), - "%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path, + "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), "%{rocm_version_number}": rocm_config.rocm_version_number, "%{miopen_version_number}": rocm_config.miopen_version_number, "%{hipruntime_version_number}": rocm_config.hipruntime_version_number, @@ -889,6 +874,8 @@ _ENVIRONS = [ "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, + _OS, + _ROCM_VERSION, ] remote_rocm_configure = repository_rule( diff --git a/third_party/tsl/third_party/remote_config/common.bzl b/third_party/tsl/third_party/remote_config/common.bzl index 57fb6fcf7aca9..c70c0ba5b51db 100644 --- a/third_party/tsl/third_party/remote_config/common.bzl +++ b/third_party/tsl/third_party/remote_config/common.bzl @@ -212,7 +212,8 @@ def execute( cmdline, error_msg = None, error_details = None, - allow_failure = False): + allow_failure = False, + env_vars = {}): """Executes an arbitrary shell command. Args: @@ -222,10 +223,11 @@ def execute( error_details: string, details about the error or steps to fix it allow_failure: bool, if True, an empty stdout result or output to stderr is fine, otherwise either of these is an error + env_vars: environment variables Returns: The result of repository_ctx.execute(cmdline) """ - result = raw_exec(repository_ctx, cmdline) + result = raw_exec(repository_ctx, cmdline, env_vars) if (result.stderr or not result.stdout) and not allow_failure: fail( "\n".join([ @@ -236,7 +238,7 @@ def execute( ) return result -def raw_exec(repository_ctx, cmdline): +def raw_exec(repository_ctx, cmdline, env_vars = {}): """Executes a command via repository_ctx.execute() and returns the result. This method is useful for debugging purposes. For example, to print all @@ -245,11 +247,12 @@ def raw_exec(repository_ctx, cmdline): Args: repository_ctx: the repository_ctx cmdline: the list of args + env_vars: environment variables Returns: The 'exec_result' of repository_ctx.execute(). """ - return repository_ctx.execute(cmdline) + return repository_ctx.execute(cmdline, environment = env_vars) def files_exist(repository_ctx, paths, bash_bin = None): """Checks which files in paths exists. diff --git a/xla/backends/cpu/collectives/BUILD b/xla/backends/cpu/collectives/BUILD index 947173cc8a7f9..4363e15a7e13f 100644 --- a/xla/backends/cpu/collectives/BUILD +++ b/xla/backends/cpu/collectives/BUILD @@ -14,6 +14,59 @@ package_group( ], ) +cc_library( + name = "cpu_clique_key", + srcs = ["cpu_clique_key.cc"], + hdrs = ["cpu_clique_key.h"], + deps = [ + "//xla/core/collectives:clique_key", + "//xla/service:global_device_id", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@tsl//tsl/platform:casts", + ], +) + +cc_library( + name = "cpu_clique", + srcs = ["cpu_clique.cc"], + hdrs = ["cpu_clique.h"], + deps = [ + ":cpu_clique_key", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:logging", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "cpu_cliques", + srcs = ["cpu_cliques.cc"], + hdrs = ["cpu_cliques.h"], + deps = [ + ":cpu_clique", + ":cpu_clique_key", + ":cpu_collectives", + "//xla:util", + "//xla/core/collectives:clique", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + ], +) + cc_library( name = "cpu_collectives", srcs = ["cpu_collectives.cc"], @@ -23,14 +76,130 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/core/collectives", + "//xla/core/collectives:clique_id", "//xla/core/collectives:collectives_registry", "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", ], ) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "gloo_communicator", + srcs = ["gloo_communicator.cc"], + hdrs = ["gloo_communicator.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + deps = [ + ":cpu_collectives", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@gloo", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "in_process_communicator", + srcs = ["in_process_communicator.cc"], + hdrs = ["in_process_communicator.h"], + deps = [ + ":cpu_collectives", + "//xla:refcounting_hash_map", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@tsl//tsl/platform:errors", + ], +) + +# TODO(b/380457503): Restrict visibility to private. +cc_library( + name = "mpi_communicator", + srcs = ["mpi_communicator.cc"], + hdrs = ["mpi_communicator.h"], + compatible_with = [], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + # copybara:uncomment_begin(google-only) + # "-Ithird_party/openmpi/ompi/include", + # copybara:uncomment_end + ], + features = ["-use_header_modules"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", + "//xla/service:collective_ops_utils", + "//xla/service:global_device_id", + "//xla/service/cpu:collectives_interface", + "//xla/stream_executor:device_memory", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@mpitrampoline", + "@tsl//tsl/platform:errors", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:statusor", + ], +) diff --git a/xla/backends/cpu/collectives/cpu_clique.cc b/xla/backends/cpu/collectives/cpu_clique.cc new file mode 100644 index 0000000000000..a81dd80392f9f --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/logging.h" + +namespace xla::cpu { + +CpuClique::CpuClique(CpuCliqueKey key) : Clique({}), key_(std::move(key)) {} + +std::string CpuClique::DebugString() const { + std::string out = + absl::StrFormat("key: %s; size: %d; communicators: ", key_.ToString(), + num_communicators()); + int32_t cnt = 0; + ForEachComm([&](RankId rank, Communicator* comm) { + if (cnt++) absl::StrAppend(&out, ", "); + absl::StrAppendFormat(&out, "[rank=%d, comm=%s]", rank.value(), + comm->ToString()); + }); + return out; +} + +absl::Status CpuClique::HealthCheck() const { + absl::Status health_check = absl::OkStatus(); + ForEachComm([&health_check](RankId rank, Communicator* comm) { + if (auto s = comm->HealthCheck(); !s.ok()) { + LOG(ERROR) << "CPU communicator error (rank " << rank << "): " << s; + if (health_check.ok()) health_check = std::move(s); // return first error + } + }); + return health_check; +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_clique.h b/xla/backends/cpu/collectives/cpu_clique.h new file mode 100644 index 0000000000000..e1ff3025a955b --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique.h @@ -0,0 +1,42 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ + +#include + +#include "absl/status/status.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/core/collectives/clique.h" + +namespace xla::cpu { + +// A group of CPU communicators making up a clique. +class CpuClique final : public Clique { + public: + explicit CpuClique(CpuCliqueKey key); + + absl::Status HealthCheck() const final; + + std::string DebugString() const final; + + private: + CpuCliqueKey key_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_H_ diff --git a/xla/backends/cpu/collectives/cpu_clique_key.cc b/xla/backends/cpu/collectives/cpu_clique_key.cc new file mode 100644 index 0000000000000..b66c844d4983e --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique_key.cc @@ -0,0 +1,59 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_clique_key.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/hash/hash.h" +#include "absl/strings/str_format.h" +#include "xla/core/collectives/clique_key.h" +#include "xla/service/global_device_id.h" +#include "tsl/platform/casts.h" + +namespace xla::cpu { + +bool CpuCliqueKey::IsSubsetOf(const CliqueKey& other) const { + auto* other_cpu = tsl::down_cast(&other); + if (other_cpu == nullptr) return false; + + return absl::c_all_of(devices(), [&](GlobalDeviceId id) { + return absl::c_linear_search(other_cpu->devices(), id); + }); +} + +std::string CpuCliqueKey::ToString() const { + return absl::StrFormat("devices=[%s]", GlobalDeviceIdsToString(devices())); +} + +void CpuCliqueKey::HashValue(absl::HashState state) const { + absl::HashState::combine(std::move(state), devices()); +} + +bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() == b.devices(); +} + +bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() < b.devices(); +} + +bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b) { + return a.devices() > b.devices(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_clique_key.h b/xla/backends/cpu/collectives/cpu_clique_key.h new file mode 100644 index 0000000000000..30b257c1a0d0c --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_clique_key.h @@ -0,0 +1,44 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ + +#include + +#include "absl/hash/hash.h" +#include "xla/core/collectives/clique_key.h" + +namespace xla::cpu { + +// Clique key for identifying a particular CPU collectives clique. +class CpuCliqueKey final : public CliqueKey { + public: + using CliqueKey::CliqueKey; + + bool IsSubsetOf(const CliqueKey& other) const final; + std::string ToString() const final; + + friend bool operator==(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator<(const CpuCliqueKey& a, const CpuCliqueKey& b); + friend bool operator>(const CpuCliqueKey& a, const CpuCliqueKey& b); + + private: + void HashValue(absl::HashState state) const final; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUE_KEY_H_ diff --git a/xla/backends/cpu/collectives/cpu_cliques.cc b/xla/backends/cpu/collectives/cpu_cliques.cc new file mode 100644 index 0000000000000..6e6c437256ad1 --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_cliques.cc @@ -0,0 +1,122 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/cpu_cliques.h" + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/collectives/cpu_clique.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace xla::cpu { + +//===----------------------------------------------------------------------===// +// ProcessCpuCliques +//===----------------------------------------------------------------------===// + +namespace { + +// CpuClique is not thread-safe, so we wrap it in a thread-safe container as we +// create new communicators lazily and potentially from multiple threads. +struct ThreadSafeClique { + explicit ThreadSafeClique(CpuCliqueKey key) : clique(key) {} + + absl::Mutex mu; + CpuClique clique ABSL_GUARDED_BY(mu); +}; + +// Container for initialized and ready to use CPU cliques. In contrast to GPU +// cliques, CPU cliques are not lockable, and we create communicators lazily +// when needed. +struct ProcessCpuCliques { + absl::Mutex mu; + absl::node_hash_map map ABSL_GUARDED_BY(mu); +}; +} // namespace + +// Returns process-local CPU cliques. +static ProcessCpuCliques& GetProcessCpuCliques() { + static auto* cliques = new ProcessCpuCliques; + return *cliques; +} + +//===----------------------------------------------------------------------===// + +// TODO(b/380457503): Consider switching to a lockable CPU clique model similar +// to GPU cliques, and creating all communicators upfront. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank) { + VLOG(3) << "Acquire communicator for clique key " << clique_key.ToString() + << " and rank " << rank; + + ProcessCpuCliques& cliques = GetProcessCpuCliques(); + + // Synchronize access to the process cliques. + ThreadSafeClique& thread_safe_clique = [&]() -> ThreadSafeClique& { + absl::MutexLock lock(&cliques.mu); + auto [it, emplaced] = cliques.map.try_emplace(clique_key, clique_key); + return it->second; + }(); + + // Check if we already have a communicator for this rank. + std::optional comm = [&]() -> std::optional { + absl::MutexLock lock(&thread_safe_clique.mu); + return thread_safe_clique.clique.comm(rank); + }(); + + if (comm.has_value()) return *comm; + + VLOG(3) << "Create a new communicator for clique key " + << clique_key.ToString() << " and rank " << rank; + + // Create a new communicator and add it to the clique. + CpuCollectives::DeviceRank device_rank(/*device=*/nullptr, rank); + CpuCollectives::Config config; + + TF_ASSIGN_OR_RETURN( + std::vector> communicators, + collectives->CreateCommunicators(clique_key.num_devices(), clique_key, + std::nullopt, {device_rank}, config)); + + // We expect to create communicators lazily on at a time. + if (communicators.size() != 1) { + return Internal( + "Expected to create a single communicator for a clique key %s and rank " + "%d, but got %d", + clique_key.ToString(), rank.value(), communicators.size()); + } + + absl::MutexLock lock(&thread_safe_clique.mu); + TF_RETURN_IF_ERROR(thread_safe_clique.clique.AddComm( + rank, std::move(communicators.front()))); + + return *thread_safe_clique.clique.comm(rank); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/cpu_cliques.h b/xla/backends/cpu/collectives/cpu_cliques.h new file mode 100644 index 0000000000000..b42774619fe4b --- /dev/null +++ b/xla/backends/cpu/collectives/cpu_cliques.h @@ -0,0 +1,33 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ + +#include "absl/status/statusor.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" + +namespace xla::cpu { + +// Returns a communicator for a given clique key and rank. +absl::StatusOr AcquireCommunicator( + CpuCollectives* collectives, const CpuCliqueKey& clique_key, RankId rank); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_CPU_CLIQUES_H_ diff --git a/xla/backends/cpu/collectives/cpu_collectives.h b/xla/backends/cpu/collectives/cpu_collectives.h index a728e7cd3a399..330b35f52146d 100644 --- a/xla/backends/cpu/collectives/cpu_collectives.h +++ b/xla/backends/cpu/collectives/cpu_collectives.h @@ -16,11 +16,19 @@ limitations under the License. #ifndef XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ #define XLA_BACKENDS_CPU_COLLECTIVES_CPU_COLLECTIVES_H_ +#include +#include +#include + #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "absl/types/span.h" +#include "xla/core/collectives/clique_id.h" #include "xla/core/collectives/collectives.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { @@ -50,6 +58,17 @@ class CpuCollectives : public Collectives { absl::Duration timeout_; }; + absl::StatusOr CreateUniqueCliqueId() const final { + return Unimplemented("CPU collectives do not support clique ids"); + } + + absl::StatusOr>> SplitCommunicators( + absl::Span comms, int32_t color, + absl::Span keys, const Config& config) final { + return Unimplemented( + "CPU collectives do not support communicator splitting"); + } + // Tries to cast a Collectives::Device to a CpuCollectives::Device. static absl::StatusOr TryCast( const Collectives::Device* device); diff --git a/xla/backends/cpu/collectives/gloo_communicator.cc b/xla/backends/cpu/collectives/gloo_communicator.cc new file mode 100644 index 0000000000000..e5e19aa3a1cfe --- /dev/null +++ b/xla/backends/cpu/collectives/gloo_communicator.cc @@ -0,0 +1,443 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/gloo_communicator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "gloo/algorithm.h" +#include "gloo/allgather.h" +#include "gloo/allreduce.h" +#include "gloo/context.h" +#include "gloo/math.h" +#include "gloo/reduce_scatter.h" +#include "gloo/transport/device.h" +#include "gloo/transport/unbound_buffer.h" +#include "gloo/types.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +GlooCommunicator::GlooCommunicator(std::shared_ptr context, + size_t rank, size_t num_ranks) + : context_(std::move(context)), rank_(rank), num_ranks_(num_ranks) {} + +GlooCommunicator::~GlooCommunicator() = default; + +template +static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, + se::DeviceMemoryBase input_buffer, + se::DeviceMemoryBase output_buffer, + size_t num_elements, + gloo::AllreduceOptions& options) { + options.setInput( + reinterpret_cast(const_cast(input_buffer.opaque())), + num_elements); + options.setOutput( + reinterpret_cast(const_cast(output_buffer.opaque())), + num_elements); + + using ReductionFn = void (*)(void*, const void*, const void*, size_t); + + switch (reduction_kind) { + case ReductionKind::SUM: + options.setReduceFunction(static_cast(&gloo::sum)); + break; + case ReductionKind::PRODUCT: + options.setReduceFunction(static_cast(&gloo::product)); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + options.setReduceFunction(static_cast(&gloo::min)); + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + options.setReduceFunction(static_cast(&gloo::max)); + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + break; + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + + gloo::AllreduceOptions options(context_); + // TODO(phawkins): how to do tags? + // options.setTag(tag); + switch (dtype) { + case S8: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case S16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case U16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case S32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case U32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case S64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case U64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case F16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case BF16: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case F32: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case F64: + TF_RETURN_IF_ERROR(SetAllReduceOptions( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case C64: + TF_RETURN_IF_ERROR(SetAllReduceOptions>( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + case C128: + TF_RETURN_IF_ERROR(SetAllReduceOptions>( + reduction_kind, send_buffer, recv_buffer, count, options)); + break; + default: + return absl::InvalidArgumentError("Unknown datatype in allreduce"); + } + options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); + options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); + + try { + gloo::allreduce(options); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo all-reduce failed: ", e.what())); + } + return absl::OkStatus(); +} + +static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; + +absl::Status GlooCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + uint32_t tag = 0; // TODO(phawkins): come up with better tags. + const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag); + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + + try { + std::unique_ptr in; + std::unique_ptr out; + for (RankId target : target_ranks) { + if (target != context_->rank) { + VLOG(1) << "send from " << context_->rank << " to " << target.value(); + if (!in) { + in = context_->createUnboundBuffer(send_buffer.opaque(), num_bytes); + } + in->send(target.value(), slot); + } + } + if (source_rank) { + if (*source_rank == context_->rank) { + std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); + } else { + VLOG(1) << "recv at " << context_->rank << " from " + << source_rank->value(); + out = context_->createUnboundBuffer(recv_buffer.opaque(), num_bytes); + out->recv(source_rank->value(), slot); + } + } else { + std::memset(recv_buffer.opaque(), 0, num_bytes); + } + VLOG(1) << "wait for send at " << context_->rank; + auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); + if (in) { + in->waitSend(deadline); + } + VLOG(1) << "wait for recv at " << context_->rank; + if (out) { + out->waitRecv(deadline); + } + VLOG(1) << "done waiting at " << context_->rank; + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo collective permute failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + // We can't use Gloo's all-to-all implementation directly because it assumes + // that the inputs and outputs are contiguous. No big deal; it's just built + // on top of send/recv and we can do the same as it. + uint32_t tag = 0; // TODO(phawkins): use better tags. + int my_rank = context_->rank; + int world_size = context_->size; + + TF_RET_CHECK(world_size == send_buffers.size()); + TF_RET_CHECK(world_size == recv_buffers.size()); + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + try { + const auto slot = gloo::Slot::build(gloo::kAlltoallSlotPrefix, tag); + std::vector> ins( + context_->size); + std::vector> outs( + context_->size); + for (size_t i = 0; i < world_size; ++i) { + if (i != my_rank) { + ins[i] = context_->createUnboundBuffer( + const_cast(send_buffers[i].opaque()), chunk_bytes); + outs[i] = context_->createUnboundBuffer( + const_cast(recv_buffers[i].opaque()), chunk_bytes); + } + } + + for (int i = 1; i < world_size; i++) { + int send_rank = (my_rank + i) % world_size; + int recv_rank = (my_rank + world_size - i) % world_size; + ins[send_rank]->send(send_rank, slot); + outs[recv_rank]->recv(recv_rank, slot); + } + + std::memcpy(const_cast(recv_buffers[my_rank].opaque()), + send_buffers[my_rank].opaque(), chunk_bytes); + + auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); + for (int i = 0; i < world_size; i++) { + if (i != my_rank) { + ins[i]->waitSend(deadline); + outs[i]->waitRecv(deadline); + } + } + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo all-to-all failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + uint32_t tag = 0; // TODO(phawkins): use better tags. + + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + gloo::AllgatherOptions options(context_); + options.setTag(tag); + options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); + options.setInput(reinterpret_cast(send_buffer.opaque()), chunk_bytes); + options.setOutput(reinterpret_cast(recv_buffer.opaque()), + chunk_bytes * context_->size); + + try { + gloo::allgather(options); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo AllGather failed: ", e.what())); + } + return absl::OkStatus(); +} + +template +absl::Status ReduceScatterHelper(std::shared_ptr context, + ReductionKind reduction_kind, void* buffer, + size_t chunk_elems) { + const gloo::ReductionFunction* reduction_function = nullptr; + if constexpr (is_complex_v) { + switch (reduction_kind) { + case ReductionKind::SUM: + reduction_function = gloo::ReductionFunction::sum; + break; + case ReductionKind::PRODUCT: + reduction_function = gloo::ReductionFunction::product; + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported reduction kind: ", static_cast(reduction_kind))); + } + } else { + switch (reduction_kind) { + case ReductionKind::SUM: + reduction_function = gloo::ReductionFunction::sum; + break; + case ReductionKind::PRODUCT: + reduction_function = gloo::ReductionFunction::product; + break; + case ReductionKind::MAX: + reduction_function = gloo::ReductionFunction::max; + break; + case ReductionKind::MIN: + reduction_function = gloo::ReductionFunction::min; + break; + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported reduction kind: ", static_cast(reduction_kind))); + } + } + try { + std::vector recv_elems(context->size, chunk_elems); + gloo::ReduceScatterHalvingDoubling algorithm( + context, std::vector{reinterpret_cast(buffer)}, + chunk_elems * context->size, recv_elems, reduction_function); + algorithm.run(); + } catch (std::exception& e) { + return absl::UnknownError( + absl::StrCat("Gloo ReduceScatter failed: ", e.what())); + } + return absl::OkStatus(); +} + +absl::Status GlooCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + std::unique_ptr temp(new char[chunk_bytes * context_->size]); + std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size); + switch (dtype) { + case S8: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case PRED: + case U8: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case S16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case U16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case S32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case U32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case S64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case U64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case BF16: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case F16: + TF_RETURN_IF_ERROR(ReduceScatterHelper( + context_, reduction_kind, temp.get(), count)); + break; + case F32: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case F64: + TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, + temp.get(), count)); + break; + case C64: + TF_RETURN_IF_ERROR(ReduceScatterHelper>( + context_, reduction_kind, temp.get(), count)); + break; + case C128: + TF_RETURN_IF_ERROR(ReduceScatterHelper>( + context_, reduction_kind, temp.get(), count)); + break; + default: + return absl::InvalidArgumentError("Unknown datatype in reducescatter"); + } + std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes); + return absl::OkStatus(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/gloo_communicator.h b/xla/backends/cpu/collectives/gloo_communicator.h new file mode 100644 index 0000000000000..234716da75934 --- /dev/null +++ b/xla/backends/cpu/collectives/gloo_communicator.h @@ -0,0 +1,103 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "gloo/context.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA communicator implemented using Gloo communication library. +class GlooCommunicator : public Communicator { + public: + GlooCommunicator(std::shared_ptr context, size_t rank, + size_t num_ranks); + ~GlooCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("GlooCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + + private: + std::shared_ptr context_; + size_t rank_; + size_t num_ranks_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_GLOO_COMMUNICATOR_H_ diff --git a/xla/backends/cpu/collectives/in_process_communicator.cc b/xla/backends/cpu/collectives/in_process_communicator.cc new file mode 100644 index 0000000000000..a293c1e72672c --- /dev/null +++ b/xla/backends/cpu/collectives/in_process_communicator.cc @@ -0,0 +1,576 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/in_process_communicator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/refcounting_hash_map.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/global_device_id.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { +namespace { + +void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { + absl::StrAppend(out, device.value()); +} + +struct AllReduceParticipantData : ParticipantData { + explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, + int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + int64_t element_count; + const void* source_data; + void* destination_data; + PrimitiveType primitive_type; + + ReductionKind reduction_kind; + + std::string ToString() const override { + return absl::StrFormat( + "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " + "rendezvous_key=%s}", + local_rank, element_count, PrimitiveType_Name(primitive_type), + rendezvous_key.ToString()); + } +}; + +template +T GetInitialValue(ReductionKind reduction_kind) { + switch (reduction_kind) { + case ReductionKind::SUM: + return static_cast(0); + case ReductionKind::PRODUCT: + return static_cast(1); + case ReductionKind::MIN: + return std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + case ReductionKind::MAX: + return std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + } +} + +// We cannot use static_assert(false), because the C++ standard (prior to +// CWG2518) does not allow the statement discarded by a constexpr if to +// be ill-formed for every possible specialization. +// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if +template +constexpr bool always_false_v = false; + +template +void ReduceHelper(absl::Span acc, absl::Span inputs) { + // TODO(penporn): make sure this gets vectorized. + if constexpr (reduction_kind == ReductionKind::SUM) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] += inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] *= inputs[j][i]; + } + } + } else if constexpr (reduction_kind == ReductionKind::MIN) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::min(acc[i], inputs[j][i]); + } + } + } else if constexpr (reduction_kind == ReductionKind::MAX) { + for (size_t j = 0; j < inputs.size(); ++j) { + for (size_t i = 0; i < acc.size(); ++i) { + acc[i] = std::max(acc[i], inputs[j][i]); + } + } + } else { + static_assert(always_false_v, "Unsupported reduction kind"); + } +} + +template +absl::Status ReduceScatter(ReductionKind reduction_kind, + absl::Span inputs, void* output, + int64_t num_elems) { + using T = primitive_util::NativeTypeOf; + T initial_value = GetInitialValue(reduction_kind); + + absl::Span out_chunk = + absl::MakeSpan(reinterpret_cast(output), num_elems); + for (int64_t i = 0; i < num_elems; ++i) { + out_chunk[i] = initial_value; + } + + absl::Span input_chunks( + reinterpret_cast(inputs.data()), inputs.size()); + switch (reduction_kind) { + case ReductionKind::SUM: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::PRODUCT: + ReduceHelper(out_chunk, input_chunks); + break; + case ReductionKind::MIN: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Min reductions not supported for complex types"); + } + break; + case ReductionKind::MAX: + if constexpr (!is_complex_v) { + ReduceHelper(out_chunk, input_chunks); + } else { + return absl::InvalidArgumentError( + "Max reductions not supported for complex types"); + } + break; + } + + return absl::OkStatus(); +} + +class CpuAllReduceRendezvous + : public Rendezvous { + public: + explicit CpuAllReduceRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllReduceParticipantData& me) override { + VLOG(3) << me.ToString(); + int64_t world_size = participants_.size(); + // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th + // chunk of the output. + int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); + + int64_t start_elem = me.local_rank * chunk_elems; + int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); + chunk_elems = std::max(int64_t{0}, end_elem - start_elem); + if (chunk_elems == 0) { + return nullptr; + } + + auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); + int64_t chunk_offset = start_elem * bytes_per_elem; + int64_t chunk_bytes = chunk_elems * bytes_per_elem; + void* reduce_output = + reinterpret_cast(me.destination_data) + chunk_offset; + + std::vector inputs; + inputs.reserve(world_size); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_data) + + chunk_offset); + } + + if (primitive_util::IsArrayType(me.primitive_type)) { + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto constant_type) { + return ReduceScatter(me.reduction_kind, inputs, + reduce_output, chunk_elems); + }, + me.primitive_type)); + } else { + return absl::UnimplementedError(absl::StrCat( + "Unexpected datatype: ", + primitive_util::LowercasePrimitiveTypeName(me.primitive_type))); + } + + // All-gather the reduced chunks. + for (const auto& p : participants_) { + if (p->local_rank != me.local_rank) { + std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, + reduce_output, chunk_bytes); + } + } + return nullptr; + } +}; + +struct CollectivePermuteParticipantData : ParticipantData { + CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, + int rank) + : ParticipantData(rendezvous_key_p, rank) {} + const void* source_buffer; + void* destination_buffer; + size_t num_bytes; + + // From which rank is this participant receiving its data? Optional; if + // absent fill with zeros. + std::optional source_rank; + + std::string ToString() const override { + return absl::StrFormat( + "CollectivePermuteParticipantData{rank=%d, " + "source_buffer=%p, destination_buffer=%p, num_bytes=%d, " + "source_replica_id=%d, " + "devices=[%s]}", + local_rank, source_buffer, destination_buffer, num_bytes, + source_rank.value_or(-1), + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId)); + } +}; + +class CpuCollectivePermuteRendezvous + : public Rendezvous { + public: + explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const CollectivePermuteParticipantData& p) override { + VLOG(3) << p.ToString(); + if (p.source_rank) { + std::memcpy(p.destination_buffer, + participants_[*p.source_rank]->source_buffer, p.num_bytes); + } else { + std::memset(p.destination_buffer, 0, p.num_bytes); + } + return nullptr; + } +}; + +struct AllToAllParticipantData : ParticipantData { + AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + std::vector source_buffers; + std::vector destination_buffers; + size_t chunk_size; + + std::string ToString() const override { + auto addr_formatter = [](std::string* out, const void* mem) { + absl::StrAppend(out, absl::StrFormat("%p", mem)); + }; + return absl::StrFormat( + "AllToAllParticipantData{rank=%d, " + "devices=[%s], source_buffers=[%s], " + "destination_buffers=[%s], chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + absl::StrJoin(source_buffers, ", ", addr_formatter), + absl::StrJoin(destination_buffers, ", ", addr_formatter), chunk_size); + } +}; + +class CpuAllToAllRendezvous + : public Rendezvous { + public: + explicit CpuAllToAllRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllToAllParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + for (int i = 0; i < world_size; ++i) { + std::memcpy(participants_[i]->destination_buffers[p.local_rank], + p.source_buffers[i], p.chunk_size); + } + return nullptr; + } +}; + +struct AllGatherParticipantData : ParticipantData { + AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + const void* source_buffer; + void* destination_buffer; + size_t chunk_size; + + std::string ToString() const override { + return absl::StrFormat( + "AllGatherParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_size=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_size); + } +}; + +class CpuAllGatherRendezvous + : public Rendezvous { + public: + explicit CpuAllGatherRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const AllGatherParticipantData& p) override { + int world_size = p.rendezvous_key.global_devices.size(); + char* out = static_cast(p.destination_buffer); + for (int i = 0; i < world_size; ++i, out += p.chunk_size) { + std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); + } + return nullptr; + } +}; + +struct ReduceScatterParticipantData : ParticipantData { + ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) + : ParticipantData(rendezvous_key_p, rank) {} + + ReductionKind reduction_kind; + PrimitiveType element_type; + const void* source_buffer; + void* destination_buffer; + size_t chunk_elems; + + std::string ToString() const override { + return absl::StrFormat( + "ReduceScatterParticipantData{rank=%d, " + "devices=[%s], source_buffer=%p, " + "destination_buffer=%p, chunk_elems=%d}", + local_rank, + absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), + source_buffer, destination_buffer, chunk_elems); + } +}; + +class CpuReduceScatterRendezvous + : public Rendezvous { + public: + explicit CpuReduceScatterRendezvous(const RendezvousKey& k) + : Rendezvous(k) {} + + protected: + absl::StatusOr RunCollectiveOp( + const ReduceScatterParticipantData& me) override { + auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); + int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; + + std::vector inputs; + inputs.reserve(participants_.size()); + for (const auto& p : participants_) { + inputs.push_back(reinterpret_cast(p->source_buffer) + + chunk_offset); + } + + if (primitive_util::IsArrayType(me.element_type)) { + TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( + [&](const auto constant_type) { + return ReduceScatter(me.reduction_kind, inputs, + me.destination_buffer, + me.chunk_elems); + }, + me.element_type)); + } else { + return absl::UnimplementedError(absl::StrCat( + "Unexpected datatype: ", + primitive_util::LowercasePrimitiveTypeName(me.element_type))); + } + return nullptr; + } +}; + +} // namespace + +struct InProcessCommunicator::State { + RefcountingHashMap + all_reduce_rendezvous_map; + RefcountingHashMap + collective_permute_rendezvous_map; + RefcountingHashMap + all_to_all_rendezvous_map; + RefcountingHashMap + all_gather_rendezvous_map; + RefcountingHashMap + reduce_scatter_rendezvous_map; +}; + +InProcessCommunicator::InProcessCommunicator(std::shared_ptr state, + size_t rank, size_t num_ranks) + : state_(std::move(state)), rank_(rank), num_ranks_(num_ranks) {} + +InProcessCommunicator::~InProcessCommunicator() = default; + +std::shared_ptr +InProcessCommunicator::CreateState() { + return std::make_shared(); +} + +absl::Status InProcessCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + AllReduceParticipantData participant(key, rank_); + participant.element_count = count; + participant.primitive_type = dtype; + participant.source_data = send_buffer.opaque(); + participant.destination_data = recv_buffer.opaque(); + participant.reduction_kind = reduction_kind; + + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + + return CpuAllReduceRendezvous::SubmitParticipant( + [&] { + return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + CollectivePermuteParticipantData participant(key, rank_); + participant.source_buffer = send_buffer.opaque(); + participant.destination_buffer = recv_buffer.opaque(); + participant.num_bytes = count * primitive_util::ByteWidth(dtype); + participant.source_rank = std::nullopt; + if (source_rank) { + participant.source_rank = source_rank->value(); + } + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuCollectivePermuteRendezvous::SubmitParticipant( + [&] { + return state_->collective_permute_rendezvous_map + .GetOrCreateIfAbsent(key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + AllToAllParticipantData participant(key, rank_); + TF_RET_CHECK(send_buffers.size() == recv_buffers.size()); + + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + participant.chunk_size = chunk_bytes; + participant.source_buffers.reserve(send_buffers.size()); + participant.destination_buffers.reserve(recv_buffers.size()); + for (se::DeviceMemoryBase send_buffer : send_buffers) { + participant.source_buffers.push_back(send_buffer.opaque()); + } + for (se::DeviceMemoryBase recv_buffer : recv_buffers) { + participant.destination_buffers.push_back(recv_buffer.opaque()); + } + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllToAllRendezvous::SubmitParticipant( + [&] { + return state_->all_to_all_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + AllGatherParticipantData participant(key, rank_); + participant.chunk_size = count * primitive_util::ByteWidth(dtype); + participant.source_buffer = send_buffer.opaque(); + participant.destination_buffer = recv_buffer.opaque(); + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuAllGatherRendezvous::SubmitParticipant( + [&] { + return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +absl::Status InProcessCommunicator::ReduceScatter( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); + const RendezvousKey& key = cpu_executor->rendezvous_key(); + + ReduceScatterParticipantData participant(key, rank_); + participant.element_type = dtype; + participant.reduction_kind = reduction_kind; + participant.chunk_elems = count; + participant.source_buffer = send_buffer.opaque(); + participant.destination_buffer = recv_buffer.opaque(); + auto make_cpu_rendezvous = [](const RendezvousKey& k) { + return std::make_unique(k); + }; + return CpuReduceScatterRendezvous::SubmitParticipant( + [&] { + return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( + key, make_cpu_rendezvous); + }, + participant) + .status(); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/in_process_communicator.h b/xla/backends/cpu/collectives/in_process_communicator.h new file mode 100644 index 0000000000000..abc82c7aba211 --- /dev/null +++ b/xla/backends/cpu/collectives/in_process_communicator.h @@ -0,0 +1,109 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +// XLA communicator that implements collective operations using shared memory +// and works only within a single process. +class InProcessCommunicator : public Communicator { + public: + // A state shared by all InProcessCommunicators in the clique. + struct State; + + // Creates a new State for constructing InProcessCommunicators. + static std::shared_ptr CreateState(); + + InProcessCommunicator(std::shared_ptr state, size_t rank, + size_t num_ranks); + ~InProcessCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return num_ranks_; } + + std::string ToString() const override { + return absl::StrCat("InProcessCommunicator [rank: ", rank_, + " num_ranks: ", num_ranks_, "]"); + } + + private: + std::shared_ptr state_; + size_t rank_; + size_t num_ranks_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_IN_PROCESS_COMMUNICATOR_H_ diff --git a/xla/backends/cpu/collectives/mpi_communicator.cc b/xla/backends/cpu/collectives/mpi_communicator.cc new file mode 100644 index 0000000000000..0062593da7540 --- /dev/null +++ b/xla/backends/cpu/collectives/mpi_communicator.cc @@ -0,0 +1,242 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/collectives/mpi_communicator.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mpi.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/primitive_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::cpu { + +absl::StatusOr PrimitiveTypeToMpiType( + PrimitiveType element_type) { + switch (element_type) { + case S8: + return MPI_INT8_T; + case U8: + case PRED: + return MPI_UINT8_T; + case S16: + return MPI_INT16_T; + case U16: + return MPI_UINT16_T; + case S32: + return MPI_INT32_T; + case U32: + return MPI_UINT32_T; + case S64: + return MPI_INT64_T; + case U64: + return MPI_UINT64_T; + case F32: + return MPI_FLOAT; + case F64: + return MPI_DOUBLE; + case C64: + return MPI_C_COMPLEX; + case C128: + return MPI_C_DOUBLE_COMPLEX; + default: + // For implementing the reduction of unsupported types + // see e.g. https://stackoverflow.com/a/29643391 + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported primitive type for reduction: ", + primitive_util::LowercasePrimitiveTypeName(element_type))); + } +} + +bool MpiTypeIsComplex(MPI_Datatype type) { + return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; +} + +absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, + MPI_Datatype type) { + switch (reduction_kind) { + case ReductionKind::SUM: + return MPI_SUM; + case ReductionKind::PRODUCT: + return MPI_PROD; + case ReductionKind::MIN: + if (!MpiTypeIsComplex(type)) { + return MPI_MIN; + } else { + return absl::InvalidArgumentError( + "MIN reduction not supported for complex types"); + } + case ReductionKind::MAX: + if (!MpiTypeIsComplex(type)) { + return MPI_MAX; + } else { + return absl::InvalidArgumentError( + "MAX reduction not supported for complex types"); + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unknown reduction kind: ", reduction_kind)); + } +} + +static absl::Status MpiErrorToAbslStatus(int error) { + if (error != MPI_SUCCESS) { + char error_str[MPI_MAX_ERROR_STRING]; + int len; + MPI_Error_string(error, error_str, &len); + return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); + } + return absl::OkStatus(); +} + +MpiCommunicator::MpiCommunicator(int color, int key) { + MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); + MPI_Comm_rank(comm_, &mpi_rank_); + MPI_Comm_size(comm_, &mpi_size_); +} + +MpiCommunicator::~MpiCommunicator() { MPI_Comm_free(&comm_); }; + +absl::Status MpiCommunicator::AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus(MPI_Allreduce( + send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_)); +} + +absl::Status MpiCommunicator::CollectivePermute( + se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, std::optional source_rank, + absl::Span target_ranks, const Executor& executor) { + int tag = 0; // TODO come up with better tags. + + const int rank = mpi_rank_; + + std::vector requests; + + size_t num_bytes = count * primitive_util::ByteWidth(dtype); + + if (source_rank) { + if (source_rank->value() == rank) { + std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); + } else { + VLOG(1) << "recv at " << rank << " from " << source_rank->value(); + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Irecv(recv_buffer.opaque(), num_bytes, MPI_BYTE, + source_rank->value(), tag, comm_, &requests.back()))); + } + } else { + std::memset(recv_buffer.opaque(), 0, num_bytes); + } + + for (RankId target : target_ranks) { + if (target != rank) { + VLOG(1) << "send from " << rank << " to " << target.value(); + requests.emplace_back(); + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Isend(send_buffer.opaque(), num_bytes, MPI_BYTE, target.value(), + tag, comm_, &requests.back()))); + } + } + + for (auto& request : requests) { + TF_RETURN_IF_ERROR( + MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCommunicator::AllToAll( + absl::Span send_buffers, + absl::Span recv_buffers, PrimitiveType dtype, + size_t count, const Executor& executor) { + // We can't use MPI_Alltoall directly because it assumes that the inputs and + // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. + + int tag = 0; // TODO use better tags. + const int rank = mpi_rank_; + const int size = mpi_size_; + TF_RET_CHECK(size == send_buffers.size()); + TF_RET_CHECK(size == recv_buffers.size()); + + size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); + + std::vector input_buffers; + std::vector output_buffers; + + for (int i = 0; i < size; i++) { + input_buffers.push_back(const_cast(send_buffers[i].opaque())); + output_buffers.push_back(const_cast(recv_buffers[i].opaque())); + } + + std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); + + for (int i = 1; i < size; i++) { + int send_rank = (rank + i) % size; + int recv_rank = (rank + size - i) % size; + TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( + MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, + tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, + recv_rank, tag, comm_, MPI_STATUS_IGNORE))); + } + + return absl::OkStatus(); +} + +absl::Status MpiCommunicator::AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + const Executor& executor) { + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type, + recv_buffer.opaque(), count, type, + comm_)); +} + +absl::Status MpiCommunicator::ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) { + const int size = mpi_size_; + std::vector recvcounts(size, count); + TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); + TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); + return MpiErrorToAbslStatus( + MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(), + recvcounts.data(), type, op, comm_)); +} + +} // namespace xla::cpu diff --git a/xla/backends/cpu/collectives/mpi_communicator.h b/xla/backends/cpu/collectives/mpi_communicator.h new file mode 100644 index 0000000000000..cfed534b66bd5 --- /dev/null +++ b/xla/backends/cpu/collectives/mpi_communicator.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ +#define XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mpi.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla::cpu { + +class MpiCommunicator : public Communicator { + public: + explicit MpiCommunicator(int color, int key); + ~MpiCommunicator() override; + + absl::Status AllReduce(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) override; + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) override; + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) override; + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) override; + + absl::Status Broadcast(se::DeviceMemoryBase, se::DeviceMemoryBase, + PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Broadcast is not implemented"); + } + + absl::Status Send(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Send is not implemented"); + } + + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, RankId, + const Executor&) override { + return Unimplemented("Recv is not implemented"); + } + + absl::StatusOr NumRanks() const override { return mpi_size_; } + + std::string ToString() const override { + return absl::StrCat("MpiCommunicator [rank: ", mpi_rank_, + " num_ranks: ", mpi_size_, "]"); + } + + private: + MPI_Comm comm_; + int mpi_rank_; + int mpi_size_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_COLLECTIVES_MPI_COMMUNICATOR_H_ diff --git a/xla/backends/cpu/runtime/BUILD b/xla/backends/cpu/runtime/BUILD index a83a5e51dca28..cd1e7b89e9c1a 100644 --- a/xla/backends/cpu/runtime/BUILD +++ b/xla/backends/cpu/runtime/BUILD @@ -145,6 +145,8 @@ cc_library( ":resource_use", "//xla:executable_run_options", "//xla:util", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives", "//xla/ffi:execution_context", "//xla/runtime:buffer_use", "//xla/service:global_device_id", @@ -155,11 +157,12 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", + "@com_google_absl//absl/strings:string_view", "@tsl//tsl/profiler/lib:traceme", "@tsl//tsl/profiler/lib:traceme_encode", ], @@ -593,6 +596,11 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_clique_key", + "//xla/backends/cpu/collectives:cpu_cliques", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:communicator", + "//xla/core/collectives:rank_id", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", @@ -601,6 +609,9 @@ cc_library( "//xla/service/cpu:collectives_interface", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -610,9 +621,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", ], ) diff --git a/xla/backends/cpu/runtime/all_gather_thunk.cc b/xla/backends/cpu/runtime/all_gather_thunk.cc index c56fdf94903b4..9a3c2fff062de 100644 --- a/xla/backends/cpu/runtime/all_gather_thunk.cc +++ b/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -77,7 +77,7 @@ tsl::AsyncValueRef AllGatherThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { diff --git a/xla/backends/cpu/runtime/all_reduce_thunk.cc b/xla/backends/cpu/runtime/all_reduce_thunk.cc index d9be82226ec34..9dca34f90ceae 100644 --- a/xla/backends/cpu/runtime/all_reduce_thunk.cc +++ b/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -102,7 +102,7 @@ tsl::AsyncValueRef AllReduceThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { const Shape& shape = destination_shape(i); diff --git a/xla/backends/cpu/runtime/all_to_all_thunk.cc b/xla/backends/cpu/runtime/all_to_all_thunk.cc index ee18d893c07bd..37235935754bc 100644 --- a/xla/backends/cpu/runtime/all_to_all_thunk.cc +++ b/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -76,7 +76,7 @@ tsl::AsyncValueRef AllToAllThunk::Execute( return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); const Shape& shape = destination_shape(0); diff --git a/xla/backends/cpu/runtime/collective_permute_thunk.cc b/xla/backends/cpu/runtime/collective_permute_thunk.cc index 5ee3a8ea2cb45..6387eb31f35be 100644 --- a/xla/backends/cpu/runtime/collective_permute_thunk.cc +++ b/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -131,7 +131,7 @@ CollectivePermuteThunk::Execute(const ExecuteParams& params) { return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { diff --git a/xla/backends/cpu/runtime/collective_thunk.cc b/xla/backends/cpu/runtime/collective_thunk.cc index 4bebdd09cd31c..35a6f72fb9671 100644 --- a/xla/backends/cpu/runtime/collective_thunk.cc +++ b/xla/backends/cpu/runtime/collective_thunk.cc @@ -32,23 +32,27 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_clique_key.h" +#include "xla/backends/cpu/collectives/cpu_cliques.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/backends/cpu/runtime/thunk.h" +#include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" namespace xla::cpu { @@ -172,7 +176,7 @@ CollectiveThunk::ExecuteWithCommunicator( TF_RET_CHECK(params) << "Collective parameters are not set for collective operation"; - CollectivesInterface* collectives = params->collectives; + CpuCollectives* collectives = params->collectives; TF_RET_CHECK(collectives) << "Collectives interface is not set for collective operation"; @@ -183,8 +187,10 @@ CollectiveThunk::ExecuteWithCommunicator( VLOG(3) << absl::StreamFormat(" rank=%d, key=%s", rank, key.ToString()); - TF_ASSIGN_OR_RETURN(std::shared_ptr communicator, - collectives->GetCommunicator(key.global_devices, rank)); + CpuCliqueKey clique_key(key.global_devices); + TF_ASSIGN_OR_RETURN( + Communicator * communicator, + AcquireCommunicator(collectives, clique_key, RankId(rank))); TF_RETURN_IF_ERROR(callback(key, *communicator)); diff --git a/xla/backends/cpu/runtime/collective_thunk.h b/xla/backends/cpu/runtime/collective_thunk.h index 8efc767838806..60c98ce37547c 100644 --- a/xla/backends/cpu/runtime/collective_thunk.h +++ b/xla/backends/cpu/runtime/collective_thunk.h @@ -86,8 +86,8 @@ class CollectiveThunk : public Thunk { protected: // Callback for collective thunk implementations. - using Callback = absl::AnyInvocable; + using Callback = absl::AnyInvocable; static bool IsDataTypeSupportedByCollectiveReduce(PrimitiveType datatype); diff --git a/xla/backends/cpu/runtime/custom_call_thunk.cc b/xla/backends/cpu/runtime/custom_call_thunk.cc index 8f693a1e3c537..974a77522ac77 100644 --- a/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -132,6 +132,12 @@ absl::StatusOr BuildCallFrameForTypedFFI( // memory addresses will be updated at runtime. for (int i = 0; i < op_buffers.arguments_buffers.size(); ++i) { auto& shape = op_buffers.arguments_shapes[i]; + + if (shape.IsToken()) { + builder.AddTokenArg(); + continue; + } + auto elements = absl::c_accumulate(shape.dimensions(), 1ULL, std::multiplies()); auto dtype_bytes = primitive_util::ByteWidth(shape.element_type()); @@ -144,6 +150,12 @@ absl::StatusOr BuildCallFrameForTypedFFI( // memory addresses will be updated at runtime. for (int i = 0; i < op_buffers.results_buffers.size(); ++i) { auto& shape = op_buffers.results_shapes[i]; + + if (shape.IsToken()) { + builder.AddTokenRet(); + continue; + } + auto elements = absl::c_accumulate(shape.dimensions(), 1ULL, std::multiplies()); auto dtype_bytes = primitive_util::ByteWidth(shape.element_type()); diff --git a/xla/backends/cpu/runtime/reduce_scatter_thunk.cc b/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index badeb6a860c3e..20311adf01b7c 100644 --- a/xla/backends/cpu/runtime/reduce_scatter_thunk.cc +++ b/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -90,7 +90,7 @@ ReduceScatterThunk::Execute(const ExecuteParams& params) { return ExecuteWithCommunicator( params.collective_params, - [&](const RendezvousKey& key, CollectivesCommunicator& comm) { + [&](const RendezvousKey& key, Communicator& comm) { CpuCollectives::Executor executor(key, DefaultCollectiveTimeout()); for (int32_t i = 0; i < data.source.size(); ++i) { diff --git a/xla/backends/cpu/runtime/thunk.cc b/xla/backends/cpu/runtime/thunk.cc index 8dab085b47fb6..a17de11724bda 100644 --- a/xla/backends/cpu/runtime/thunk.cc +++ b/xla/backends/cpu/runtime/thunk.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/cpu_executable_run_options.h" @@ -30,7 +32,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/logging.h" +#include "xla/tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" @@ -121,8 +123,7 @@ Thunk::CollectiveExecuteParams::Create( Thunk::CollectiveExecuteParams::CollectiveExecuteParams( RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, - const DeviceAssignment* device_assignment, - CollectivesInterface* collectives) + const DeviceAssignment* device_assignment, CpuCollectives* collectives) : run_id(run_id), local_device_ordinal(local_device_ordinal), global_device_id(global_device_id), diff --git a/xla/backends/cpu/runtime/thunk.h b/xla/backends/cpu/runtime/thunk.h index 38d3f41d6a75b..2c86db9251774 100644 --- a/xla/backends/cpu/runtime/thunk.h +++ b/xla/backends/cpu/runtime/thunk.h @@ -28,21 +28,20 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/function_library.h" -#include "xla/backends/cpu/runtime/kernel_c_api.h" #include "xla/backends/cpu/runtime/resource_use.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/collectives_interface.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" -#include "xla/util.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" namespace Eigen { struct ThreadPoolDevice; @@ -164,13 +163,13 @@ class Thunk { GlobalDeviceId global_device_id; const DeviceAssignment* device_assignment = nullptr; - CollectivesInterface* collectives = nullptr; + CpuCollectives* collectives = nullptr; private: CollectiveExecuteParams(RunId run_id, int64_t local_device_ordinal, GlobalDeviceId global_device_id, const DeviceAssignment* device_assignment, - CollectivesInterface* collectives); + CpuCollectives* collectives); }; //===--------------------------------------------------------------------===// diff --git a/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc b/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc index 79efa4e752e9f..846925a925ce1 100644 --- a/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc +++ b/xla/backends/gpu/codegen/ir/xla_gpu_ops.cc @@ -114,7 +114,7 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "must have thread_id dimension in both indexing maps"; } - if (map_in.GetDimVars(0).bounds != map_out.GetDimVars(0).bounds) { + if (map_in.GetDimVar(0).bounds != map_out.GetDimVar(0).bounds) { return emitOpError() << "thread_id dimension must have the same bounds in " "both indexing maps"; } diff --git a/xla/backends/gpu/codegen/transforms/BUILD b/xla/backends/gpu/codegen/transforms/BUILD index 3894a53825be8..090cf3d26325a 100644 --- a/xla/backends/gpu/codegen/transforms/BUILD +++ b/xla/backends/gpu/codegen/transforms/BUILD @@ -38,6 +38,7 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "atomic_rmw_utils.cc", "convert_float_nvidia.cc", "convert_xla_gpu_pure_call_ops.cc", "erase_dead_functions.cc", diff --git a/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc b/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc new file mode 100644 index 0000000000000..ad1c769447e01 --- /dev/null +++ b/xla/backends/gpu/codegen/transforms/atomic_rmw_utils.cc @@ -0,0 +1,120 @@ +/* Copyright 2025 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/ilist.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/UseDefLists.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/ir/xla_ops.h" + +namespace xla { +namespace gpu { + +#include "xla/backends/gpu/codegen/transforms/passes.h.inc" + +using mlir::Operation; +using mlir::Type; +using mlir::Value; + +namespace ml = ::mlir::LLVM; +namespace arith = ::mlir::arith; + +bool IsAtomicIntegral(Type element_type) { + if (!element_type.isInteger()) { + return false; + } + unsigned element_bitwidth = element_type.getIntOrFloatBitWidth(); + return element_bitwidth == 32 || element_bitwidth == 64; +} + +std::optional GetAtomicBinOp(Operation* modifier_op, + Type element_type) { + return llvm::TypeSwitch>( + modifier_op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; }) + .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; }) + .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; }) + // Integer operations. + .Case([&](arith::AddIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::add) + : std::nullopt; + }) + .Case([&](arith::MaxUIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::umax) + : std::nullopt; + }) + .Case([&](arith::MinUIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::umin) + : std::nullopt; + }) + .Case([&](arith::MaxSIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::max) + : std::nullopt; + }) + .Case([&](arith::MinSIOp op) { + return IsAtomicIntegral(element_type) + ? std::make_optional(ml::AtomicBinOp::min) + : std::nullopt; + }) + .Default([](Operation* op) { return std::nullopt; }); +} + +// Returns atomic op modifier and the atomic bin op kind. +std::optional> GetAtomicModifierParameters( + AtomicRMWOp op) { + Type element_type = op.getInput().getType().getElementType(); + auto& operations = op.getBody()->getOperations(); + auto terminator = op.getBody()->getTerminator(); + if (operations.size() > 2) { + return std::nullopt; + } + // If the body contains only the terminator, then it is an atomic store. + if (operations.size() == 1) { + // TODO(b/336367145): Support complex atomic store. + if (element_type.isF32() || IsAtomicIntegral(element_type)) { + return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg); + } + return std::nullopt; + } + // Match the kind of the atomic op. + mlir::Operation* modifier_op = &operations.front(); + auto kind = GetAtomicBinOp(modifier_op, element_type); + if (!kind.has_value()) { + return std::nullopt; + } + // Find the modifier arg that does not match the argument of `atomic_rmw` + // body. + Value block_arg = op.getBody()->getArgument(0); + Value modifier_arg = modifier_op->getOperand(0) == block_arg + ? modifier_op->getOperand(1) + : modifier_op->getOperand(0); + return std::make_pair(modifier_arg, *kind); +} + +} // namespace gpu +} // namespace xla diff --git a/xla/backends/gpu/codegen/transforms/lower_tensors.cc b/xla/backends/gpu/codegen/transforms/lower_tensors.cc index 822ba8498800e..0fff3bc811bbc 100644 --- a/xla/backends/gpu/codegen/transforms/lower_tensors.cc +++ b/xla/backends/gpu/codegen/transforms/lower_tensors.cc @@ -755,71 +755,6 @@ class RewriteAtomicRMW : public OpRewritePattern { } private: - // Returns atomic op modifier and the atomic bin op kind. - std::optional> GetAtomicModifierParameters( - AtomicRMWOp op) const { - Type element_type = op.getInput().getType().getElementType(); - auto& operations = op.getBody()->getOperations(); - auto terminator = op.getBody()->getTerminator(); - if (operations.size() > 2) { - return std::nullopt; - } - // If the body contains only the terminator, then it is an atomic store. - if (operations.size() == 1) { - // TODO(b/336367145): Support complex atomic store. - if (element_type.isF32() || IsAtomicIntegral(element_type)) { - return std::make_pair(terminator->getOperand(0), ml::AtomicBinOp::xchg); - } - return std::nullopt; - } - // Match the kind of the atomic op. - mlir::Operation* modifier_op = &operations.front(); - std::optional kind = - llvm::TypeSwitch>( - modifier_op) - // Floating-point operations. - .Case([](arith::AddFOp op) { return ml::AtomicBinOp::fadd; }) - .Case([](arith::MaximumFOp op) { return ml::AtomicBinOp::fmax; }) - .Case([](arith::MinimumFOp op) { return ml::AtomicBinOp::fmin; }) - // Integer operations. - .Case([&](arith::AddIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::add) - : std::nullopt; - }) - .Case([&](arith::MaxUIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::umax) - : std::nullopt; - }) - .Case([&](arith::MinUIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::umin) - : std::nullopt; - }) - .Case([&](arith::MaxSIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::max) - : std::nullopt; - }) - .Case([&](arith::MinSIOp op) { - return IsAtomicIntegral(element_type) - ? std::make_optional(ml::AtomicBinOp::min) - : std::nullopt; - }) - .Default([](Operation* op) { return std::nullopt; }); - if (!kind.has_value()) { - return std::nullopt; - } - // Find the modifier arg that does not match the argument of `atomic_rmw` - // body. - Value block_arg = op.getBody()->getArgument(0); - Value modifier_arg = modifier_op->getOperand(0) == block_arg - ? modifier_op->getOperand(1) - : modifier_op->getOperand(0); - return std::make_pair(modifier_arg, *kind); - } - // Certain computations, such as floating-point addition and integer // maximization, can be simply implemented using an LLVM atomic instruction. // If "computation" is one of this kind, emits code to do that and returns diff --git a/xla/backends/gpu/codegen/transforms/passes.h b/xla/backends/gpu/codegen/transforms/passes.h index db6f75779b93b..de12227f94c0c 100644 --- a/xla/backends/gpu/codegen/transforms/passes.h +++ b/xla/backends/gpu/codegen/transforms/passes.h @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "xla/hlo/analysis/indexing_map.h" +#include "xla/codegen/ir/xla_ops.h" #include "xla/stream_executor/device_description.h" namespace xla { @@ -31,6 +33,10 @@ namespace gpu { #define GEN_PASS_DECL #include "xla/backends/gpu/codegen/transforms/passes.h.inc" +// Returns atomic op modifier and the atomic bin op kind. +std::optional> +GetAtomicModifierParameters(AtomicRMWOp op); + std::unique_ptr CreateConvertFloatNvidiaPass(); std::optional> MaybeCreateConvertFloatNvidiaPass( const se::DeviceDescription& device_description); @@ -56,7 +62,10 @@ std::unique_ptr CreatePropagateSliceIndicesPass(); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); std::unique_ptr CreateUnswitchLoopsPass(); -std::unique_ptr CreateVectorizeLoadsAndStoresPass(); +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const std::string& gpu_device_info = ""); +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description); #define GEN_PASS_REGISTRATION #include "xla/backends/gpu/codegen/transforms/passes.h.inc" diff --git a/xla/backends/gpu/codegen/transforms/passes.td b/xla/backends/gpu/codegen/transforms/passes.td index 1b5ffbdb24636..53b20387c62aa 100644 --- a/xla/backends/gpu/codegen/transforms/passes.td +++ b/xla/backends/gpu/codegen/transforms/passes.td @@ -256,6 +256,11 @@ def VectorizeLoadsAndStoresPass : "mlir::vector::VectorDialect", ]; + let options = [ + Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", + "Serialized stream_executor::GPUDeviceInfo proto.">, + ]; + let constructor = "CreateVectorizeLoadsAndStoresPass()"; } diff --git a/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir b/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir index a3b7e816bb05f..d5d3d0a74fe4a 100644 --- a/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir +++ b/xla/backends/gpu/codegen/transforms/tests/vectorize_loads_stores.mlir @@ -1,5 +1,6 @@ // RUN: emitters_opt -allow-unregistered-dialect %s -split-input-file \ -// RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s +// RUN: -xla-gpu-vectorize-loads-stores="gpu_device_info='cuda_compute_capability {major: 6}'" -cse -canonicalize \ +// RUN: | FileCheck %s #map = #xla.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," "domain: d0 in [0, 63], s0 in [0, 1]"> @@ -251,7 +252,7 @@ func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f3 func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index + %c4 = arith.constant 4 : index %cst = arith.constant 0.0 : f32 %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> @@ -263,6 +264,7 @@ func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}) // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[V:.*]] = scf.for +// CHECK-SAME: (vector<4xf32>) // CHECK-NEXT: vector.insert // CHECK-NEXT: scf.yield // CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[C0]]] diff --git a/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc b/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc index 8202ae05e8d07..19e6b7faf5e36 100644 --- a/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc +++ b/xla/backends/gpu/codegen/transforms/vectorize_loads_stores.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APInt.h" @@ -40,7 +41,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/backends/gpu/codegen/ir/xla_gpu_ops.h" +#include "xla/backends/gpu/codegen/transforms/passes.h" +#include "xla/codegen/ir/xla_ops.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -326,21 +329,45 @@ class VectorizeLoadsAndStoresPass : public impl::VectorizeLoadsAndStoresPassBase< VectorizeLoadsAndStoresPass> { public: + explicit VectorizeLoadsAndStoresPass( + const VectorizeLoadsAndStoresPassOptions& options) + : VectorizeLoadsAndStoresPassBase(options) {} + + explicit VectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} + void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (!gpu_device_info_.empty()) { + se::GpuDeviceInfoProto device_info; + CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_, + &device_info)); + device_description_ = se::DeviceDescription(device_info); + } + mlir::MLIRContext* mlir_context = &getContext(); + mlir::RewritePatternSet patterns(mlir_context); + patterns.add(mlir_context); + if (mlir::failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } + + se::DeviceDescription device_description_; }; } // namespace -std::unique_ptr> -CreateVectorizeLoadsAndStoresPass() { - return std::make_unique(); +std::unique_ptr<::mlir::Pass> CreateVectorizeLoadsAndStoresPass( + const std::string& gpu_device_info) { + VectorizeLoadsAndStoresPassOptions options; + options.gpu_device_info_ = gpu_device_info; + return std::make_unique(options); +} + +std::unique_ptr CreateVectorizeLoadsAndStoresPass( + const se::DeviceDescription& device_description) { + return std::make_unique(device_description); } } // namespace gpu diff --git a/xla/backends/profiler/gpu/device_tracer_cuda.cc b/xla/backends/profiler/gpu/device_tracer_cuda.cc index 578d4ab6d3021..2d675afba107d 100644 --- a/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -46,8 +46,7 @@ using tsl::ReadBoolFromEnvVar; // GpuTracer for GPU. class GpuTracer : public tsl::profiler::ProfilerInterface { public: - GpuTracer(CuptiTracer* cupti_tracer, CuptiInterface* cupti_interface) - : cupti_tracer_(cupti_tracer) { + explicit GpuTracer(CuptiTracer* cupti_tracer) : cupti_tracer_(cupti_tracer) { VLOG(1) << "GpuTracer created."; } ~GpuTracer() override {} @@ -227,8 +226,7 @@ std::unique_ptr CreateGpuTracer( if (!cupti_tracer->IsAvailable()) { return nullptr; } - profiler::CuptiInterface* cupti_interface = profiler::GetCuptiInterface(); - return std::make_unique(cupti_tracer, cupti_interface); + return std::make_unique(cupti_tracer); } auto register_gpu_tracer_factory = [] { diff --git a/xla/codegen/ir/xla_ops.cc b/xla/codegen/ir/xla_ops.cc index 1f48f5bdd5c9c..1d72b0264b66f 100644 --- a/xla/codegen/ir/xla_ops.cc +++ b/xla/codegen/ir/xla_ops.cc @@ -323,7 +323,7 @@ absl::StatusOr GetNewIndexingMapAfterFoldingSequence( replacement_expr = getAffineDimExpr(num_dims + added_dim_args.size(), ctx); added_dim_args.push_back(producer_operand.get()); - new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); + new_dim_vars.push_back(producer_map.GetDimVar(dim_num)); } producer_dim_replacements.push_back(replacement_expr); } @@ -529,7 +529,7 @@ struct FoldApplyIndexingOperands } else { new_operands.push_back(operand.get()); dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); - new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); + new_dim_vars.push_back(indexing_map.GetDimVar(operand_id)); } } rewriter.replaceOpWithNewOp( diff --git a/xla/core/collectives/BUILD b/xla/core/collectives/BUILD index 0ab61569ecc1e..2e9ace8f7aa25 100644 --- a/xla/core/collectives/BUILD +++ b/xla/core/collectives/BUILD @@ -68,6 +68,7 @@ cc_library( hdrs = ["communicator.h"], deps = [ ":rank_id", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", "//xla/stream_executor:device_memory", diff --git a/xla/core/collectives/clique.cc b/xla/core/collectives/clique.cc index 6eb73c1ea91cb..1a0a5d659aecb 100644 --- a/xla/core/collectives/clique.cc +++ b/xla/core/collectives/clique.cc @@ -21,8 +21,10 @@ limitations under the License. #include "absl/container/btree_map.h" #include "absl/functional/function_ref.h" +#include "absl/status/status.h" #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/rank_id.h" +#include "xla/util.h" namespace xla { @@ -44,4 +46,13 @@ void Clique::ForEachComm( } } +absl::Status Clique::AddComm(RankId rank, + std::unique_ptr communicator) { + auto emplaced = communicators_.emplace(rank, std::move(communicator)); + if (!emplaced.second) { + return InvalidArgument("Rank %d already exists in clique", rank.value()); + } + return absl::OkStatus(); +} + } // namespace xla diff --git a/xla/core/collectives/clique.h b/xla/core/collectives/clique.h index 69705ccfa524c..24f80a3f1682c 100644 --- a/xla/core/collectives/clique.h +++ b/xla/core/collectives/clique.h @@ -49,6 +49,9 @@ class Clique { // Returns a communicator for a given rank if it's in a clique. std::optional comm(RankId rank) const; + // Adds a communicator to the clique. + absl::Status AddComm(RankId rank, std::unique_ptr communicator); + // Calls `fn` for each communicator in the clique. void ForEachComm(absl::FunctionRef fn) const; @@ -61,8 +64,8 @@ class Clique { size_t num_communicators() const { return communicators_.size(); } private: - // We keep communicators in a sorted order by rank to guarantee deterministic - // traversal order in `ForEachComm`. + // We keep communicators in a sorted order by rank to guarantee + // deterministic traversal order in `ForEachComm`. absl::btree_map> communicators_; }; diff --git a/xla/core/collectives/clique_key.cc b/xla/core/collectives/clique_key.cc index 2da8d6651c354..92749633bb91a 100644 --- a/xla/core/collectives/clique_key.cc +++ b/xla/core/collectives/clique_key.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/core/collectives/clique_key.h" +#include #include #include #include @@ -31,6 +32,8 @@ CliqueKey::CliqueKey(std::vector devices) absl::Span CliqueKey::devices() const { return devices_; } +size_t CliqueKey::num_devices() const { return devices_.size(); } + std::optional CliqueKey::rank(GlobalDeviceId id) const { if (auto it = absl::c_find(devices_, id); it != devices_.end()) { return RankId(it - devices_.begin()); diff --git a/xla/core/collectives/clique_key.h b/xla/core/collectives/clique_key.h index 0541177343150..37e16d5fb774a 100644 --- a/xla/core/collectives/clique_key.h +++ b/xla/core/collectives/clique_key.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ #define XLA_CORE_COLLECTIVES_CLIQUE_KEY_H_ +#include #include #include #include @@ -52,6 +53,7 @@ class CliqueKey { std::optional rank(GlobalDeviceId id) const; absl::Span devices() const; + size_t num_devices() const; // Returns true if this clique is a subset of `other`. virtual bool IsSubsetOf(const CliqueKey& other) const = 0; diff --git a/xla/core/collectives/communicator.h b/xla/core/collectives/communicator.h index b6139dec3684b..af95f7063fc80 100644 --- a/xla/core/collectives/communicator.h +++ b/xla/core/collectives/communicator.h @@ -28,6 +28,7 @@ limitations under the License. #include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -53,23 +54,24 @@ class Communicator { virtual absl::Status Unregister() = 0; }; + // Register `buffer` for efficient collective operations (i.e. on NCCL backend + // it registers the buffer for zero-copy collective operations). + virtual absl::StatusOr> + RegisterBuffer(stream_executor::DeviceMemoryBase buffer) { + return Unimplemented("User-managed buffer registration is not supported"); + } + // Abort any uncompleted operations and destroys the underlying communicator // object. It is undefined behavior to use the communicator after calling // this method. - virtual absl::Status Abort() = 0; + virtual absl::Status Abort() { + return Unimplemented("Aborting communicator is not implemented"); + } // Checks the health of the communicator. It might return an error from the // previously launched asynchronous collective operations, and it does not // have to wait for the completion of scheduled operations. - virtual absl::Status HealthCheck() const = 0; - - // Returns the number of ranks in the communicator. - virtual absl::StatusOr NumRanks() const = 0; - - // Register `buffer` for efficient collective operations (i.e. on NCCL backend - // it registers the buffer for zero-copy collective operations). - virtual absl::StatusOr> - RegisterBuffer(stream_executor::DeviceMemoryBase buffer) = 0; + virtual absl::Status HealthCheck() const { return absl::OkStatus(); } // Reduce buffers of length `count` in `send_buff` using `reduction_kind` // reduction and leaves identical copies of the result on each `recv_buff`. @@ -129,6 +131,10 @@ class Communicator { PrimitiveType dtype, size_t count, RankId peer, const Executor& executor) = 0; + // Returns the number of ranks in the communicator. + virtual absl::StatusOr NumRanks() const = 0; + + // Returns a human-readable description of the communicator. virtual std::string ToString() const = 0; }; diff --git a/xla/hlo/analysis/indexing_map.h b/xla/hlo/analysis/indexing_map.h index 17038aa05f73e..77ea7ec24f3be 100644 --- a/xla/hlo/analysis/indexing_map.h +++ b/xla/hlo/analysis/indexing_map.h @@ -286,7 +286,7 @@ class IndexingMap { RangeEvaluator GetRangeEvaluator() const; // Getters for dimension vars. - const Variable& GetDimVars(int64_t id) const { return dim_vars_[id]; } + const Variable& GetDimVar(int64_t id) const { return dim_vars_[id]; } const std::vector& GetDimVars() const { return dim_vars_; } int64_t GetDimVarsCount() const { return dim_vars_.size(); } @@ -407,18 +407,18 @@ class IndexingMap { mlir::AffineMap affine_map_; - // Dimension variable represents a dimension of a tensor or a GPU grid. - // Dimensions correspond to the dimension parameter of `affine_map_`. + // A dimension variable represents a dimension of a tensor or a GPU grid. + // Dimension variables correspond to the dimensions of the `affine_map_`. std::vector dim_vars_; - // RangeSymbol variable represents a range of values, e.g. to compute a single + // A range variable represents a range of values, e.g. to compute a single // element of the reduction's result we need a range of values from the input - // tensor. RangeSymbol variables correspond to the front portion of the + // tensor. Range variables correspond to the front portion of the // symbols in `affine_map_`. std::vector range_vars_; - // RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in - // HLO dynamic-update-slice op. RTSymbol variables correspond to the back + // A runtime variable represents a runtime symbol, e.g. a dynamic offset in of + // a HLO dynamic-update-slice op. Runtime variables correspond to the back // portion of the symbols in `affine_map_`. std::vector rt_vars_; diff --git a/xla/hlo/ir/BUILD b/xla/hlo/ir/BUILD index d1b6e0a409ee4..560f46517e07e 100644 --- a/xla/hlo/ir/BUILD +++ b/xla/hlo/ir/BUILD @@ -232,6 +232,19 @@ cc_library( ], ) +xla_cc_test( + name = "hlo_instruction_utils_test", + srcs = ["hlo_instruction_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_instruction_utils", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_query", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "hlo_reachability", hdrs = ["hlo_reachability.h"], diff --git a/xla/hlo/ir/hlo_computation.h b/xla/hlo/ir/hlo_computation.h index 757505980a079..4411e3102b5a2 100644 --- a/xla/hlo/ir/hlo_computation.h +++ b/xla/hlo/ir/hlo_computation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_COMPUTATION_H_ #define XLA_HLO_IR_HLO_COMPUTATION_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -420,11 +422,23 @@ class HloComputation { // with respect to HloComputation::Equal() method. template friend H AbslHashValue(H h, const HloComputation& computation) { + // Walk the computation in post-order, computing (and caching) the + // Absl::Hash after each instruction to use to as an operand for + // subsequent instructions. auto instructions = computation.MakeInstructionPostOrder(); + absl::flat_hash_map instruction_hash_cache; + instruction_hash_cache.reserve(instructions.size()); for (auto* instruction : instructions) { - h = H::combine(std::move(h), *instruction); + absl::InlinedVector operand_hashes; + for (auto* operand : instruction->operands()) { + operand_hashes.push_back(instruction_hash_cache[operand]); + } + instruction_hash_cache.emplace( + instruction, absl::HashOf(*instruction, operand_hashes)); } - return H::combine(std::move(h), instructions.size()); + return H::combine(std::move(h), + instruction_hash_cache[computation.root_instruction()], + instructions.size()); } using InstructionSequence = tsl::gtl::iterator_range< diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index cd8d5368cc832..db3d994215963 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -41,6 +41,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -1736,27 +1737,20 @@ class HloInstruction { /*ignore_commutative_operand_order=*/true); } + // Allow subclasses to contribute additional attributes to the hash. + virtual void HashAdditionalAttributes(absl::HashState h) const {}; + // Generates a hash value of an HLO instruction. Hash considers - // information on opcode, shape, operands, and typically a root instruction. - // This function returns the same hash value for equivalent HLO instructions, - // with respect to HloInstruction::Identical() method. - // TODO(majnemer): Make the comment here more crisp & accurate. + // information on opcode, shape, number of operands, and other relevant + // additional attributes (e.g. literal values, parameters, etc.). template friend H AbslHashValue(H h, const HloInstruction& hlo) { h = H::combine(std::move(h), hlo.opcode(), hlo.shape()); - if (!hlo.IsCrossModuleAllReduce()) { - for (size_t i = 0; i < hlo.operands().size(); ++i) { - h = H::combine(std::move(h), hlo.operand(i)->shape()); - } h = H::combine(std::move(h), hlo.operand_count()); } - - if (hlo.opcode() == HloOpcode::kFusion) { - h = H::combine(std::move(h), *hlo.fused_expression_root(), - hlo.fusion_kind(), hlo.fused_instruction_count(), - hlo.fused_parameters().size()); - } + // Allow subclasses to mix additional data into h before returning + hlo.HashAdditionalAttributes(absl::HashState::Create(&h)); return h; } diff --git a/xla/hlo/ir/hlo_instruction_utils_test.cc b/xla/hlo/ir/hlo_instruction_utils_test.cc new file mode 100644 index 0000000000000..fe8c488b154e8 --- /dev/null +++ b/xla/hlo/ir/hlo_instruction_utils_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_instruction_utils.h" + +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_query.h" + +namespace xla { + +namespace hlo_instruction_utils { + +namespace { + +class HloInstructionUtilsTest : public HloHardwareIndependentTestBase {}; + +TEST_F(HloInstructionUtilsTest, TestIsUnstridedSlice) { + const char* hlo_text = R"( + HloModule test + ENTRY main { + param = f32[2,8] parameter(0) + strided_slice = f32[2,2] slice(param), slice={[0:2:1], [4:8:2]} + unstrided_slice = f32[2,4] slice(param), slice={[0:2:1], [4:8:1]} + ROOT tuple = (f32[2,2], f32[2,4]) tuple(strided_slice, unstrided_slice) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_text)); + HloInstruction* unstrided_slice = + hlo_query::FindInstruction(m->entry_computation(), "unstrided_slice"); + HloInstruction* strided_slice = + hlo_query::FindInstruction(m->entry_computation(), "strided_slice"); + EXPECT_NE(unstrided_slice, nullptr); + EXPECT_NE(strided_slice, nullptr); + EXPECT_TRUE(IsUnstridedSlice(unstrided_slice)); + EXPECT_FALSE(IsUnstridedSlice(strided_slice)); +} + +TEST_F(HloInstructionUtilsTest, TestAddOrUpdateVectorOfPairsAsAttribute) { + const char* hlo = R"( + HloModule test + ENTRY main { + ROOT param = s32[] parameter(0), frontend_attributes={foo="bar", baz="qux"} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + HloInstruction* param = m->entry_computation()->root_instruction(); + EXPECT_EQ(param->frontend_attributes().map().size(), 2); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + + std::string new_key = "quux"; + std::vector> value = {{1, 2}, {3, 4}}; + AddOrUpdateVectorOfPairsAsAttribute(param, new_key, value); + EXPECT_EQ(param->frontend_attributes().map().size(), 3); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + EXPECT_EQ(param->frontend_attributes().map().at("quux"), "{{1,2},{3,4}}"); + + std::vector> new_value = {{5, 6}, {7, 8}}; + AddOrUpdateVectorOfPairsAsAttribute(param, new_key, new_value); + EXPECT_EQ(param->frontend_attributes().map().size(), 3); + EXPECT_EQ(param->frontend_attributes().map().at("foo"), "bar"); + EXPECT_EQ(param->frontend_attributes().map().at("baz"), "qux"); + EXPECT_EQ(param->frontend_attributes().map().at("quux"), "{{5,6},{7,8}}"); +} + +} // namespace + +} // namespace hlo_instruction_utils + +} // namespace xla diff --git a/xla/hlo/ir/hlo_instructions.h b/xla/hlo/ir/hlo_instructions.h index 1ca2bfddd5559..c21dddeee907b 100644 --- a/xla/hlo/ir/hlo_instructions.h +++ b/xla/hlo/ir/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_INSTRUCTIONS_H_ #define XLA_HLO_IR_HLO_INSTRUCTIONS_H_ +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/hash/hash.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -1356,6 +1358,14 @@ class HloConstantInstruction : public HloInstruction { return false; } + // Add literal to the hash state. + void HashAdditionalAttributes(absl::HashState h) const override { + if (HasLiteral()) { + absl::HashState::combine(std::move(h), + Literal::AbslHashable(literal())); + } + } + private: bool IsElementwiseImpl( const std::optional& operand_idx) const override; @@ -1595,6 +1605,13 @@ class HloFusionInstruction : public HloCallableInstruction { return hlo->opcode() == HloOpcode::kFusion; } + // Add various fusion parameters to the hash. + void HashAdditionalAttributes(absl::HashState h) const override { + absl::HashState::combine(std::move(h), *fused_expression_root(), + fusion_kind(), fused_instruction_count(), + fused_parameters().size()); + } + protected: std::string default_called_computation_name() const override { return "fused_computation"; @@ -1714,6 +1731,11 @@ class HloParameterInstruction : public HloInstruction { return hlo->opcode() == HloOpcode::kParameter; } + // Add parameter number to the hash. + void HashAdditionalAttributes(absl::HashState h) const override { + absl::HashState::combine(std::move(h), parameter_number()); + } + private: void PrintExtraAttributesImpl(AttributePrinter& printer, const HloPrintOptions& options) const override; diff --git a/xla/hlo/ir/hlo_module_test.cc b/xla/hlo/ir/hlo_module_test.cc index 226bf5c892a21..01756318c93ec 100644 --- a/xla/hlo/ir/hlo_module_test.cc +++ b/xla/hlo/ir/hlo_module_test.cc @@ -32,9 +32,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla { namespace { @@ -204,5 +201,231 @@ TEST(HloModuleTest, CloneWithNewConfig) { m1.config().device_memory_size()); } +TEST(HloModuleTest, AbslHashInstructionOrdering) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Add.0 and add.1 are swapped. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.1 = f32[32,32] add(b, c) // Swapped with below + add.0 = f32[32,32] add(a, b) // Swapped with above + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_EQ(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionOpcodes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Second add changed to sub + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] subtract(b, c) // Changed from add to subtract + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionShapes) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Second add has different shape. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + ENTRY main { + // Shapes changed from [32,32] to [16,16] + a = f32[16,16] parameter(0) + b = f32[16,16] parameter(1) + c = f32[16,16] parameter(2) + add.0 = f32[16,16] add(a, b) + add.1 = f32[16,16] add(b, c) + ROOT result = f32[16,16] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashInstructionNaming) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Add x to all names + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + // All names changed to x + ax = f32[32,32] parameter(0) + bx = f32[32,32] parameter(1) + cx = f32[32,32] parameter(2) + add.0x = f32[32,32] add(ax, bx) + add.1x = f32[32,32] add(bx, cx) + ROOT resultx = f32[32,32] add(add.0x, add.1x) + } + )")); + + EXPECT_EQ(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashGraphChanges) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Changed from (a+b)+(b+c) to ((a+b)+c)+a + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(add.0, c) // Changed from add(b, c) + ROOT result = f32[32,32] add(add.1, a) // Changed from add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashParameterChanges) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(0) + b = f32[32,32] parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + // Change parameter numbers + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = f32[32,32] parameter(1) // Changed from parameter(0) + b = f32[32,32] parameter(0) // Changed from parameter(1) + c = f32[32,32] parameter(2) + add.0 = f32[32,32] add(a, b) + add.1 = f32[32,32] add(b, c) + ROOT result = f32[32,32] add(add.0, add.1) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + +TEST(HloModuleTest, AbslHashConstantValues) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = s32[32,32] parameter(0) + c = s32[] constant(42) + b = s32[32,32] broadcast(c), dimensions={} + ROOT result = s32[32,32] add(a, b) + } + )")); + + // Changed from 42 to 43 + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnUnverifiedModule(R"( + HloModule HashTest + + ENTRY main { + a = s32[32,32] parameter(0) + c = s32[] constant(43) // Changed from constant(42) + b = s32[32,32] broadcast(c), dimensions={} + ROOT result = s32[32,32] add(a, b) + } + )")); + + EXPECT_NE(absl::HashOf(*module1), absl::HashOf(*module2)); +} + } // namespace } // namespace xla diff --git a/xla/hlo/ir/hlo_original_value.cc b/xla/hlo/ir/hlo_original_value.cc index c1617888510a4..e76cd15d989ce 100644 --- a/xla/hlo/ir/hlo_original_value.cc +++ b/xla/hlo/ir/hlo_original_value.cc @@ -53,15 +53,14 @@ std::string OriginalValueToStringHelper(const OriginalValue& original_value, return result; } - // The original_value may refer to an empty array, such as origin {}, so let's - // check whether that's the case before accessing them. Generally speaking the - // index _should_ be good, but let's double check. const auto& leaf = original_value.element(shape_index); if (leaf.has_value()) { absl::StrAppend( &result, "{", "\"", leaf->instruction_name, "\"", (leaf->shape_index.empty() ? "" : " " + leaf->shape_index.ToString()), "}"); + } else { + absl::StrAppend(&result, "{}"); } return result; } diff --git a/xla/hlo/parser/hlo_parser.cc b/xla/hlo/parser/hlo_parser.cc index 01335cb5ff28d..3436fd408890f 100644 --- a/xla/hlo/parser/hlo_parser.cc +++ b/xla/hlo/parser/hlo_parser.cc @@ -6488,18 +6488,25 @@ bool HloParserImpl::ParseOriginalValue( ++leaf_shape_index.back(); } else if (lexer_.GetKind() == TokKind::kLbrace) { lexer_.Lex(); - std::string instruction_name; - ShapeIndex shape_index; - if (!ParseString(&instruction_name)) { - return false; - } if (lexer_.GetKind() != TokKind::kRbrace) { - if (!ParseShapeIndex(&shape_index)) { + std::string instruction_name; + ShapeIndex shape_index; + if (!ParseString(&instruction_name)) { return false; } + if (lexer_.GetKind() != TokKind::kRbrace) { + if (!ParseShapeIndex(&shape_index)) { + return false; + } + } + *(**original_value)->mutable_element(leaf_shape_index) = { + instruction_name, shape_index}; + } else { + // The original_value is not expected to have any leaf without values. + // However we should not fail the execution here. This should + // be done in HloVerifier instead. + LOG(WARNING) << "Found an empty leaf node in an original value"; } - *(**original_value)->mutable_element(leaf_shape_index) = { - instruction_name, shape_index}; if (!ParseToken(TokKind::kRbrace, "Expects '} at end of each OriginalArray'")) { return false; diff --git a/xla/hlo/parser/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc index f1ce17e4a57b7..61de9ca31adcd 100644 --- a/xla/hlo/parser/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -5726,6 +5726,20 @@ ENTRY %test { HasSubstr("expects instruction shape"))); } +TEST_F(HloParserTest, EmptyLeafInOriginalValue) { + const std::string hlo_string = R"(HloModule test + +ENTRY %test { + ROOT op = ((f32[], f32[3]{0}), f32[2,3]) parameter(0), origin={(({}, {"v2"}), {"v3"})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + ExpectHasSubstr(module->ToString(HloPrintOptions::ShortParsable()), + "origin={(({}, {\"v2\"}), {\"v3\"})}"); +} + TEST_F(HloParserTest, TranscendentalAccuracyMode) { constexpr absl::string_view hlo_string = R"( HloModule exponential_hw @@ -5842,21 +5856,5 @@ ENTRY main { "error: unexpected attribute \"result_accuracy\""); } -TEST_F(HloParserTest, EmptyOriginalValueIsPrintedCorrectly) { - const std::string hlo_string = R"(HloModule test - -ENTRY %test { - ROOT op = f32[] parameter(0), origin={} -} - - -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - ExpectHasSubstr(module->ToString(HloPrintOptions::Fingerprint()), - "origin={}"); -} - } // namespace } // namespace xla diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc index bbe1ecea736a3..d5af349ef6dec 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.cc +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_replace.h" @@ -119,7 +120,7 @@ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, instruction_can_change_layout_func_); TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); - return std::move(module); + return module; } /* static */ @@ -258,9 +259,11 @@ HloHardwareIndependentTestBase::RunAndCheckHloRewrite( VLOG(7) << "Input HLO: " << hlo_string; TF_ASSIGN_OR_RETURN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string)); + VLOG(7) << "Input HLO parsed. Running the pass: + " << hlo_pass.name(); TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); VLOG(7) << "Output HLO: " - << module->ToString(HloPrintOptions::ShortParsable()); + << module->ToString(HloPrintOptions::ShortParsable() + .set_print_control_dependencies(true)); EXPECT_EQ(changed, expect_change); return module; } diff --git a/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/xla/hlo/testlib/hlo_hardware_independent_test_base.h index 2a7f1f488b54e..e41bcea3e4d82 100644 --- a/xla/hlo/testlib/hlo_hardware_independent_test_base.h +++ b/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -55,6 +55,23 @@ class HloHardwareIndependentTestBase : public ::testing::Test { public: static PrecisionConfig DefaultPrecisionConfig(int operands); + // Gets the computation/instruction from the given module with the given name. + // Note that it is encouraged to use these functions directly via the + // hlo_query.h header instead since they are independent from any test-time + // variables or contexts. + + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + static HloComputation* FindComputation(HloModule* module, + absl::string_view name); + static HloInstruction* FindInstruction(HloModule* module, + absl::string_view name); + // Gets the instruction from the given module with the given opcode. + static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); + // Gets all the instructions from the given module with the given opcode. + static std::vector FindInstructions(HloModule* module, + HloOpcode opcode); + protected: explicit HloHardwareIndependentTestBase( bool verifier_layout_sensitive = false, @@ -199,22 +216,6 @@ class HloHardwareIndependentTestBase : public ::testing::Test { ->Clear(); } - // Gets the computation/instruction from the given module with the given name. - // Note that it is encouraged to use these functions directly via the - // hlo_query.h header instead since they are independent from any test-time - // variables or contexts. - - // This is useful for tests which create HLOs from a string and then want to - // inspect a particular computation or instruction. - static HloComputation* FindComputation(HloModule* module, - absl::string_view name); - static HloInstruction* FindInstruction(HloModule* module, - absl::string_view name); - // Gets the instruction from the given module with the given opcode. - static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); - // Gets all the instructions from the given module with the given opcode. - static std::vector FindInstructions(HloModule* module, - HloOpcode opcode); bool verifier_layout_sensitive() const { return verifier_layout_sensitive_; } void set_verifier_layout_sensitive(bool verifier_layout_sensitive) { diff --git a/xla/hlo/transforms/BUILD b/xla/hlo/transforms/BUILD index 774be3834cc41..2c8f45317a59e 100644 --- a/xla/hlo/transforms/BUILD +++ b/xla/hlo/transforms/BUILD @@ -1107,13 +1107,3 @@ xla_cc_test( "@tsl//tsl/platform:test_main", ], ) - -alias( - name = "hlo_dce", - actual = "//xla/hlo/transforms/simplifiers:hlo_dce", -) - -alias( - name = "dynamic_dimension_simplifier", - actual = "//xla/hlo/transforms/simplifiers:dynamic_dimension_simplifier", -) diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc b/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index 269284b021d5d..4b96bf2a81d50 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -5939,8 +5939,9 @@ AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( new_operands.push_back(operand); } } - VLOG(4) << "Sinking broadcast after user:" << "\n old broadcast: " - << broadcast->ToString() << "\n old user: " << user->ToString(); + VLOG(4) << "Sinking broadcast after user:" + << "\n old broadcast: " << broadcast->ToString() + << "\n old user: " << user->ToString(); changed_shape = ShapeUtil::ChangeElementType(operand->shape(), user->shape().element_type()); simplifier_->UpdateLayout(&changed_shape); @@ -8233,6 +8234,24 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status AlgebraicSimplifierVisitor::HandleReducePrecision( + HloInstruction* hlo) { + HloReducePrecisionInstruction* reduce_precision = + Cast(hlo); + PrimitiveType element_type = + reduce_precision->operand(0)->shape().element_type(); + if (options_.enable_remove_no_op_reduce_precision() && + reduce_precision->exponent_bits() == + primitive_util::ExponentWidth(element_type) && + reduce_precision->mantissa_bits() + 1 == + primitive_util::SignificandWidth(element_type)) { + return ReplaceInstruction( + /*old_instruction=*/hlo, + /*new_instruction=*/reduce_precision->mutable_operand(0)); + } + return absl::OkStatus(); +} + absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* hlo) { auto* reduce_window = Cast(hlo); diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier.h b/xla/hlo/transforms/simplifiers/algebraic_simplifier.h index 96c50ba251a94..f3ded542605db 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier.h +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier.h @@ -322,6 +322,16 @@ class AlgebraicSimplifierOptions { return enable_broadcast_degenerate_dimension_; } + void set_enable_remove_no_op_reduce_precision( + bool enable_remove_no_op_reduce_precision) { + enable_remove_no_op_reduce_precision_ = + enable_remove_no_op_reduce_precision; + } + + bool enable_remove_no_op_reduce_precision() const { + return enable_remove_no_op_reduce_precision_; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplifierOptions that can be later used in an @@ -364,6 +374,7 @@ class AlgebraicSimplifierOptions { bool disable_dynamic_slice_to_slice_conversion_{false}; bool enable_fast_math_{false}; bool enable_broadcast_degenerate_dimension_{true}; + bool enable_remove_no_op_reduce_precision_{false}; Metadata metadata_; }; @@ -484,6 +495,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { absl::Status HandleReduce(HloInstruction* hlo) override; + absl::Status HandleReducePrecision(HloInstruction* hlo) override; + absl::Status HandleReduceWindow(HloInstruction* hlo) override; absl::Status HandleReverse(HloInstruction* reverse) override; diff --git a/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc b/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 5b0519107ad65..e30822e37f578 100644 --- a/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc +++ b/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -12688,5 +12688,36 @@ TEST_F(AlgebraicSimplifierTest, TestNew123) { EXPECT_FALSE(simplifier.Run(module.get()).value()); } +TEST_F(AlgebraicSimplifierTest, + ReducePrecisionWithSamePrecisionAsOperandIsRemovedIfRemoveNoOpIsSet) { + const char* hlo = R"( + HloModule test + ENTRY main { + p0 = bf16[64]{0} parameter(0) + ROOT reduce-precision = bf16[64] reduce-precision(p0), exponent_bits=8, mantissa_bits=7 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + default_options_.set_enable_remove_no_op_reduce_precision(true); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter())); +} + +TEST_F(AlgebraicSimplifierTest, + ReducePrecisionWithDifferentPrecisionFromOperandIsNotModifiedByDefault) { + const char* hlo = R"( + HloModule test + ENTRY main { + p0 = bf16[64]{0} parameter(0) + ROOT reduce-precision = bf16[64] reduce-precision(p0), exponent_bits=7, mantissa_bits=8 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo)); + + default_options_.set_enable_remove_no_op_reduce_precision(true); + EXPECT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + } // namespace } // namespace xla diff --git a/xla/literal.h b/xla/literal.h index 0c028bd1aa60e..1b76f2effe6a9 100644 --- a/xla/literal.h +++ b/xla/literal.h @@ -367,9 +367,9 @@ class LiteralBase { static_assert(sizeof(H) == 0, "Do not use Literal directly as a hash key, because it has " "multiple definitions of equality - layout sensitive or " - "insensitive. Instead, provide an external hash function " - "that uses Literal::Hash which allows you to specify layout " - "sensitivity."); + "insensitive. Instead, use AbslHashable<...>() to create a " + "wrapper with layout sensitivity specified suitable for " + "passing to Absl::Hash"); } // Always use this together with the Equal method and not operator== in order @@ -419,6 +419,17 @@ class LiteralBase { return std::move(state); } + // Templated wrapper struct to control layout sensitivity during Absl::Hash. + template + struct AbslHashable { + const LiteralBase& literal; + explicit AbslHashable(const LiteralBase& l) : literal(l) {} + template + friend H AbslHashValue(H h, const AbslHashable& w) { + return LiteralBase::Hash(std::move(h), w.literal); + } + }; + // Converts this literal to the given shape. Returns an error is the // conversion is not possible. absl::StatusOr ConvertToShape(const Shape& dest_shape) const; diff --git a/xla/mlir/framework/transforms/outline_with_xla_framework.cc b/xla/mlir/framework/transforms/outline_with_xla_framework.cc index b960958a7d634..7d9b8fc700767 100644 --- a/xla/mlir/framework/transforms/outline_with_xla_framework.cc +++ b/xla/mlir/framework/transforms/outline_with_xla_framework.cc @@ -164,7 +164,7 @@ class OutlineWithXLAFrameworkPass patterns.add(ctx); // Set target. - if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) { + if (failed(applyPatternsGreedily(m, std::move(patterns)))) { signalPassFailure(); } m->walk([](func::FuncOp f) { diff --git a/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 064978aec3982..3a09b6e3b3381 100644 --- a/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -536,7 +536,7 @@ struct BufferReusePass : public impl::BufferReusePassBase { eliminateCopies(block, /*root=*/block); do { // Eliminate dead code. - (void)applyPatternsAndFoldGreedily(getOperation(), {}); + (void)applyPatternsGreedily(getOperation(), {}); // Only coalesce dealloc/alloc pairs that are immediate neighbors, to // make sure we don't accidentally extend the live range of a buffer. result = reuseBuffers(block, BufferReuseMode::CONSERVATIVE); @@ -547,7 +547,7 @@ struct BufferReusePass : public impl::BufferReusePassBase { // Now we can also coalesce distant dealloc/alloc pairs. reuseBuffers(block, BufferReuseMode::AGGRESSIVE); promoteBuffers(block); - (void)applyPatternsAndFoldGreedily(getOperation(), {}); + (void)applyPatternsGreedily(getOperation(), {}); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc b/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc index da27173913f81..c8268e4335dca 100644 --- a/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc +++ b/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc @@ -439,8 +439,8 @@ struct BroadcastPropagationPass GreedyRewriteConfig config; config.useTopDownTraversal = false; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc b/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc index 60fcd19885391..cbe532ba959f7 100644 --- a/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc +++ b/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc @@ -92,8 +92,7 @@ struct CollapseElementwiseMapPass MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc b/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc index e986bdc5ad694..79e55a4c9f3d5 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc @@ -68,8 +68,7 @@ struct LegalizeDotToDotGeneralPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateDotToDotGeneralPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index c35ce560146dc..e861dec331848 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -179,8 +179,7 @@ struct LegalizeEinsumToDotGeneralPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateEinsumToDotGeneralPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc index 8cc65ea23f04c..865c07fc316d8 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -139,8 +139,7 @@ struct LegalizeTorchIndexSelectToGatherPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTorchIndexSelectToGatherPatterns(&getContext(), &patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc b/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc index 2e7018e2fd17c..ccf2ed1151ccc 100644 --- a/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc +++ b/xla/mlir_hlo/mhlo/transforms/legalize_trigonometric_to_approximation/legalize_trigonometric_to_approximation.cc @@ -172,8 +172,7 @@ struct LegalizeTrigonometricToApproximationPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateTrigonometricToApproximationPatterns(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc b/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc index 185b2c9d7caa1..d6c4b4767297d 100644 --- a/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc +++ b/xla/mlir_hlo/mhlo/transforms/merge_assuming_ops/merge_assuming_ops.cc @@ -434,8 +434,8 @@ struct MergeAssumingOpsPass mhlo::populateMergeAssumingOpsPatterns(ctx, &patterns); GreedyRewriteConfig config; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc b/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc index deccadf230d5a..b86038624c4c2 100644 --- a/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc +++ b/xla/mlir_hlo/mhlo/transforms/mhlo_flatten_tuple/mhlo_flatten_tuple.cc @@ -132,8 +132,7 @@ class FlattenTuplePass : public impl::FlattenTuplePassBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc b/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc index b96370f71cf23..1747bd93b492e 100644 --- a/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc +++ b/xla/mlir_hlo/mhlo/transforms/shape_simplification/shape_simplification.cc @@ -242,7 +242,7 @@ struct ShapeSimplification ExtractFromBroadcastedTensorCanonicalizationPattern>(context); auto func = getOperation(); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc b/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc index 20808e4d12d9e..961e512d23968 100644 --- a/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc +++ b/xla/mlir_hlo/mhlo/transforms/symbolic_shape_optimization/symbolic_shape_optimization.cc @@ -793,8 +793,8 @@ class SymbolicShapeOptimizationPass final shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx); shape::ShapeOfOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed( + mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc b/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc index 8bd3bbc140961..d585ea0b9d159 100644 --- a/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/test_infer_shaped_type/test_infer_shaped_type_pass.cc @@ -95,8 +95,7 @@ struct TestInferShapedTypeMethodsPass RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); patterns.add(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc index 7409def78d770..285f056008da7 100644 --- a/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc +++ b/xla/mlir_hlo/mhlo/transforms/unfuse_batch_norm/unfuse_batch_norm_pass.cc @@ -43,8 +43,7 @@ struct TestUnfuseBatchNormPass RewritePatternSet patterns(&getContext()); populateUnfuseBatchNormInferencePattern(&getContext(), &patterns); populateUnfuseBatchNormTrainingPattern(&getContext(), &patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp index 0ad3029f96ccf..9cd3e90e6f5df 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_canonicalize_dynamism.cpp @@ -200,8 +200,7 @@ struct StablehloCanonicalizeDynamismPass patterns.add(&getContext()); auto funcOp = getOperation(); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns), config))) { funcOp.emitError("Failed to converge StablehloCanonicalizeDynamism in ") << config.maxIterations << " iterations"; return signalPassFailure(); diff --git a/xla/mlir_hlo/transforms/detensorize_scf_ops.cc b/xla/mlir_hlo/transforms/detensorize_scf_ops.cc index 2a8be4e6b09ae..12d8b3814646e 100644 --- a/xla/mlir_hlo/transforms/detensorize_scf_ops.cc +++ b/xla/mlir_hlo/transforms/detensorize_scf_ops.cc @@ -120,7 +120,7 @@ struct DetensorizeScfOpsPass patterns.add, RegionOpPattern, RegionOpPattern>(&getContext()); - if (failed(applyPatternsAndFoldGreedily(f, std::move(patterns)))) { + if (failed(applyPatternsGreedily(f, std::move(patterns)))) { signalPassFailure(); } } diff --git a/xla/mlir_hlo/transforms/generic_host_to_llvm.cc b/xla/mlir_hlo/transforms/generic_host_to_llvm.cc index 9df69afbaf55a..8cd4bf99f5133 100644 --- a/xla/mlir_hlo/transforms/generic_host_to_llvm.cc +++ b/xla/mlir_hlo/transforms/generic_host_to_llvm.cc @@ -86,7 +86,7 @@ class GenericHostToLLVMPass // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } LLVMConversionTarget target(*ctx); diff --git a/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc b/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc index 3e22aa5588832..d490588de4508 100644 --- a/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc +++ b/xla/mlir_hlo/transforms/gpu_kernel_lowering_passes.cc @@ -96,7 +96,7 @@ void GpuKernelToNVVMPass::runOnOperation() { { RewritePatternSet patterns(&getContext()); populateAllCommonVectorProgressiveLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } RewritePatternSet patterns(&getContext()); diff --git a/xla/mlir_hlo/transforms/lower_index_cast_pass.cc b/xla/mlir_hlo/transforms/lower_index_cast_pass.cc index 489d8fb4cb811..b773792e67b5c 100644 --- a/xla/mlir_hlo/transforms/lower_index_cast_pass.cc +++ b/xla/mlir_hlo/transforms/lower_index_cast_pass.cc @@ -64,8 +64,7 @@ struct LowerIndexCastPass patterns.add, IndexCastConverter>( patterns.getContext()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/transforms/naive_copy_removal.cc b/xla/mlir_hlo/transforms/naive_copy_removal.cc index 55ab2fbb2e0ee..a13f0396a85e6 100644 --- a/xla/mlir_hlo/transforms/naive_copy_removal.cc +++ b/xla/mlir_hlo/transforms/naive_copy_removal.cc @@ -80,7 +80,7 @@ struct NaiveCopyRemovalPass RewritePatternSet patterns(ctx); patterns.add(removeCopy); memref::AllocOp::getCanonicalizationPatterns(patterns, ctx); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/xla/mlir_hlo/transforms/tile_loops_pass.cc b/xla/mlir_hlo/transforms/tile_loops_pass.cc index ee3b935cff277..d6efd72d2437c 100644 --- a/xla/mlir_hlo/transforms/tile_loops_pass.cc +++ b/xla/mlir_hlo/transforms/tile_loops_pass.cc @@ -127,7 +127,7 @@ void TileLoopsPass::runOnOperation() { getContext() .getOrLoadDialect() ->getCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/xla/mlir_hlo/transforms/vectorize_copy.cc b/xla/mlir_hlo/transforms/vectorize_copy.cc index 1b68cd8b28b74..5650e83be0c2d 100644 --- a/xla/mlir_hlo/transforms/vectorize_copy.cc +++ b/xla/mlir_hlo/transforms/vectorize_copy.cc @@ -215,7 +215,7 @@ struct VectorizeCopyPass RewritePatternSet patterns(ctx); patterns.add( ctx, /*numElementsThreshold = */ 8); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/pjrt/c/CHANGELOG.md b/xla/pjrt/c/CHANGELOG.md index 5852c9a54dcc0..d56741eb3500b 100644 --- a/xla/pjrt/c/CHANGELOG.md +++ b/xla/pjrt/c/CHANGELOG.md @@ -1,4 +1,10 @@ # PJRT C API changelog + +## 0.61 +* Added ``PJRT_KeyValueTryGet`` to the KV store interface, + which is non-blocking and immediately returns an error if the + key is not found. + ## 0.60 * Added ``PJRT_Client_CreateBuffersForAsyncHostToDevice`` and ``PJRT_AsyncHostToDeviceTransferManager_TransferRawDataToSubBuffer``. diff --git a/xla/pjrt/c/pjrt_c_api.h b/xla/pjrt/c/pjrt_c_api.h index 36d82b0787ba4..f2fc3b1c507a3 100644 --- a/xla/pjrt/c/pjrt_c_api.h +++ b/xla/pjrt/c/pjrt_c_api.h @@ -80,7 +80,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 60 +#define PJRT_API_MINOR 61 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -351,6 +351,35 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueGetCallback_Args, typedef PJRT_Error* (*PJRT_KeyValueGetCallback)( PJRT_KeyValueGetCallback_Args* args); +// Same as KeyValueGet, but returns `NotFoundError` immediately if the key is +// not found. +typedef void (*PJRT_KeyValueTryGetCallback_ValueDeleter)(char* value); + +struct PJRT_KeyValueTryGetCallback_Args { + size_t struct_size; + PJRT_Extension_Base* extension_start; + const char* key; + size_t key_size; + PJRT_CallbackError* callback_error; + void* user_arg; + char* value; // out + size_t value_size; // out + // The caller needs to set a PJRT_KeyValueTryGetCallback_ValueDeleter to + // delete the value returned by PJRT_KeyValueTryGetCallback. The + // implementation is responsible for copying `value` and then calling + // value_deleter_callback. + PJRT_KeyValueTryGetCallback_ValueDeleter value_deleter_callback; // out +}; +PJRT_DEFINE_STRUCT_TRAITS(PJRT_KeyValueTryGetCallback_Args, + value_deleter_callback); + +// Requirements for PJRT_KeyValueTryGetCallback implementation: (1) Thread-safe. +// (2) The caller that provides the two callbacks is responsible for avoiding +// key collisions between different users of key-value store (i.e. between +// different plugins, but not between different nodes in one plugin). +typedef PJRT_Error* (*PJRT_KeyValueTryGetCallback)( + PJRT_KeyValueTryGetCallback_Args* args); + struct PJRT_KeyValuePutCallback_Args { size_t struct_size; PJRT_Extension_Base* extension_start; @@ -389,8 +418,15 @@ struct PJRT_Client_Create_Args { void* kv_put_user_arg; PJRT_Client* client; // out + + // Key-value try-get callback provided by the caller of PJRT_Client_Create. + // Same as key-value get callback, but returns `NotFoundError` immediately if + // the key is not found. + PJRT_KeyValueTryGetCallback kv_try_get_callback; + // Will be passed to `kv_try_get_callback` as `user_arg` argument. + void* kv_try_get_user_arg; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, client); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Client_Create_Args, kv_try_get_user_arg); // Creates and initializes a new PJRT_Client and returns in `client`. typedef PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args); diff --git a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 4f53c640a6a3d..68d36fdb7f5c8 100644 --- a/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -154,9 +154,9 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { options.num_nodes = num_nodes; options.allowed_devices = visible_devices; options.platform_name = platform_name; - options.kv_store = - pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, - args->kv_put_callback, args->kv_put_user_arg); + options.kv_store = pjrt::ToCppKeyValueStore( + args->kv_get_callback, args->kv_get_user_arg, args->kv_try_get_callback, + args->kv_try_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, diff --git a/xla/pjrt/c/pjrt_c_api_helpers.cc b/xla/pjrt/c/pjrt_c_api_helpers.cc index 2060a73a634a4..c5d4b92c1a541 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -797,6 +797,25 @@ static PJRT_KeyValueGetCFunc ToKVGetCFunc( }; } +static PJRT_KeyValueTryGetCFunc ToKVTryGetCFunc( + xla::KeyValueStoreInterface* kv_store) { + return [kv_store](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { + absl::StatusOr output = + kv_store->TryGet(absl::string_view(args->key, args->key_size)); + if (!output.ok()) { + absl::string_view message = output.status().message(); + return (*args->callback_error)( + StatusCodeToPjrtErrorCode(output.status().code()), message.data(), + message.size()); + } + args->value = new char[output->size()]; + std::copy(output->begin(), output->end(), args->value); + args->value_size = output->size(); + args->value_deleter_callback = &PjRtValueDeleterCallback; + return nullptr; + }; +} + static PJRT_KeyValuePutCFunc ToKVPutCFunc( xla::KeyValueStoreInterface* kv_store) { return [kv_store](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -828,6 +847,22 @@ static PJRT_KeyValueGetCallback ToCKVGetCallback( }; } +static PJRT_KeyValueTryGetCallback ToCKVTryGetCallback( + PJRT_KeyValueTryGetCFunc* kv_try_get_c_func) { + return [](PJRT_KeyValueTryGetCallback_Args* args) -> PJRT_Error* { + PJRT_KeyValueTryGetCFunc* kv_try_get_c_func = + reinterpret_cast(args->user_arg); + if (kv_try_get_c_func == nullptr) { + absl::Status status = xla::InvalidArgument( + "got nullptr for PJRT_KeyValueTryGet_Args.user_arg"); + return (*args->callback_error)(StatusCodeToPjrtErrorCode(status.code()), + status.message().data(), + status.message().size()); + } + return (*kv_try_get_c_func)(args); + }; +} + static PJRT_KeyValuePutCallback ToCKVPutCallback( PJRT_KeyValuePutCFunc* kv_put_c_func) { return [](PJRT_KeyValuePutCallback_Args* args) -> PJRT_Error* { @@ -848,9 +883,12 @@ std::unique_ptr ConvertToCKeyValueCallbacks( std::shared_ptr kv_store) { auto kv_callback_data = std::make_unique(); kv_callback_data->kv_get_c_func = ToKVGetCFunc(kv_store.get()); + kv_callback_data->kv_try_get_c_func = ToKVTryGetCFunc(kv_store.get()); kv_callback_data->kv_put_c_func = ToKVPutCFunc(kv_store.get()); kv_callback_data->c_kv_get = ToCKVGetCallback(&kv_callback_data->kv_get_c_func); + kv_callback_data->c_kv_try_get = + ToCKVTryGetCallback(&kv_callback_data->kv_try_get_c_func); kv_callback_data->c_kv_put = ToCKVPutCallback(&kv_callback_data->kv_put_c_func); kv_callback_data->kv_store = std::move(kv_store); diff --git a/xla/pjrt/c/pjrt_c_api_helpers.h b/xla/pjrt/c/pjrt_c_api_helpers.h index 709558fba465a..d7a4286571b73 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/xla/pjrt/c/pjrt_c_api_helpers.h @@ -218,6 +218,9 @@ int GetId(const PJRT_Api* api, PJRT_DeviceDescription* device_desc); using PJRT_KeyValueGetCFunc = std::function; +using PJRT_KeyValueTryGetCFunc = + std::function; + using PJRT_KeyValuePutCFunc = std::function; @@ -228,17 +231,21 @@ struct PJRT_KeyValueCallbackData { std::shared_ptr kv_store; - // kv_get_c_func and kv_put_c_func are holding pointers to kv_store. + // kv_get_c_func, kv_try_get_c_func and kv_put_c_func are holding pointers to + // kv_store. pjrt::PJRT_KeyValueGetCFunc kv_get_c_func; pjrt::PJRT_KeyValuePutCFunc kv_put_c_func; - // c_kv_get and c_kv_put are holding pointers to kv_get_c_func and - // kv_put_c_func. + // c_kv_get, c_kv_try_get and c_kv_put are holding pointers to kv_get_c_func, + // kv_try_get_c_func and kv_put_c_func. PJRT_KeyValueGetCallback c_kv_get; PJRT_KeyValuePutCallback c_kv_put; + pjrt::PJRT_KeyValueTryGetCFunc kv_try_get_c_func; + PJRT_KeyValueTryGetCallback c_kv_try_get; }; -// The returned &kv_get_c_func and &kv_put_c_func must be set as -// PJRT_Client_Create_Args.kv_get_user_arg and +// The returned &kv_get_c_func, &kv_try_get_c_func and &kv_put_c_func must be +// set as PJRT_Client_Create_Args.kv_get_user_arg, +// PJRT_Client_Create_Args.kv_try_get_user_arg and // PJRT_Client_Create_Args.kv_put_user_arg, respectively. The entire // PJRT_KeyValueCallbackData must be kept alive as long as c_kv_get and c_kv_put // may be called. diff --git a/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 4b8a59287589e..6dfce81a1e451 100644 --- a/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -108,14 +108,22 @@ TEST(PjRtCApiHelperTest, Callback) { auto kv_callback_data = ConvertToCKeyValueCallbacks(kv_store); auto converted_kv_store = ToCppKeyValueStore( kv_callback_data->c_kv_get, &kv_callback_data->kv_get_c_func, + kv_callback_data->c_kv_try_get, &kv_callback_data->kv_try_get_c_func, kv_callback_data->c_kv_put, &kv_callback_data->kv_put_c_func); + auto v_not_found = converted_kv_store->Get("key", absl::Seconds(1)); + EXPECT_TRUE(absl::IsNotFound(v_not_found.status())) << v_not_found.status(); + auto s = converted_kv_store->Set("key", "value"); TF_EXPECT_OK(s); auto v = converted_kv_store->Get("key", absl::Seconds(1)); TF_EXPECT_OK(v.status()); EXPECT_EQ(*v, "value"); + + auto v_2 = converted_kv_store->TryGet("key"); + TF_EXPECT_OK(v.status()); + EXPECT_EQ(*v, "value"); } TEST(PjRtCApiHelperTest, ConvertToCLayoutFromStrides) { diff --git a/xla/pjrt/c/pjrt_c_api_test_base.cc b/xla/pjrt/c/pjrt_c_api_test_base.cc index 9602813c573c5..f867846ebcbd5 100644 --- a/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -47,9 +47,11 @@ PJRT_Client* CreateClient(const PJRT_Api* api) { create_args.create_options = nullptr; create_args.num_options = 0; create_args.kv_get_callback = nullptr; + create_args.kv_get_user_arg = nullptr; create_args.kv_put_callback = nullptr; create_args.kv_put_user_arg = nullptr; - create_args.kv_get_user_arg = nullptr; + create_args.kv_try_get_callback = nullptr; + create_args.kv_try_get_user_arg = nullptr; PJRT_Error* error = api->PJRT_Client_Create(&create_args); CHECK_EQ(error, nullptr); CHECK_NE(create_args.client, nullptr); diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 64aa20bac3c0e..f832fad0c997c 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -235,9 +235,13 @@ static absl::Status PopulateExecutableOutputMemoryKinds( class CApiKeyValueStore : public xla::KeyValueStoreInterface { public: CApiKeyValueStore(PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, + void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) : c_get_callback_(c_get_callback), get_user_arg_(get_user_arg), + c_try_get_callback_(c_try_get_callback), + try_get_user_arg_(try_get_user_arg), c_put_callback_(c_put_callback), put_user_arg_(put_user_arg) {} @@ -264,6 +268,27 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { return result; } + absl::StatusOr TryGet(absl::string_view key) override { + PJRT_CallbackError callback_error = [](PJRT_Error_Code code, + const char* message, + size_t message_size) { + return new PJRT_Error{absl::Status(static_cast(code), + std::string(message, message_size))}; + }; + PJRT_KeyValueTryGetCallback_Args args; + args.key = key.data(); + args.key_size = key.size(); + args.callback_error = &callback_error; + args.user_arg = try_get_user_arg_; + std::unique_ptr error(c_try_get_callback_(&args)); + if (error != nullptr) { + return error->status; + } + auto result = std::string(args.value, args.value_size); + args.value_deleter_callback(args.value); + return result; + } + absl::Status Set(absl::string_view key, absl::string_view value) override { PJRT_CallbackError callback_error = [](PJRT_Error_Code code, const char* message, @@ -288,18 +313,23 @@ class CApiKeyValueStore : public xla::KeyValueStoreInterface { private: PJRT_KeyValueGetCallback c_get_callback_; void* get_user_arg_; + PJRT_KeyValueTryGetCallback c_try_get_callback_; + void* try_get_user_arg_; PJRT_KeyValuePutCallback c_put_callback_; void* put_user_arg_; }; std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg) { - if (c_get_callback == nullptr || c_put_callback == nullptr) { + if (c_get_callback == nullptr || c_try_get_callback == nullptr || + c_put_callback == nullptr) { return nullptr; } - return std::make_shared(c_get_callback, get_user_arg, - c_put_callback, put_user_arg); + return std::make_shared( + c_get_callback, get_user_arg, c_try_get_callback, try_get_user_arg, + c_put_callback, put_user_arg); } // ---------------------------------- Errors ----------------------------------- diff --git a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index 04463410ee7e0..27b1cac051dbd 100644 --- a/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -464,6 +464,7 @@ PJRT_Client* CreateWrapperClient(std::unique_ptr cpp_client); // Helper functions for converting C key-value store callbacks to C++ callbacks. std::shared_ptr ToCppKeyValueStore( PJRT_KeyValueGetCallback c_get_callback, void* get_user_arg, + PJRT_KeyValueTryGetCallback c_try_get_callback, void* try_get_user_arg, PJRT_KeyValuePutCallback c_put_callback, void* put_user_arg); // A method that does not nothing other than returning a nullptr. Can be used as diff --git a/xla/pjrt/cpu/BUILD b/xla/pjrt/cpu/BUILD index ba0265eaed3c2..1ce663e34dc1d 100644 --- a/xla/pjrt/cpu/BUILD +++ b/xla/pjrt/cpu/BUILD @@ -1,6 +1,6 @@ load("//xla:xla.bzl", "xla_cc_test") load("//xla/pjrt/cpu:package_groups.bzl", "xla_cpu_internal_packages") -load("//xla/tsl:tsl.bzl", "if_oss", "internal_visibility") +load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl/platform:rules_cc.bzl", "cc_library") package( @@ -298,8 +298,11 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:gloo_communicator", + "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", @@ -361,33 +364,42 @@ xla_cc_test( cc_library( name = "mpi_collectives", - srcs = if_oss(["mpi_collectives.cc"]), - hdrs = if_oss(["mpi_collectives.h"]), + srcs = ["mpi_collectives.cc"], + hdrs = ["mpi_collectives.h"], + compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", + # copybara:uncomment_begin(google-only) + # "-Ithird_party/openmpi/ompi/include", + # copybara:uncomment_end ], features = ["-use_header_modules"], visibility = [ "//xla/pjrt/cpu:legacy_cpu_internal_users", ], - deps = if_oss([ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", + deps = [ "//xla:shape_util", "//xla:status_macros", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:mpi_communicator", + "//xla/core/collectives:communicator", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/service/cpu:collectives_interface", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@mpitrampoline", "@tsl//tsl/platform:errors", "@tsl//tsl/platform:logging", - "@mpitrampoline", - ]), + ], ) diff --git a/xla/pjrt/cpu/cpu_client.h b/xla/pjrt/cpu/cpu_client.h index 2a1517a1b53fc..e325e15e29137 100644 --- a/xla/pjrt/cpu/cpu_client.h +++ b/xla/pjrt/cpu/cpu_client.h @@ -202,13 +202,6 @@ class TfrtCpuClient final : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::StatusOr CreateChannelHandle() override { - return Unimplemented("CreateChannelHandle not implemented."); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented("CreateDeviceToHostChannelHandle not implemented."); - } - absl::Status Defragment() override { return Unimplemented("Defragment not implemented."); } diff --git a/xla/pjrt/cpu/gloo_collectives.cc b/xla/pjrt/cpu/gloo_collectives.cc index 02e5602dd28f2..09451f220b97d 100644 --- a/xla/pjrt/cpu/gloo_collectives.cc +++ b/xla/pjrt/cpu/gloo_collectives.cc @@ -15,13 +15,8 @@ limitations under the License. #include "xla/pjrt/cpu/gloo_collectives.h" -#include -#include -#include -#include #include #include -#include #include #include #include @@ -33,419 +28,19 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" #include "absl/types/span.h" -#include "gloo/algorithm.h" -#include "gloo/allgather.h" -#include "gloo/allreduce.h" #include "gloo/context.h" -#include "gloo/math.h" -#include "gloo/reduce_scatter.h" #include "gloo/rendezvous/context.h" #include "gloo/rendezvous/prefix_store.h" #include "gloo/rendezvous/store.h" #include "gloo/transport/device.h" -#include "gloo/transport/unbound_buffer.h" -#include "gloo/types.h" -#include "xla/backends/cpu/collectives/cpu_collectives.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/primitive_util.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" +#include "xla/backends/cpu/collectives/gloo_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -GlooCollectivesCommunicator::GlooCollectivesCommunicator( - std::shared_ptr context) - : context_(std::move(context)) {} -GlooCollectivesCommunicator::~GlooCollectivesCommunicator() = default; - -template -static absl::Status SetAllReduceOptions(ReductionKind reduction_kind, - se::DeviceMemoryBase input_buffer, - se::DeviceMemoryBase output_buffer, - size_t num_elements, - gloo::AllreduceOptions& options) { - options.setInput( - reinterpret_cast(const_cast(input_buffer.opaque())), - num_elements); - options.setOutput( - reinterpret_cast(const_cast(output_buffer.opaque())), - num_elements); - - using ReductionFn = void (*)(void*, const void*, const void*, size_t); - - switch (reduction_kind) { - case ReductionKind::SUM: - options.setReduceFunction(static_cast(&gloo::sum)); - break; - case ReductionKind::PRODUCT: - options.setReduceFunction(static_cast(&gloo::product)); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - options.setReduceFunction(static_cast(&gloo::min)); - } else { - return absl::InvalidArgumentError( - "MIN reduction not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - options.setReduceFunction(static_cast(&gloo::max)); - } else { - return absl::InvalidArgumentError( - "MAX reduction not supported for complex types"); - } - break; - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::AllReduce( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - - gloo::AllreduceOptions options(context_); - // TODO(phawkins): how to do tags? - // options.setTag(tag); - switch (dtype) { - case S8: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case S16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case U16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case S32: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case U32: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case S64: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case U64: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case F16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case BF16: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case F32: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case F64: - TF_RETURN_IF_ERROR(SetAllReduceOptions( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case C64: - TF_RETURN_IF_ERROR(SetAllReduceOptions>( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - case C128: - TF_RETURN_IF_ERROR(SetAllReduceOptions>( - reduction_kind, send_buffer, recv_buffer, count, options)); - break; - default: - return absl::InvalidArgumentError("Unknown datatype in allreduce"); - } - options.setAlgorithm(gloo::AllreduceOptions::Algorithm::RING); - options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); - - try { - gloo::allreduce(options); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo all-reduce failed: ", e.what())); - } - return absl::OkStatus(); -} - -static constexpr uint8_t kCollectivePermuteSlotPrefix = 0x40; - -absl::Status GlooCollectivesCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { - uint32_t tag = 0; // TODO(phawkins): come up with better tags. - const auto slot = gloo::Slot::build(kCollectivePermuteSlotPrefix, tag); - - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - size_t num_bytes = count * primitive_util::ByteWidth(dtype); - - try { - std::unique_ptr in; - std::unique_ptr out; - for (RankId target : target_ranks) { - if (target != context_->rank) { - VLOG(1) << "send from " << context_->rank << " to " << target.value(); - if (!in) { - in = context_->createUnboundBuffer(send_buffer.opaque(), num_bytes); - } - in->send(target.value(), slot); - } - } - if (source_rank) { - if (*source_rank == context_->rank) { - std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); - } else { - VLOG(1) << "recv at " << context_->rank << " from " - << source_rank->value(); - out = context_->createUnboundBuffer(recv_buffer.opaque(), num_bytes); - out->recv(source_rank->value(), slot); - } - } else { - std::memset(recv_buffer.opaque(), 0, num_bytes); - } - VLOG(1) << "wait for send at " << context_->rank; - auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); - if (in) { - in->waitSend(deadline); - } - VLOG(1) << "wait for recv at " << context_->rank; - if (out) { - out->waitRecv(deadline); - } - VLOG(1) << "done waiting at " << context_->rank; - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo collective permute failed: ", e.what())); - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) { - // We can't use Gloo's all-to-all implementation directly because it assumes - // that the inputs and outputs are contiguous. No big deal; it's just built - // on top of send/recv and we can do the same as it. - uint32_t tag = 0; // TODO(phawkins): use better tags. - int my_rank = context_->rank; - int world_size = context_->size; - - TF_RET_CHECK(world_size == send_buffers.size()); - TF_RET_CHECK(world_size == recv_buffers.size()); - - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - try { - const auto slot = gloo::Slot::build(gloo::kAlltoallSlotPrefix, tag); - std::vector> ins( - context_->size); - std::vector> outs( - context_->size); - for (size_t i = 0; i < world_size; ++i) { - if (i != my_rank) { - ins[i] = context_->createUnboundBuffer( - const_cast(send_buffers[i].opaque()), chunk_bytes); - outs[i] = context_->createUnboundBuffer( - const_cast(recv_buffers[i].opaque()), chunk_bytes); - } - } - - for (int i = 1; i < world_size; i++) { - int send_rank = (my_rank + i) % world_size; - int recv_rank = (my_rank + world_size - i) % world_size; - ins[send_rank]->send(send_rank, slot); - outs[recv_rank]->recv(recv_rank, slot); - } - - std::memcpy(const_cast(recv_buffers[my_rank].opaque()), - send_buffers[my_rank].opaque(), chunk_bytes); - - auto deadline = absl::ToChronoTime(absl::Now() + cpu_executor->timeout()); - for (int i = 0; i < world_size; i++) { - if (i != my_rank) { - ins[i]->waitSend(deadline); - outs[i]->waitRecv(deadline); - } - } - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo all-to-all failed: ", e.what())); - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::AllGather( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, const Executor& executor) { - uint32_t tag = 0; // TODO(phawkins): use better tags. - - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - gloo::AllgatherOptions options(context_); - options.setTag(tag); - options.setTimeout(absl::ToChronoMilliseconds(cpu_executor->timeout())); - options.setInput(reinterpret_cast(send_buffer.opaque()), chunk_bytes); - options.setOutput(reinterpret_cast(recv_buffer.opaque()), - chunk_bytes * context_->size); - - try { - gloo::allgather(options); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo AllGather failed: ", e.what())); - } - return absl::OkStatus(); -} - -template -absl::Status ReduceScatterHelper(std::shared_ptr context, - ReductionKind reduction_kind, void* buffer, - size_t chunk_elems) { - const gloo::ReductionFunction* reduction_function = nullptr; - if constexpr (is_complex_v) { - switch (reduction_kind) { - case ReductionKind::SUM: - reduction_function = gloo::ReductionFunction::sum; - break; - case ReductionKind::PRODUCT: - reduction_function = gloo::ReductionFunction::product; - break; - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported reduction kind: ", static_cast(reduction_kind))); - } - } else { - switch (reduction_kind) { - case ReductionKind::SUM: - reduction_function = gloo::ReductionFunction::sum; - break; - case ReductionKind::PRODUCT: - reduction_function = gloo::ReductionFunction::product; - break; - case ReductionKind::MAX: - reduction_function = gloo::ReductionFunction::max; - break; - case ReductionKind::MIN: - reduction_function = gloo::ReductionFunction::min; - break; - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported reduction kind: ", static_cast(reduction_kind))); - } - } - try { - std::vector recv_elems(context->size, chunk_elems); - gloo::ReduceScatterHalvingDoubling algorithm( - context, std::vector{reinterpret_cast(buffer)}, - chunk_elems * context->size, recv_elems, reduction_function); - algorithm.run(); - } catch (std::exception& e) { - return absl::UnknownError( - absl::StrCat("Gloo ReduceScatter failed: ", e.what())); - } - return absl::OkStatus(); -} - -absl::Status GlooCollectivesCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - std::unique_ptr temp(new char[chunk_bytes * context_->size]); - std::memcpy(temp.get(), send_buffer.opaque(), chunk_bytes * context_->size); - switch (dtype) { - case S8: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case PRED: - case U8: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case S16: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case U16: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case S32: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case U32: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case S64: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case U64: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case BF16: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case F16: - TF_RETURN_IF_ERROR(ReduceScatterHelper( - context_, reduction_kind, temp.get(), count)); - break; - case F32: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case F64: - TF_RETURN_IF_ERROR(ReduceScatterHelper(context_, reduction_kind, - temp.get(), count)); - break; - case C64: - TF_RETURN_IF_ERROR(ReduceScatterHelper>( - context_, reduction_kind, temp.get(), count)); - break; - case C128: - TF_RETURN_IF_ERROR(ReduceScatterHelper>( - context_, reduction_kind, temp.get(), count)); - break; - default: - return absl::InvalidArgumentError("Unknown datatype in reducescatter"); - } - std::memcpy(recv_buffer.opaque(), temp.get(), chunk_bytes); - return absl::OkStatus(); -} - GlooCollectives::GlooCollectives( std::unique_ptr store, std::shared_ptr device) @@ -453,8 +48,7 @@ GlooCollectives::GlooCollectives( GlooCollectives::~GlooCollectives() = default; -absl::StatusOr> -GlooCollectives::GetCommunicator( +absl::StatusOr> GlooCollectives::GetCommunicator( absl::Span global_devices, int rank) { Context* context; { @@ -487,8 +81,8 @@ GlooCollectives::GetCommunicator( return absl::UnknownError( absl::StrCat("Gloo context initialization failed: ", e.what())); } - context->communicator = - std::make_shared(std::move(gloo_context)); + context->communicator = std::make_shared( + std::move(gloo_context), rank, global_devices.size()); return context->communicator; } diff --git a/xla/pjrt/cpu/gloo_collectives.h b/xla/pjrt/cpu/gloo_collectives.h index 401ad0c54f728..174cdb48acceb 100644 --- a/xla/pjrt/cpu/gloo_collectives.h +++ b/xla/pjrt/cpu/gloo_collectives.h @@ -16,61 +16,26 @@ limitations under the License. #ifndef XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ #define XLA_PJRT_CPU_GLOO_COLLECTIVES_H_ -#include #include -#include #include #include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "absl/time/time.h" #include "absl/types/span.h" #include "gloo/context.h" #include "gloo/rendezvous/store.h" #include "gloo/transport/device.h" -#include "xla/service/collective_ops_utils.h" +#include "xla/backends/cpu/collectives/gloo_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class GlooCollectivesCommunicator : public CollectivesCommunicator { - public: - explicit GlooCollectivesCommunicator(std::shared_ptr context); - ~GlooCollectivesCommunicator() override; - - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - private: - std::shared_ptr context_; -}; - class GlooCollectives : public CollectivesInterface { public: GlooCollectives(std::unique_ptr store, @@ -78,17 +43,19 @@ class GlooCollectives : public CollectivesInterface { ~GlooCollectives() override; // Thread-safe. - absl::StatusOr> GetCommunicator( + absl::StatusOr> GetCommunicator( absl::Span devices, int rank) override; private: - std::unique_ptr store_; - std::shared_ptr device_; - absl::Mutex mu_; struct Context { absl::Mutex mu; - std::shared_ptr communicator; + std::shared_ptr communicator; }; + + std::unique_ptr store_; + std::shared_ptr device_; + + absl::Mutex mu_; absl::flat_hash_map, int>, std::unique_ptr> contexts_ ABSL_GUARDED_BY(mu_); diff --git a/xla/pjrt/cpu/gloo_collectives_test.cc b/xla/pjrt/cpu/gloo_collectives_test.cc index 4537b1073fb56..e4c79982beeaa 100644 --- a/xla/pjrt/cpu/gloo_collectives_test.cc +++ b/xla/pjrt/cpu/gloo_collectives_test.cc @@ -59,7 +59,7 @@ constexpr int kNumParticipants = 2; constexpr size_t kBufferSize = 256; constexpr absl::Duration kTimeout = absl::Seconds(5); -absl::StatusOr> GetCommunicator( +absl::StatusOr> GetCommunicator( size_t kNumParticipants, absl::Span global_devices, const std::shared_ptr& kv_store, int rank) { auto collectives = std::make_shared( diff --git a/xla/pjrt/cpu/mpi_collectives.cc b/xla/pjrt/cpu/mpi_collectives.cc index aaf1ebe6bb581..88dc69a31917d 100644 --- a/xla/pjrt/cpu/mpi_collectives.cc +++ b/xla/pjrt/cpu/mpi_collectives.cc @@ -15,242 +15,25 @@ limitations under the License. #include "xla/pjrt/cpu/mpi_collectives.h" -#include -#include -#include -#include -#include #include -#include -#include #include -#include #include -#include "mpi.h" // NOLINT +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_join.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/primitive_util.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" +#include "mpi.h" +#include "xla/backends/cpu/collectives/mpi_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" namespace xla::cpu { -absl::StatusOr PrimitiveTypeToMpiType( - PrimitiveType element_type) { - switch (element_type) { - case S8: - return MPI_INT8_T; - case U8: - case PRED: - return MPI_UINT8_T; - case S16: - return MPI_INT16_T; - case U16: - return MPI_UINT16_T; - case S32: - return MPI_INT32_T; - case U32: - return MPI_UINT32_T; - case S64: - return MPI_INT64_T; - case U64: - return MPI_UINT64_T; - case F32: - return MPI_FLOAT; - case F64: - return MPI_DOUBLE; - case C64: - return MPI_C_COMPLEX; - case C128: - return MPI_C_DOUBLE_COMPLEX; - default: - // For implementing the reduction of unsupported types - // see e.g. https://stackoverflow.com/a/29643391 - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported primitive type for reduction: ", - primitive_util::LowercasePrimitiveTypeName(element_type))); - } -} - -bool MpiTypeIsComplex(MPI_Datatype type) { - return type == MPI_C_COMPLEX || type == MPI_C_DOUBLE_COMPLEX; -} - -absl::StatusOr ReductionKindToMpiOp(ReductionKind reduction_kind, - MPI_Datatype type) { - switch (reduction_kind) { - case ReductionKind::SUM: - return MPI_SUM; - case ReductionKind::PRODUCT: - return MPI_PROD; - case ReductionKind::MIN: - if (!MpiTypeIsComplex(type)) { - return MPI_MIN; - } else { - return absl::InvalidArgumentError( - "MIN reduction not supported for complex types"); - } - case ReductionKind::MAX: - if (!MpiTypeIsComplex(type)) { - return MPI_MAX; - } else { - return absl::InvalidArgumentError( - "MAX reduction not supported for complex types"); - } - default: - return absl::InvalidArgumentError( - absl::StrCat("Unknown reduction kind: ", reduction_kind)); - } -} - -static absl::Status MpiErrorToAbslStatus(int error) { - if (error != MPI_SUCCESS) { - char error_str[MPI_MAX_ERROR_STRING]; - int len; - MPI_Error_string(error, error_str, &len); - return absl::UnknownError(absl::StrCat("MPI error: ", error_str)); - } - return absl::OkStatus(); -} - -MpiCollectivesCommunicator::MpiCollectivesCommunicator(int color, int key) { - MPI_Comm_split(MPI_COMM_WORLD, color, key, &comm_); - MPI_Comm_rank(comm_, &mpi_rank_); - MPI_Comm_size(comm_, &mpi_size_); -} - -MpiCollectivesCommunicator::~MpiCollectivesCommunicator() { - MPI_Comm_free(&comm_); -}; - -absl::Status MpiCollectivesCommunicator::AllReduce( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus(MPI_Allreduce( - send_buffer.opaque(), recv_buffer.opaque(), count, type, op, comm_)); -} - -absl::Status MpiCollectivesCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { - int tag = 0; // TODO come up with better tags. - - const int rank = mpi_rank_; - - std::vector requests; - - size_t num_bytes = count * primitive_util::ByteWidth(dtype); - - if (source_rank) { - if (source_rank->value() == rank) { - std::memcpy(recv_buffer.opaque(), send_buffer.opaque(), num_bytes); - } else { - VLOG(1) << "recv at " << rank << " from " << source_rank->value(); - requests.emplace_back(); - TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Irecv(recv_buffer.opaque(), num_bytes, MPI_BYTE, - source_rank->value(), tag, comm_, &requests.back()))); - } - } else { - std::memset(recv_buffer.opaque(), 0, num_bytes); - } - - for (RankId target : target_ranks) { - if (target != rank) { - VLOG(1) << "send from " << rank << " to " << target.value(); - requests.emplace_back(); - TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Isend(send_buffer.opaque(), num_bytes, MPI_BYTE, target.value(), - tag, comm_, &requests.back()))); - } - } - - for (auto& request : requests) { - TF_RETURN_IF_ERROR( - MpiErrorToAbslStatus(MPI_Wait(&request, MPI_STATUS_IGNORE))); - } - - return absl::OkStatus(); -} - -absl::Status MpiCollectivesCommunicator::AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) { - // We can't use MPI_Alltoall directly because it assumes that the inputs and - // outputs are contiguous. Therefore here we implement it using MPI_Sendrecv. - - int tag = 0; // TODO use better tags. - const int rank = mpi_rank_; - const int size = mpi_size_; - TF_RET_CHECK(size == send_buffers.size()); - TF_RET_CHECK(size == recv_buffers.size()); - - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - std::vector input_buffers; - std::vector output_buffers; - - for (int i = 0; i < size; i++) { - input_buffers.push_back(const_cast(send_buffers[i].opaque())); - output_buffers.push_back(const_cast(recv_buffers[i].opaque())); - } - - std::memcpy(output_buffers[rank], input_buffers[rank], chunk_bytes); - - for (int i = 1; i < size; i++) { - int send_rank = (rank + i) % size; - int recv_rank = (rank + size - i) % size; - TF_RETURN_IF_ERROR(MpiErrorToAbslStatus( - MPI_Sendrecv(input_buffers[send_rank], chunk_bytes, MPI_BYTE, send_rank, - tag, output_buffers[recv_rank], chunk_bytes, MPI_BYTE, - recv_rank, tag, comm_, MPI_STATUS_IGNORE))); - } - - return absl::OkStatus(); -} - -absl::Status MpiCollectivesCommunicator::AllGather( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, const Executor& executor) { - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - return MpiErrorToAbslStatus(MPI_Allgather(send_buffer.opaque(), count, type, - recv_buffer.opaque(), count, type, - comm_)); -} - -absl::Status MpiCollectivesCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - const int size = mpi_size_; - std::vector recvcounts(size, count); - TF_ASSIGN_OR_RETURN(MPI_Datatype type, PrimitiveTypeToMpiType(dtype)); - TF_ASSIGN_OR_RETURN(MPI_Op op, ReductionKindToMpiOp(reduction_kind, type)); - return MpiErrorToAbslStatus( - MPI_Reduce_scatter(send_buffer.opaque(), recv_buffer.opaque(), - recvcounts.data(), type, op, comm_)); -} - void MpiCollectives::Init() { int provided; - MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &provided); + MPI_Init_thread(nullptr, nullptr, MPI_THREAD_FUNNELED, &provided); MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank_); MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size_); VLOG(1) << "MPI rank=" << mpi_world_rank_ << " size=" << mpi_world_size_; @@ -261,16 +44,15 @@ void MpiCollectives::Finalize() { MPI_Finalize(); } -absl::StatusOr> -MpiCollectives::GetCommunicator(absl::Span global_devices, - int rank) { +absl::StatusOr> MpiCollectives::GetCommunicator( + absl::Span global_devices, int rank) { int flag; MPI_Is_thread_main(&flag); if (!flag) { return absl::UnknownError( - absl::StrCat("MPI: Communicator requested from a thread that is not " - "the one MPI was initialized from. Multiple " - "threads/devices per process are not yet supported.")); + "MPI: Communicator requested from a thread that is not " + "the one MPI was initialized from. Multiple " + "threads/devices per process are not yet supported."); } auto& context = contexts_[std::make_tuple( @@ -288,7 +70,7 @@ MpiCollectives::GetCommunicator(absl::Span global_devices, } else { color = MPI_UNDEFINED; } - context = std::make_shared(color, key); + context = std::make_shared(color, key); return context; } diff --git a/xla/pjrt/cpu/mpi_collectives.h b/xla/pjrt/cpu/mpi_collectives.h index 8058c5f38077e..5db5f13f410bd 100644 --- a/xla/pjrt/cpu/mpi_collectives.h +++ b/xla/pjrt/cpu/mpi_collectives.h @@ -16,60 +16,22 @@ limitations under the License. #ifndef XLA_PJRT_CPU_MPI_COLLECTIVES_H_ #define XLA_PJRT_CPU_MPI_COLLECTIVES_H_ -#include #include -#include #include #include -#include "mpi.h" // NOLINT -#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/service/collective_ops_utils.h" +#include "xla/backends/cpu/collectives/mpi_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -class MpiCollectivesCommunicator : public CollectivesCommunicator { - public: - explicit MpiCollectivesCommunicator(int color, int key); - ~MpiCollectivesCommunicator() override; - - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - private: - MPI_Comm comm_; - int mpi_rank_; - int mpi_size_; -}; - class MpiCollectives : public CollectivesInterface { public: /* @@ -84,7 +46,7 @@ class MpiCollectives : public CollectivesInterface { void Init(); void Finalize(); - absl::StatusOr> GetCommunicator( + absl::StatusOr> GetCommunicator( absl::Span global_devices, int rank) override; private: @@ -94,7 +56,7 @@ class MpiCollectives : public CollectivesInterface { int mpi_world_rank_; int mpi_world_size_; absl::flat_hash_map, int>, - std::shared_ptr> + std::shared_ptr> contexts_; }; diff --git a/xla/pjrt/distributed/client.cc b/xla/pjrt/distributed/client.cc index 280c60873e9d0..305afe7ae4c6d 100644 --- a/xla/pjrt/distributed/client.cc +++ b/xla/pjrt/distributed/client.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -53,6 +54,7 @@ class DistributedRuntimeCoordinationServiceClient absl::Status Shutdown() override; absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) override; + absl::StatusOr KeyValueTryGet(absl::string_view key) override; absl::StatusOr>> KeyValueDirGet(absl::string_view key) override; absl::Status KeyValueSet(absl::string_view key, @@ -144,6 +146,12 @@ DistributedRuntimeCoordinationServiceClient::BlockingKeyValueGet( return coord_agent_->GetKeyValue(key, timeout); } +absl::StatusOr +DistributedRuntimeCoordinationServiceClient::KeyValueTryGet( + absl::string_view key) { + return coord_agent_->TryGetKeyValue(key); +} + absl::StatusOr>> DistributedRuntimeCoordinationServiceClient::KeyValueDirGet( absl::string_view key) { @@ -216,6 +224,10 @@ class DistributedKeyValueStore : public KeyValueStoreInterface { return client_->BlockingKeyValueGet(absl::StrCat(prefix_, key), timeout); } + absl::StatusOr TryGet(absl::string_view key) override { + return client_->KeyValueTryGet(absl::StrCat(prefix_, key)); + } + absl::Status Set(absl::string_view key, absl::string_view value) override { return client_->KeyValueSet(absl::StrCat(prefix_, key), value); } diff --git a/xla/pjrt/distributed/client.h b/xla/pjrt/distributed/client.h index e597ff158cc67..58f4fe367681d 100644 --- a/xla/pjrt/distributed/client.h +++ b/xla/pjrt/distributed/client.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" @@ -116,6 +117,9 @@ class DistributedRuntimeClient { virtual absl::StatusOr BlockingKeyValueGet( absl::string_view key, absl::Duration timeout) = 0; + // Returns `NotFoundError` immediately if the key is not found. + virtual absl::StatusOr KeyValueTryGet(absl::string_view key) = 0; + // Get all key-value pairs under a directory (key). // A value is considered to be in the directory if its key is prefixed with // the directory. diff --git a/xla/pjrt/distributed/client_server_test.cc b/xla/pjrt/distributed/client_server_test.cc index f5b7e656fe69a..baec103eced93 100644 --- a/xla/pjrt/distributed/client_server_test.cc +++ b/xla/pjrt/distributed/client_server_test.cc @@ -1029,6 +1029,20 @@ TEST_F(ClientServerTest, KeyValueSet_Duplicate_Overwrites) { EXPECT_EQ(result.value(), "overwritten_value"); } +TEST_F(ClientServerTest, KeyValueTryGet) { + StartService(/*num_nodes=*/1); + auto client = GetClient(/*node_id=*/0); + TF_ASSERT_OK(client->Connect()); + + ASSERT_THAT(client->KeyValueTryGet("test_key").status(), + StatusIs(absl::StatusCode::kNotFound)); + + TF_ASSERT_OK(client->KeyValueSet("test_key", "value")); + auto result = client->KeyValueTryGet("test_key"); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(result.value(), "value"); +} + TEST_F(ClientServerTest, KeyValueDelete) { StartService(/*num_nodes=*/1); auto client = GetClient(/*node_id=*/0); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.cc b/xla/pjrt/distributed/in_memory_key_value_store.cc index 70cc5360ecf7b..49fc73ec87f16 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.cc +++ b/xla/pjrt/distributed/in_memory_key_value_store.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -40,6 +41,17 @@ absl::StatusOr InMemoryKeyValueStore::Get(absl::string_view key, return kv_store_.find(key)->second; } +absl::StatusOr InMemoryKeyValueStore::TryGet( + absl::string_view key) { + absl::MutexLock lock(&mu_); + auto it = kv_store_.find(key); + if (it == kv_store_.end()) { + return absl::NotFoundError( + absl::StrCat(key, " is not found in the kv store.")); + } + return it->second; +} + absl::Status InMemoryKeyValueStore::Set(absl::string_view key, absl::string_view value) { absl::MutexLock lock(&mu_); diff --git a/xla/pjrt/distributed/in_memory_key_value_store.h b/xla/pjrt/distributed/in_memory_key_value_store.h index 1530633a98b75..13f50c722bd12 100644 --- a/xla/pjrt/distributed/in_memory_key_value_store.h +++ b/xla/pjrt/distributed/in_memory_key_value_store.h @@ -21,7 +21,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/time.h" #include "xla/pjrt/distributed/key_value_store_interface.h" namespace xla { @@ -31,6 +33,8 @@ class InMemoryKeyValueStore : public KeyValueStoreInterface { absl::StatusOr Get(absl::string_view key, absl::Duration timeout) override; + absl::StatusOr TryGet(absl::string_view key) override; + absl::Status Set(absl::string_view key, absl::string_view value) override; private: diff --git a/xla/pjrt/distributed/key_value_store_interface.h b/xla/pjrt/distributed/key_value_store_interface.h index 29580fb86847b..312ebb8abb646 100644 --- a/xla/pjrt/distributed/key_value_store_interface.h +++ b/xla/pjrt/distributed/key_value_store_interface.h @@ -38,11 +38,18 @@ class KeyValueStoreInterface { virtual ~KeyValueStoreInterface() = default; // Blocking Get(). + // Useful for listening for a key-value pair that may be set later on. // There are no concurrency guarantees. To avoid a race / impose an ordering // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). virtual absl::StatusOr Get(absl::string_view key, absl::Duration timeout) = 0; + // Returns `NotFoundError` immediately if the key is not found. + // Useful for checking key existence. + // There are no concurrency guarantees. To avoid a race / impose an ordering + // on potentially concurrent ops (e.g. set, delete), use WaitAtBarrier(). + virtual absl::StatusOr TryGet(absl::string_view key) = 0; + virtual absl::Status Set(absl::string_view key, absl::string_view value) = 0; }; diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc index b7dea23fe13c3..00e242434f437 100644 --- a/xla/pjrt/pjrt_c_api_client.cc +++ b/xla/pjrt/pjrt_c_api_client.cc @@ -2599,6 +2599,8 @@ absl::StatusOr> WrapClientAroundCApi( kv_callback_data = pjrt::ConvertToCKeyValueCallbacks(kv_store); init_args.kv_get_callback = kv_callback_data->c_kv_get; init_args.kv_get_user_arg = &kv_callback_data->kv_get_c_func; + init_args.kv_try_get_callback = kv_callback_data->c_kv_try_get; + init_args.kv_try_get_user_arg = &kv_callback_data->kv_try_get_c_func; init_args.kv_put_callback = kv_callback_data->c_kv_put; init_args.kv_put_user_arg = &kv_callback_data->kv_put_c_func; } diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h index 03e41ec398590..fe98aa5ecce39 100644 --- a/xla/pjrt/pjrt_c_api_client.h +++ b/xla/pjrt/pjrt_c_api_client.h @@ -401,28 +401,12 @@ class PjRtCApiClient : public PjRtClient { "this feature."); } - absl::StatusOr CreateChannelHandle() override { - return Unimplemented( - "PJRT C API does not support CreateChannelHandle. Please report an " - "issue at https://github.com/google/jax/issues if you need this " - "feature."); - } - - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return Unimplemented( - "PJRT C API does not support CreateDeviceToHostChannelHandle. Please " - "report an issue at https://github.com/google/jax/issues if you need " - "this feature."); - } - absl::Status Defragment() override { return Unimplemented( "PJRT C API does not support Defragment. Please report an issue at " "https://github.com/google/jax/issues if you need this feature."); } - bool SupportsSendRecvCallbacks() const override { return true; } - const PJRT_Api* pjrt_c_api() const; PJRT_Client* pjrt_c_client() { return c_client_.get(); } diff --git a/xla/pjrt/pjrt_client.h b/xla/pjrt/pjrt_client.h index 0b1da9ef4660a..c0a07ae66d4e5 100644 --- a/xla/pjrt/pjrt_client.h +++ b/xla/pjrt/pjrt_client.h @@ -1070,25 +1070,12 @@ class PjRtClient { "MakeCrossHostReceiveBuffersForGather is not implemented."); } - // Create ChannelHandles for XLA send/recv. - virtual absl::StatusOr CreateChannelHandle() { - return Unimplemented("CreateChannelHandle is not implemented."); - } - virtual absl::StatusOr CreateDeviceToHostChannelHandle() { - return Unimplemented("CreateDeviceToHostChannelHandle is not implemented."); - } - // TODO(zhangqiaorjc): Experimental API to be removed. // Defragment device memory. virtual absl::Status Defragment() { return Unimplemented("Defragment is not implemented."); } - // If false, this client does not support send/recv host callbacks, and - // callers should not set the `send_callbacks` and `recv_callbacks` arguments - // in ExecuteOptions. - virtual bool SupportsSendRecvCallbacks() const { return false; } - // Return the PjRtHostMemoryForDeviceManager for this client. It can be // nullptr if the implementation does not provide one. virtual PjRtHostMemoryForDeviceManager* GetPjRtHostMemoryForDeviceManager() diff --git a/xla/pjrt/pjrt_executable.h b/xla/pjrt/pjrt_executable.h index fc4f76ef4776a..1244039ede0cd 100644 --- a/xla/pjrt/pjrt_executable.h +++ b/xla/pjrt/pjrt_executable.h @@ -101,7 +101,9 @@ struct CompileOptions { // Key-value string pairs, parsed in order to set miscellaneous options, // overriding if appropriate. using OptionOverride = std::variant; - std::vector> env_option_overrides; + using EnvironmentOptionOverrides = + std::vector>; + EnvironmentOptionOverrides env_option_overrides; std::optional target_config; diff --git a/xla/pjrt/pjrt_stream_executor_client.cc b/xla/pjrt/pjrt_stream_executor_client.cc index 39b0d9740afc9..35a8267ae1486 100644 --- a/xla/pjrt/pjrt_stream_executor_client.cc +++ b/xla/pjrt/pjrt_stream_executor_client.cc @@ -347,32 +347,11 @@ void StallStreamOnError(LocalDeviceState* local_device, se::Stream* stream) { // after the usage of device_buffer was enqueued. // usage_stream: the stream the operation using device_buffer // was enqueued on. -// prefer_to_retain_reference: relevant only for the compute synchronous -// allocation model. If true, retain a reference -// to device_buffer until after the operation -// completes. If false then the compute stream -// will have to be synchronized past event before -// device_buffer can be freed. -// -// prefer_to_retain_reference encodes a heuristic set by the caller for the -// compute synchronous model: -// -// Generally when a buffer is the destination of a copy to a device, it will -// subsequently be used on the device's compute stream before being freed. In -// that case, there is no need to retain a reference to the buffer. If the -// buffer is freed before being used on the compute stream, the free will be -// delayed until the host knows that event has completed, but this is expected -// to be uncommon. -// -// When a buffer is the source of a copy from a device, we need to either retain -// a reference to the buffer until the copy completes or serialize the compute -// stream behind the copy. It is often better to retain a reference since while -// that keeps memory alive longer, it avoids stalling the compute stream. void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, LocalDeviceState* buffer_local_device, LocalDeviceState* stream_local_device, std::shared_ptr event, - se::Stream* usage_stream, bool prefer_to_retain_reference, + se::Stream* usage_stream, std::vector>* buffers_to_release = nullptr) { tsl::profiler::TraceMe traceme("RecordUsage"); @@ -382,11 +361,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, (stream_local_device != buffer_local_device) || // In the synchronous allocation model, always retain a reference. (stream_local_device->allocation_model() == - LocalDeviceState::kSynchronous) || - // In the compute synchronous model, use the caller's heuristic. - (stream_local_device->allocation_model() == - LocalDeviceState::kComputeSynchronized && - prefer_to_retain_reference); + LocalDeviceState::kSynchronous); if (retain_buffer_until_completion) { if (buffers_to_release) { buffers_to_release->push_back(device_buffer.buffer()); @@ -415,15 +390,8 @@ absl::Status AddDestinationBufferSynchronization( } definition_event->SetSequencingEvent(std::move(event_or).value(), copy_stream); - // prefer_to_retain_reference=false means don't retain a memory reference - // until the transfer is complete when using the ComputeSynchronized - // allocation model. This is a heuristic because in the common case - // destination buffers will be used on the compute stream and therefore don't - // require any synchronization before being freed. If the buffer is allocated - // and never used, the free will take longer and this is assumed to be ok. RecordUsage(std::move(device_buffer), local_device, local_device, - definition_event, copy_stream, - /*prefer_to_retain_reference=*/false); + definition_event, copy_stream); return absl::OkStatus(); } @@ -583,16 +551,9 @@ AllocateDestinationBuffer( if (on_device_shape.IsTuple()) { // Add a usage hold for the tuple table write and immediately convert it to - // the appropriate form of synchronization. prefer_to_retain_reference=false - // means don't retain a memory reference until the transfer is complete when - // using the ComputeSynchronized allocation model. This is a heuristic - // because in the common case destination buffers will be used on the - // compute stream and therefore don't require any synchronization before - // being freed. If the buffer is allocated and never used, the free will - // take longer and this is assumed to be ok. + // the appropriate form of synchronization. RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, - definition_events.back(), tuple_table_stream, - /*prefer_to_retain_reference=*/false); + definition_events.back(), tuple_table_stream); } return py_buffer; @@ -1954,8 +1915,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( std::move(async_copy_to_device)); RecordUsage(std::move(dst_device_buffer), transfer_local_device, - transfer_local_device, copy_event, transfer_stream, - /*prefer_to_retain_reference=*/false); + transfer_local_device, copy_event, transfer_stream); return std::pair, std::shared_ptr>( @@ -2039,12 +1999,6 @@ PjRtStreamExecutorBuffer::CopyToDeviceMemorySpace( std::unique_ptr& buffer = buffer_and_event.first; std::shared_ptr& event = buffer_and_event.second; - // prefer_to_retain_reference=*/true means that, when using the - // ComputeSynchronized allocation model, retain a reference to the - // src_device_buffer until the copy completes. This is a heuristic; the - // alternative is to ensure, before freeing the buffer, that the compute - // stream is synchronized past the transfer, but it seems better to hold onto - // the buffer too long than to stall the compute stream. src_device_buffer.ConvertUsageHold(transfer_stream, event, /*reference_held=*/true); @@ -2340,7 +2294,7 @@ absl::StatusOr> OutputBufferHelper( memory_space); RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), - /*prefer_to_retain_reference=*/false, &buffers_to_release); + &buffers_to_release); return std::unique_ptr(std::move(pjrt_buffer)); } @@ -3118,14 +3072,9 @@ PjRtStreamExecutorLoadedExecutable::ExecuteHelper( buffers_to_release)); for (PjRtStreamExecutorBuffer::ScopedHold& b : device_buffers) { - // prefer_to_retain_reference=false because when using the - // ComputeSynchronized allocation model we don't need to retain a reference - // to the device_buffer during execution because by definition the compute - // stream is synchronized past the execution. if (b.type() == PjRtStreamExecutorBuffer::ScopedHold::kUsage) { RecordUsage(std::move(b), device_state, device_state, definition_event, - stream, - /*prefer_to_retain_reference=*/false, &buffers_to_release); + stream, &buffers_to_release); } else { CHECK(b.type() == PjRtStreamExecutorBuffer::ScopedHold::kDonation); b.ConfirmDonation(); diff --git a/xla/pjrt/pjrt_stream_executor_client.h b/xla/pjrt/pjrt_stream_executor_client.h index 394777b07ff47..f753df6d6fcc2 100644 --- a/xla/pjrt/pjrt_stream_executor_client.h +++ b/xla/pjrt/pjrt_stream_executor_client.h @@ -394,13 +394,6 @@ class PjRtStreamExecutorClient : public PjRtClient { std::function on_delete_callback, std::optional stream) override; - absl::StatusOr CreateChannelHandle() override { - return client()->CreateChannelHandle(); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return client()->CreateDeviceToHostChannelHandle(); - } - // TODO(zhangqiaorjc): Experimental. Will be removed. absl::Status Defragment() override { return Unimplemented("Defragment not implemented"); diff --git a/xla/pjrt/tf_pjrt_client.h b/xla/pjrt/tf_pjrt_client.h index 8933a2482c868..49b8d5db5e92e 100644 --- a/xla/pjrt/tf_pjrt_client.h +++ b/xla/pjrt/tf_pjrt_client.h @@ -340,12 +340,6 @@ class TfPjRtClient : public PjRtClient { return wrapped_->MakeCrossHostReceiveBuffersForGather( shapes, std::move(gather_details), device, std::move(notifier)); } - absl::StatusOr CreateChannelHandle() override { - return wrapped_->CreateChannelHandle(); - } - absl::StatusOr CreateDeviceToHostChannelHandle() override { - return wrapped_->CreateDeviceToHostChannelHandle(); - } absl::StatusOr GetTopologyDescription() const override { return wrapped_->GetTopologyDescription(); diff --git a/xla/python/BUILD b/xla/python/BUILD index 4b024b2c8c3df..2b32267047719 100644 --- a/xla/python/BUILD +++ b/xla/python/BUILD @@ -12,6 +12,7 @@ load( "//xla/tsl:tsl.bzl", "if_cuda_or_rocm", "if_google", + "if_oss", "internal_visibility", ) load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable", "tsl_pybind_extension") @@ -735,7 +736,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/hash", "@com_google_absl//absl/status", @@ -1358,9 +1358,8 @@ tsl_pybind_extension( }) + select({ # mpitrampoline does not build on windows "//xla/tsl:windows": [], - "//conditions:default": [ - "//xla/pjrt/cpu:mpi_collectives", - ], + # we support MPI collectives only in OSS builds + "//conditions:default": if_oss(["//xla/pjrt/cpu:mpi_collectives"]), }), ) diff --git a/xla/python/ifrt/support/module_parsing.cc b/xla/python/ifrt/support/module_parsing.cc index b1740cd5cf0ca..8d6efaf1a4a56 100644 --- a/xla/python/ifrt/support/module_parsing.cc +++ b/xla/python/ifrt/support/module_parsing.cc @@ -52,6 +52,7 @@ void RegisterMlirDialects(mlir::MLIRContext& context) { mlir::DialectRegistry registry; InitializeMlirDialectRegistry(registry); context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); } absl::StatusOr> ParseMlirModuleString( diff --git a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc index ab36e6c0f17f6..2c8d52e7e7cff 100644 --- a/xla/python/ifrt_proxy/client/grpc_host_buffer.cc +++ b/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -105,8 +105,10 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, } if (!writer->WritesDone()) { + writer->Finish().IgnoreError(); promise.Set( absl::InternalError("Failed to write all host buffer chunks")); + return; } } @@ -150,6 +152,7 @@ Future<> GrpcClientHostBufferStore::Store(uint64_t handle, } } if (!writer->WritesDone()) { + writer->Finish().IgnoreError(); return Future<>( absl::InternalError("Failed to write all host buffer chunks")); } diff --git a/xla/python/pjit.cc b/xla/python/pjit.cc index d492311a81ba4..88c3d7c9bd5fb 100644 --- a/xla/python/pjit.cc +++ b/xla/python/pjit.cc @@ -34,7 +34,6 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/status/status.h" @@ -325,6 +324,11 @@ class PjitFunction { executables_->Clear(); } + std::shared_ptr executables() { + nb::ft_object_guard lock(cache_); + return executables_; + } + nb::object PythonSignature() { if (!fun_.has_value()) { throw nb::value_error( @@ -362,41 +366,6 @@ class PjitFunction { std::shared_ptr executables_; }; -// Thread-safe. -class PjitFunctionStore { - public: - void Insert(PjitFunction* function) { - nb::ft_lock_guard lock(mu_); - compiled_functions_.insert(function); - } - - void Erase(PjitFunction* function) { - nb::ft_lock_guard lock(mu_); - compiled_functions_.erase(function); - } - - void ClearFunctionCache() { - absl::flat_hash_set functions; - { - nb::ft_lock_guard lock(mu_); - std::swap(functions, compiled_functions_); - } - for (auto* function : functions) { - function->ClearCache(); - } - } - - private: - // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. - nb::ft_mutex mu_; - absl::flat_hash_set compiled_functions_; -}; - -PjitFunctionStore& GetGlobalPjitFunctionStore() { - static auto* const store = new PjitFunctionStore(); - return *store; -} - PjitFunction::PjitFunction( std::string function_name, std::optional fun, nb::callable cache_miss, std::vector static_argnums, @@ -418,8 +387,6 @@ PjitFunction::PjitFunction( PyUnicode_InternInPlace(&s); static_argnames_.push_back(nb::steal(s)); } - - GetGlobalPjitFunctionStore().Insert(this); } void PjitFunction::InitExecutables() { @@ -432,7 +399,7 @@ void PjitFunction::InitExecutables() { } } -PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } +PjitFunction::~PjitFunction() = default; void CallShardArgFallback( nb::handle arg, nb::handle sharding, nb::handle layout, @@ -969,8 +936,64 @@ struct PjitFunctionObject { #endif // PY_VERSION_HEX < 0x030C0000 vectorcallfunc vectorcall; PjitFunction fun; + + // Doubly-linked list of PjitFunctionObjects, protected by + // PjitFunctionStore::mu_ or the GIL in GIL mode. + PjitFunctionObject* next; + PjitFunctionObject* prev; }; +// Contains a list of all PjitFunctionObjects. +// Thread-safe. +class PjitFunctionStore { + public: + void Insert(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + o->next = compiled_functions_; + o->prev = nullptr; + if (o->next) { + o->next->prev = o; + } + compiled_functions_ = o; + } + + void Remove(PjitFunctionObject* o) { + nb::ft_lock_guard lock(mu_); + if (o->next) { + o->next->prev = o->prev; + } + if (o->prev) { + o->prev->next = o->next; + } else { + compiled_functions_ = o->next; + } + } + + void ClearCaches() { + std::vector< + std::pair>> + caches; + { + nb::ft_lock_guard lock(mu_); + for (PjitFunctionObject* fn = compiled_functions_; fn != nullptr; + fn = fn->next) { + caches.emplace_back(fn->fun.cache(), fn->fun.executables()); + } + } + for (auto& [cache, executables] : caches) { + nb::ft_object_guard lock(cache); + executables->Clear(); + } + }; + + private: + // Protected by the GIL in GIL mode, and by mu_ in freethreading mode. + nb::ft_mutex mu_; + PjitFunctionObject* compiled_functions_; +}; + +PjitFunctionStore pjit_function_store; + PyObject* PjitFunction_Type = nullptr; bool PjitFunction::IsPjitFunction(nb::handle handle) { @@ -1036,6 +1059,7 @@ void PjitFunction_tp_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); PyTypeObject* tp = Py_TYPE(self); PjitFunctionObject* o = reinterpret_cast(self); + pjit_function_store.Remove(o); PyObject_ClearWeakRefs(self); #if PY_VERSION_HEX < 0x030C0000 Py_CLEAR(o->dict); @@ -1125,6 +1149,7 @@ void InitializePjitFunction( xla::nb_class_ptr pytree_registry, nb::callable shard_arg_fallback, xla::nb_class_ptr cache) { + fn_obj->next = fn_obj->prev = nullptr; if (nb::isinstance(global_cache_key)) { global_cache_key = nb::tuple(global_cache_key); } @@ -1136,6 +1161,10 @@ void InitializePjitFunction( // Handled separately because it is not exception safe to call this // in the constructor because it leaves the object improperly constructed. fn_obj->fun.InitExecutables(); + + // Only add the executable to the store after executables_ has been + // initialized. We want only fully constructed executables in the store. + pjit_function_store.Insert(fn_obj); } nb::object MakePjitFunction( @@ -1201,8 +1230,7 @@ void BuildPjitSubmodule(nb::module_& m) { cache.def("size", &PjitFunctionCache::Size, nb::lock_self()); cache.def("capacity", &PjitFunctionCache::Capacity, nb::lock_self()); cache.def("clear", &PjitFunctionCache::Clear, nb::lock_self()); - cache.def_static("clear_all", - []() { GetGlobalPjitFunctionStore().ClearFunctionCache(); }); + cache.def_static("clear_all", []() { pjit_function_store.ClearCaches(); }); cache.def( "__getstate__", // Pickles as an empty cache; the client can repopulate as needed. diff --git a/xla/python/pjrt_ifrt/pjrt_memory.cc b/xla/python/pjrt_ifrt/pjrt_memory.cc index 8edb3bfa29fe2..5217eb72b1fbd 100644 --- a/xla/python/pjrt_ifrt/pjrt_memory.cc +++ b/xla/python/pjrt_ifrt/pjrt_memory.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" @@ -29,6 +30,7 @@ namespace ifrt { char PjRtCompatibleMemory::ID = 0; char PjRtMemory::ID = 0; +char PjRtMemoryDescription::ID = 0; PjRtMemory::PjRtMemory(PjRtClient* client, xla::PjRtMemorySpace* pjrt_memory) : client_(client), pjrt_memory_(pjrt_memory), kind_(pjrt_memory->kind()) { @@ -51,6 +53,29 @@ absl::string_view PjRtMemory::DebugString() const { absl::Span PjRtMemory::Devices() const { return devices_; } +PjRtMemoryDescription::PjRtMemoryDescription( + PjRtClient* client, absl::Span devices, + const xla::PjRtMemorySpaceDescription* desc) + : desc_(desc), kind_(desc->kind()) { + for (auto device : devices) { + devices_.push_back(device); + } +} + +MemoryId PjRtMemoryDescription::Id() const { + return MemoryId(desc_->kind_id()); +} + +const MemoryKind& PjRtMemoryDescription::Kind() const { return kind_; } + +absl::string_view PjRtMemoryDescription::ToString() const { + return desc_->kind(); +} + +absl::string_view PjRtMemoryDescription::DebugString() const { + return desc_->kind(); +} + MemoryKind CanonicalizeMemoryKindWithPjRtDevice(MemoryKind memory_kind, xla::PjRtDevice* device) { if (memory_kind.memory_kind().has_value()) { diff --git a/xla/python/pjrt_ifrt/pjrt_memory.h b/xla/python/pjrt_ifrt/pjrt_memory.h index 3964ac56b184d..f6517f9e191d9 100644 --- a/xla/python/pjrt_ifrt/pjrt_memory.h +++ b/xla/python/pjrt_ifrt/pjrt_memory.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_device_description.h" #include "xla/python/ifrt/memory.h" namespace xla { @@ -60,6 +61,30 @@ class PjRtMemory final std::vector devices_; }; +class PjRtMemoryDescription final + : public llvm::RTTIExtends { + public: + PjRtMemoryDescription(PjRtClient* client, absl::Span devices, + const xla::PjRtMemorySpaceDescription* desc); + + PjRtClient* client() const { return client_; } + xla::PjRtMemorySpace* pjrt_memory() override { return nullptr; } + + MemoryId Id() const override; + const MemoryKind& Kind() const override; + absl::string_view ToString() const override; + absl::string_view DebugString() const override; + absl::Span Devices() const override { return devices_; } + + static char ID; // NOLINT + + private: + PjRtClient* client_; + const xla::PjRtMemorySpaceDescription* desc_; + MemoryKind kind_; + std::vector devices_; +}; + // Canonicalizes `MemoryKind`. If `MemoryKind` has no memory kind chosen, // returns a default `MemoryKind` chosen for the PjRt device. If there is no // default indicated by the device, simply returns `MemoryKind` with no memory diff --git a/xla/python/pmap_lib.cc b/xla/python/pmap_lib.cc index 3999b7b7473a6..609cee2deb46f 100644 --- a/xla/python/pmap_lib.cc +++ b/xla/python/pmap_lib.cc @@ -432,8 +432,10 @@ class PmapFunction { // passed to the underlying PyLoadedExecutable. In sorted order. std::vector static_argnums_; xla::nb_class_ptr pytree_registry_; - // We need a `unique_ptr` here to ensure value pointer stability. - absl::flat_hash_map> + // We need a `shared_ptr` here to ensure value pointer stability, and to + // ensure that the cache entry remains alive in the presence of concurrent + // removals. + absl::flat_hash_map> executables_; // The fallback function to use with `ShardArgs`. @@ -581,15 +583,14 @@ absl::StatusOr PmapFunction::Call(nb::handle callable, } // Retrieve/Maybe add the executable to the cache. - absl::flat_hash_map>::iterator - it; - bool inserted; - std::tie(it, inserted) = executables_.try_emplace( - call_signature, std::unique_ptr()); - if (inserted) { - it->second = std::make_unique(pytree_registry_.get()); + bool inserted = false; + std::shared_ptr& cache_entry_ptr = + executables_[call_signature]; + if (cache_entry_ptr == nullptr) { + inserted = true; + cache_entry_ptr = std::make_shared(pytree_registry_.get()); } - PmapCacheEntry& cache_entry = *(it->second); + PmapCacheEntry& cache_entry = *cache_entry_ptr; if (!cache_entry.compilation_complete.HasBeenNotified()) { // In case of several threads attempting to compile the executable, only diff --git a/xla/python/xla.cc b/xla/python/xla.cc index 219d6704b4f79..46ecfb4a6dd4f 100644 --- a/xla/python/xla.cc +++ b/xla/python/xla.cc @@ -337,8 +337,8 @@ NB_MODULE(xla_extension, m) { [](bool asynchronous, std::shared_ptr distributed_client, int node_id, int num_nodes, - std::shared_ptr collectives) - -> nb_class_ptr { + std::shared_ptr collectives, + std::optional num_devices) -> nb_class_ptr { std::unique_ptr ifrt_client; { nb::gil_scoped_release gil_release; @@ -347,6 +347,7 @@ NB_MODULE(xla_extension, m) { options.asynchronous = asynchronous; options.collectives = std::move(collectives); options.process_id = node_id; + options.cpu_device_count = num_devices; std::unique_ptr client = xla::ValueOrThrow(xla::GetXlaPjrtCpuClient(std::move(options))); ifrt::PjRtClient::CreateOptions ifrt_options; @@ -367,7 +368,8 @@ NB_MODULE(xla_extension, m) { nb::arg("asynchronous") = true, nb::arg("distributed_client") = nullptr, nb::arg("node_id") = 0, nb::arg("num_nodes") = 1, nb::arg("collectives").none() = - std::shared_ptr()); + std::shared_ptr(), + nb::arg("num_devices").none() = std::nullopt); m.def("pjrt_plugin_loaded", [](std::string platform_name) -> bool { absl::StatusOr pjrt_api = pjrt::PjrtApi(platform_name); return pjrt_api.ok(); @@ -672,6 +674,21 @@ NB_MODULE(xla_extension, m) { return nb::bytes(result.data(), result.size()); }, nb::arg("key"), nb::arg("timeout_in_ms")) + .def( + "key_value_try_get", + [](DistributedRuntimeClient& client, std::string key) { + nb::gil_scoped_release gil_release; + return xla::ValueOrThrow(client.KeyValueTryGet(key)); + }, + nb::arg("key")) + .def( + "key_value_try_get_bytes", + [](DistributedRuntimeClient& client, std::string key) -> nb::bytes { + nb::gil_scoped_release gil_release; + std::string result = xla::ValueOrThrow(client.KeyValueTryGet(key)); + return nb::bytes(result.data(), result.size()); + }, + nb::arg("key")) .def( "wait_at_barrier", [](DistributedRuntimeClient& client, std::string barrier_id, diff --git a/xla/python/xla_client.py b/xla/python/xla_client.py index 040c781cd087d..46dd4a72edd1e 100644 --- a/xla/python/xla_client.py +++ b/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 302 +_version = 303 # Version number for MLIR:Python components. mlir_api_version = 57 @@ -70,7 +70,8 @@ def make_cpu_client( distributed_client=None, node_id=0, num_nodes=1, - collectives=None + collectives=None, + num_devices=None, ) -> ...: register_custom_call_handler('cpu', _xla.register_custom_call_target) register_custom_type_id_handler('cpu', _xla.register_custom_type_id) @@ -80,6 +81,7 @@ def make_cpu_client( node_id=node_id, num_nodes=num_nodes, collectives=collectives, + num_devices=num_devices, ) diff --git a/xla/python/xla_client.pyi b/xla/python/xla_client.pyi index cac63a98c1b2d..efc3d2573b222 100644 --- a/xla/python/xla_client.pyi +++ b/xla/python/xla_client.pyi @@ -89,6 +89,7 @@ def make_cpu_client( node_id: int = ..., num_nodes: int = ..., collectives: _xla.CpuCollectives | None = ..., + num_devices: int | None = ..., ) -> Client: ... diff --git a/xla/python/xla_client_test.py b/xla/python/xla_client_test.py index 35b4a1ee77964..f0cecc9903295 100644 --- a/xla/python/xla_client_test.py +++ b/xla/python/xla_client_test.py @@ -2757,6 +2757,8 @@ def testDevices(self): def testLocalDevices(self): self.assertNotEmpty(self.backend.local_devices()) + if self.backend.platform == "cpu": + self.assertLen(self.backend.local_devices(), 2) def testGetAllDevices(self): # TODO(hyeontaek): Remove this method once we have a unified API for @@ -3692,7 +3694,7 @@ def InstantiateTests(globals_dict, backend_fn, test_prefix="", **kw): backends = { - "cpu": xla_client.make_cpu_client, + "cpu": functools.partial(xla_client.make_cpu_client, num_devices=2), "gpu": xla_client.make_gpu_client, } diff --git a/xla/python/xla_extension/__init__.pyi b/xla/python/xla_extension/__init__.pyi index 2e3862285898f..67eadd44c14a4 100644 --- a/xla/python/xla_extension/__init__.pyi +++ b/xla/python/xla_extension/__init__.pyi @@ -607,6 +607,7 @@ def get_tfrt_cpu_client( node_id: int = ..., num_nodes: int = ..., collectives: Optional[CpuCollectives] = ..., + num_devices: int | None = ..., ) -> Client: ... def get_gpu_client( asynchronous: bool = ..., @@ -830,6 +831,8 @@ class DistributedRuntimeClient: def blocking_key_value_get_bytes( self, key: str, timeout_in_ms: int ) -> _Status: ... + def key_value_try_get(self, key: str) -> _Status: ... + def key_value_try_get_bytes(self, key: str) -> _Status: ... def key_value_dir_get(self, key: str) -> _Status: ... def key_value_dir_get_bytes(self, key: str) -> _Status: ... def key_value_set(self, key: str, value: str, diff --git a/xla/service/BUILD b/xla/service/BUILD index 8cd9cac1da809..bf03c3f792992 100644 --- a/xla/service/BUILD +++ b/xla/service/BUILD @@ -1837,21 +1837,21 @@ xla_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], deps = [ + ":buffer_value", + "//xla:literal_util", "//xla:shape_util", "//xla:test_helpers", - "//xla:types", "//xla:xla_data_proto_cc", - "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", - "@tsl//tsl/platform:statusor", ], ) @@ -2024,14 +2024,22 @@ xla_cc_test( ":hlo_creation_utils", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:array2d", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@tsl//tsl/platform:test", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", ], ) @@ -2230,13 +2238,16 @@ xla_cc_test( shard_count = 12, deps = [ ":triangular_solve_expander", + "//xla:array2d", + "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:reference_util", - "//xla:test", - "//xla:types", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", ], ) @@ -3493,25 +3504,35 @@ xla_cc_test( name = "hlo_module_test", srcs = ["hlo_module_test.cc"], deps = [ + ":buffer_value", ":computation_placer_hdr", + ":hlo_module_config", ":test_compilation_environment_proto_cc", - "//xla:literal", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/transforms/simplifiers:hlo_memory_scheduler", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/platform:statusor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:protobuf", ], ) @@ -4636,10 +4657,12 @@ cc_library( "//xla/tsl/platform:statusor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:die_if_null", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:casts", ], diff --git a/xla/service/cpu/BUILD b/xla/service/cpu/BUILD index 5dbad693e7847..012112662640a 100644 --- a/xla/service/cpu/BUILD +++ b/xla/service/cpu/BUILD @@ -669,6 +669,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1386,6 +1387,7 @@ xla_cc_test( tags = ["not_run:arm"], deps = [ ":cpu_instruction_fusion", + "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1539,6 +1541,7 @@ xla_cc_test( deps = [ ":conv_canonicalization", ":target_machine_features_stub", + "//xla:literal_util", "//xla:test", "//xla:test_helpers", "//xla:util", @@ -1958,12 +1961,17 @@ cc_library( name = "collectives_interface", hdrs = ["collectives_interface.h"], deps = [ + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/core/collectives:clique_id", + "//xla/core/collectives:clique_key", "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", + "//xla/tsl/platform:statusor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", @@ -1983,16 +1991,20 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/collectives:cpu_collectives", + "//xla/backends/cpu/collectives:in_process_communicator", + "//xla/core/collectives:communicator", "//xla/core/collectives:rank_id", "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/stream_executor:device_memory", "//xla/tsl/platform:statusor", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@tsl//tsl/platform:errors", diff --git a/xla/service/cpu/collectives_interface.h b/xla/service/cpu/collectives_interface.h index faba50bc2280a..77e159e1535bc 100644 --- a/xla/service/cpu/collectives_interface.h +++ b/xla/service/cpu/collectives_interface.h @@ -17,71 +17,108 @@ limitations under the License. #define XLA_SERVICE_CPU_COLLECTIVES_INTERFACE_H_ #include +#include #include #include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/collectives/cpu_collectives.h" +#include "xla/core/collectives/clique_id.h" +#include "xla/core/collectives/clique_key.h" #include "xla/core/collectives/communicator.h" +#include "xla/core/collectives/rank_id.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla::cpu { -// TODO(b/380457503): We are in the middle of migrating this API to the new XLA -// collectives API defined under `xla/core/collectives`. -class CollectivesCommunicator { - public: - using Executor = Communicator::Executor; - - virtual ~CollectivesCommunicator() = default; +namespace internal { - // Performs an all-reduce. - virtual absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) = 0; - - // Performs a collective permute. - // Arguments: - // source_rank: the rank from which this rank should receive its data. - // Optional; if absent, then the output is filled with zeros. - // target_rank: the ranks to which this rank should send its data. - virtual absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) = 0; - - // Performs an all-to-all. - // The all-to-all chunks are passed separately and do not have to be - // contiguous in memory. - virtual absl::Status AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) = 0; - - // Performs an all-gather. - virtual absl::Status AllGather(se::DeviceMemoryBase send_buffer, +// An adapter from a shared_ptr to a Communicator. +class CommunicatorWrapper final : public Communicator { + public: + explicit CommunicatorWrapper(std::shared_ptr comm) + : comm_(std::move(comm)) {} + + absl::Status AllReduce(stream_executor::DeviceMemoryBase send_buffer, + stream_executor::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final { + return comm_->AllReduce(send_buffer, recv_buffer, dtype, count, + reduction_kind, executor); + } + + absl::Status Broadcast(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId root, + const Executor& executor) final { + return comm_->Broadcast(send_buffer, recv_buffer, dtype, count, root, + executor); + } + + absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, + PrimitiveType dtype, size_t count, + ReductionKind reduction_kind, + const Executor& executor) final { + return comm_->ReduceScatter(send_buffer, recv_buffer, dtype, count, + reduction_kind, executor); + } + + absl::Status AllGather(se::DeviceMemoryBase send_buffer, + se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, const Executor& executor) final { + return comm_->AllGather(send_buffer, recv_buffer, dtype, count, executor); + } + + absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, - const Executor& executor) = 0; - - // Performs a reduce-scatter - virtual absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) = 0; + std::optional source_rank, + absl::Span target_ranks, + const Executor& executor) final { + return comm_->CollectivePermute(send_buffer, recv_buffer, dtype, count, + source_rank, target_ranks, executor); + } + + absl::Status AllToAll(absl::Span send_buffers, + absl::Span recv_buffers, + PrimitiveType dtype, size_t count, + const Executor& executor) final { + return comm_->AllToAll(send_buffers, recv_buffers, dtype, count, executor); + } + + absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final { + return comm_->Send(send_buffer, dtype, count, peer, executor); + } + + absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, + size_t count, RankId peer, const Executor& executor) final { + return comm_->Recv(recv_buffer, dtype, count, peer, executor); + } + + absl::StatusOr NumRanks() const final { return comm_->NumRanks(); } + + std::string ToString() const final { return comm_->ToString(); } + + private: + std::shared_ptr comm_; }; -class CollectivesInterface { +} // namespace internal + +class CollectivesInterface : public CpuCollectives { public: virtual ~CollectivesInterface() = default; @@ -89,8 +126,27 @@ class CollectivesInterface { // Args: // devices: the devices participating in this collective. // rank: the rank of this process. - virtual absl::StatusOr> - GetCommunicator(absl::Span devices, int rank) = 0; + virtual absl::StatusOr> GetCommunicator( + absl::Span devices, int rank) = 0; + + absl::StatusOr>> + CreateCommunicators(int32_t nranks, const CliqueKey& clique_key, + const std::optional& clique_id, + absl::Span ranks, + const Config& config) final { + // We expect to create CPU communicators lazily one at a time. + if (ranks.size() != 1) { + return InvalidArgument("Expected 1 rank, got %d", ranks.size()); + } + + TF_ASSIGN_OR_RETURN(auto comm, GetCommunicator(clique_key.devices(), + ranks[0].rank.value())); + + std::vector> comms; + comms.reserve(1); + comms.push_back(std::make_unique(comm)); + return comms; + } }; } // namespace xla::cpu diff --git a/xla/service/cpu/conv_canonicalization_test.cc b/xla/service/cpu/conv_canonicalization_test.cc index 00c9ee256452c..6f6ebd96fb64c 100644 --- a/xla/service/cpu/conv_canonicalization_test.cc +++ b/xla/service/cpu/conv_canonicalization_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/literal_util.h" #include "xla/service/cpu/target_machine_features_stub.h" #include "xla/test.h" #include "xla/test_helpers.h" diff --git a/xla/service/cpu/cpu_compiler.cc b/xla/service/cpu/cpu_compiler.cc index 41b3847b50613..5c28de6021def 100644 --- a/xla/service/cpu/cpu_compiler.cc +++ b/xla/service/cpu/cpu_compiler.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" @@ -1503,7 +1504,17 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { std::string ir_module_string; if (embed_ir_in_executable) { - ir_module_string = llvm_ir::DumpToString(llvm_module.get()); + std::string emitter2_ir = llvm_ir::DumpToString(llvm_module.get()); + + auto thunk_kernel_fmt = [](std::string* out, + const ThunkEmitter::EmittedKernel& kernel) { + absl::StrAppend( + out, llvm_ir::DumpToString(kernel.module.getModuleUnlocked())); + }; + std::string thunks_ir = + absl::StrJoin(thunk_emitter.kernels(), "\n", thunk_kernel_fmt); + + ir_module_string = absl::StrCat(emitter2_ir, "\n", thunks_ir); } TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); diff --git a/xla/service/cpu/cpu_instruction_fusion_test.cc b/xla/service/cpu/cpu_instruction_fusion_test.cc index 6b4de145d8e80..787c4d138b344 100644 --- a/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/service/transpose_folding.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" diff --git a/xla/service/cpu/in_process_collectives.cc b/xla/service/cpu/in_process_collectives.cc index 46e5d47993d15..a7d759348fefd 100644 --- a/xla/service/cpu/in_process_collectives.cc +++ b/xla/service/cpu/in_process_collectives.cc @@ -15,575 +15,34 @@ limitations under the License. #include "xla/service/cpu/in_process_collectives.h" -#include -#include -#include -#include -#include #include -#include -#include -#include +#include #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/time/time.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/backends/cpu/collectives/cpu_collectives.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/primitive_util.h" -#include "xla/refcounting_hash_map.h" -#include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/collectives_interface.h" +#include "xla/backends/cpu/collectives/in_process_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/global_device_id.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/tsl/platform/statusor.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" -namespace xla { -namespace cpu { -namespace runtime { -namespace { +namespace xla::cpu::runtime { -void FormatGlobalId(std::string* out, const GlobalDeviceId& device) { - absl::StrAppend(out, device.value()); -} - -struct AllReduceParticipantData : ParticipantData { - explicit AllReduceParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - int64_t element_count; - const void* source_data; - void* destination_data; - PrimitiveType primitive_type; - - ReductionKind reduction_kind; - - std::string ToString() const override { - return absl::StrFormat( - "AllReduceParticipantData{rank=%d, element_count=%d, type=%s, " - "rendezvous_key=%s}", - local_rank, element_count, PrimitiveType_Name(primitive_type), - rendezvous_key.ToString()); - } -}; - -template -T GetInitialValue(ReductionKind reduction_kind) { - switch (reduction_kind) { - case ReductionKind::SUM: - return static_cast(0); - case ReductionKind::PRODUCT: - return static_cast(1); - case ReductionKind::MIN: - return std::numeric_limits::has_infinity - ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - case ReductionKind::MAX: - return std::numeric_limits::has_infinity - ? -std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - } -} - -// We cannot use static_assert(false), because the C++ standard (prior to -// CWG2518) does not allow the statement discarded by a constexpr if to -// be ill-formed for every possible specialization. -// See https://en.cppreference.com/w/cpp/language/if#Constexpr_if -template -constexpr bool always_false_v = false; - -template -void ReduceHelper(absl::Span acc, absl::Span inputs) { - // TODO(penporn): make sure this gets vectorized. - if constexpr (reduction_kind == ReductionKind::SUM) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] += inputs[j][i]; - } - } - } else if constexpr (reduction_kind == ReductionKind::PRODUCT) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] *= inputs[j][i]; - } - } - } else if constexpr (reduction_kind == ReductionKind::MIN) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] = std::min(acc[i], inputs[j][i]); - } - } - } else if constexpr (reduction_kind == ReductionKind::MAX) { - for (size_t j = 0; j < inputs.size(); ++j) { - for (size_t i = 0; i < acc.size(); ++i) { - acc[i] = std::max(acc[i], inputs[j][i]); - } - } - } else { - static_assert(always_false_v, "Unsupported reduction kind"); - } -} - -template -absl::Status ReduceScatter(ReductionKind reduction_kind, - absl::Span inputs, void* output, - int64_t num_elems) { - using T = primitive_util::NativeTypeOf; - T initial_value = GetInitialValue(reduction_kind); - - absl::Span out_chunk = - absl::MakeSpan(reinterpret_cast(output), num_elems); - for (int64_t i = 0; i < num_elems; ++i) { - out_chunk[i] = initial_value; - } - - absl::Span input_chunks( - reinterpret_cast(inputs.data()), inputs.size()); - switch (reduction_kind) { - case ReductionKind::SUM: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::PRODUCT: - ReduceHelper(out_chunk, input_chunks); - break; - case ReductionKind::MIN: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Min reductions not supported for complex types"); - } - break; - case ReductionKind::MAX: - if constexpr (!is_complex_v) { - ReduceHelper(out_chunk, input_chunks); - } else { - return absl::InvalidArgumentError( - "Max reductions not supported for complex types"); - } - break; - } - - return absl::OkStatus(); -} - -class CpuAllReduceRendezvous - : public Rendezvous { - public: - explicit CpuAllReduceRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - absl::StatusOr RunCollectiveOp( - const AllReduceParticipantData& me) override { - VLOG(3) << me.ToString(); - int64_t world_size = participants_.size(); - // Divide the buffer up into equal(ish) chunks. Rank r computes the r-th - // chunk of the output. - int64_t chunk_elems = CeilOfRatio(me.element_count, world_size); - - int64_t start_elem = me.local_rank * chunk_elems; - int64_t end_elem = std::min(start_elem + chunk_elems, me.element_count); - chunk_elems = std::max(int64_t{0}, end_elem - start_elem); - if (chunk_elems == 0) { - return nullptr; - } - - auto bytes_per_elem = primitive_util::ByteWidth(me.primitive_type); - int64_t chunk_offset = start_elem * bytes_per_elem; - int64_t chunk_bytes = chunk_elems * bytes_per_elem; - void* reduce_output = - reinterpret_cast(me.destination_data) + chunk_offset; - - std::vector inputs; - inputs.reserve(world_size); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_data) + - chunk_offset); - } - - if (primitive_util::IsArrayType(me.primitive_type)) { - TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( - [&](const auto constant_type) { - return ReduceScatter(me.reduction_kind, inputs, - reduce_output, chunk_elems); - }, - me.primitive_type)); - } else { - return absl::UnimplementedError(absl::StrCat( - "Unexpected datatype: ", - primitive_util::LowercasePrimitiveTypeName(me.primitive_type))); - } - - // All-gather the reduced chunks. - for (const auto& p : participants_) { - if (p->local_rank != me.local_rank) { - std::memcpy(reinterpret_cast(p->destination_data) + chunk_offset, - reduce_output, chunk_bytes); - } - } - return nullptr; - } -}; - -struct CollectivePermuteParticipantData : ParticipantData { - CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key_p, - int rank) - : ParticipantData(rendezvous_key_p, rank) {} - const void* source_buffer; - void* destination_buffer; - size_t num_bytes; - - // From which rank is this participant receiving its data? Optional; if - // absent fill with zeros. - std::optional source_rank; - - std::string ToString() const override { - return absl::StrFormat( - "CollectivePermuteParticipantData{rank=%d, " - "source_buffer=%p, destination_buffer=%p, num_bytes=%d, " - "source_replica_id=%d, " - "devices=[%s]}", - local_rank, source_buffer, destination_buffer, num_bytes, - source_rank.value_or(-1), - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId)); - } -}; - -class CpuCollectivePermuteRendezvous - : public Rendezvous { - public: - explicit CpuCollectivePermuteRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - - absl::StatusOr RunCollectiveOp( - const CollectivePermuteParticipantData& p) override { - VLOG(3) << p.ToString(); - if (p.source_rank) { - std::memcpy(p.destination_buffer, - participants_[*p.source_rank]->source_buffer, p.num_bytes); - } else { - std::memset(p.destination_buffer, 0, p.num_bytes); - } - return nullptr; - } -}; - -struct AllToAllParticipantData : ParticipantData { - AllToAllParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - std::vector source_buffers; - std::vector destination_buffers; - size_t chunk_size; - - std::string ToString() const override { - auto addr_formatter = [](std::string* out, const void* mem) { - absl::StrAppend(out, absl::StrFormat("%p", mem)); - }; - return absl::StrFormat( - "AllToAllParticipantData{rank=%d, " - "devices=[%s], source_buffers=[%s], " - "destination_buffers=[%s], chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - absl::StrJoin(source_buffers, ", ", addr_formatter), - absl::StrJoin(destination_buffers, ", ", addr_formatter), chunk_size); - } -}; - -class CpuAllToAllRendezvous - : public Rendezvous { - public: - explicit CpuAllToAllRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllToAllParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - for (int i = 0; i < world_size; ++i) { - std::memcpy(participants_[i]->destination_buffers[p.local_rank], - p.source_buffers[i], p.chunk_size); - } - return nullptr; - } -}; - -struct AllGatherParticipantData : ParticipantData { - AllGatherParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - const void* source_buffer; - void* destination_buffer; - size_t chunk_size; - - std::string ToString() const override { - return absl::StrFormat( - "AllGatherParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_size=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_size); - } -}; - -class CpuAllGatherRendezvous - : public Rendezvous { - public: - explicit CpuAllGatherRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const AllGatherParticipantData& p) override { - int world_size = p.rendezvous_key.global_devices.size(); - char* out = static_cast(p.destination_buffer); - for (int i = 0; i < world_size; ++i, out += p.chunk_size) { - std::memcpy(out, participants_[i]->source_buffer, p.chunk_size); - } - return nullptr; - } -}; - -struct ReduceScatterParticipantData : ParticipantData { - ReduceScatterParticipantData(const RendezvousKey& rendezvous_key_p, int rank) - : ParticipantData(rendezvous_key_p, rank) {} - - ReductionKind reduction_kind; - PrimitiveType element_type; - const void* source_buffer; - void* destination_buffer; - size_t chunk_elems; - - std::string ToString() const override { - return absl::StrFormat( - "ReduceScatterParticipantData{rank=%d, " - "devices=[%s], source_buffer=%p, " - "destination_buffer=%p, chunk_elems=%d}", - local_rank, - absl::StrJoin(rendezvous_key.global_devices, ", ", FormatGlobalId), - source_buffer, destination_buffer, chunk_elems); - } -}; - -class CpuReduceScatterRendezvous - : public Rendezvous { - public: - explicit CpuReduceScatterRendezvous(const RendezvousKey& k) - : Rendezvous(k) {} - - protected: - CollectivesInterface* collectives_; - absl::StatusOr RunCollectiveOp( - const ReduceScatterParticipantData& me) override { - auto bytes_per_elem = primitive_util::ByteWidth(me.element_type); - int64_t chunk_offset = me.local_rank * me.chunk_elems * bytes_per_elem; - - std::vector inputs; - inputs.reserve(participants_.size()); - for (const auto& p : participants_) { - inputs.push_back(reinterpret_cast(p->source_buffer) + - chunk_offset); - } - - if (primitive_util::IsArrayType(me.element_type)) { - TF_RETURN_IF_ERROR(primitive_util::ArrayTypeSwitch( - [&](const auto constant_type) { - return ReduceScatter(me.reduction_kind, inputs, - me.destination_buffer, - me.chunk_elems); - }, - me.element_type)); - } else { - return absl::UnimplementedError(absl::StrCat( - "Unexpected datatype: ", - primitive_util::LowercasePrimitiveTypeName(me.element_type))); - } - return nullptr; - } -}; - -} // namespace - -struct InProcessCollectivesState { - RefcountingHashMap - all_reduce_rendezvous_map; - RefcountingHashMap - collective_permute_rendezvous_map; - RefcountingHashMap - all_to_all_rendezvous_map; - RefcountingHashMap - all_gather_rendezvous_map; - RefcountingHashMap - reduce_scatter_rendezvous_map; -}; - -InProcessCollectivesCommunicator::InProcessCollectivesCommunicator( - InProcessCollectivesState* state, int rank, int size) - : state_(state), rank_(rank) {} -InProcessCollectivesCommunicator::~InProcessCollectivesCommunicator() = default; - -absl::Status InProcessCollectivesCommunicator::AllReduce( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - AllReduceParticipantData participant(key, rank_); - participant.element_count = count; - participant.primitive_type = dtype; - participant.source_data = send_buffer.opaque(); - participant.destination_data = recv_buffer.opaque(); - participant.reduction_kind = reduction_kind; - - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - - return CpuAllReduceRendezvous::SubmitParticipant( - [&] { - return state_->all_reduce_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::CollectivePermute( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, std::optional source_rank, - absl::Span target_ranks, const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); +absl::StatusOr> +InProcessCollectives::GetCommunicator(absl::Span devices, + int rank) { + absl::MutexLock lock(&mu_); - CollectivePermuteParticipantData participant(key, rank_); - participant.source_buffer = send_buffer.opaque(); - participant.destination_buffer = recv_buffer.opaque(); - participant.num_bytes = count * primitive_util::ByteWidth(dtype); - participant.source_rank = std::nullopt; - if (source_rank) { - participant.source_rank = source_rank->value(); + std::shared_ptr state = state_.lock(); + if (state == nullptr) { + state = InProcessCommunicator::CreateState(); + state_ = state; } - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuCollectivePermuteRendezvous::SubmitParticipant( - [&] { - return state_->collective_permute_rendezvous_map - .GetOrCreateIfAbsent(key, make_cpu_rendezvous); - }, - participant) - .status(); -} -absl::Status InProcessCollectivesCommunicator::AllToAll( - absl::Span send_buffers, - absl::Span recv_buffers, PrimitiveType dtype, - size_t count, const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - AllToAllParticipantData participant(key, rank_); - TF_RET_CHECK(send_buffers.size() == recv_buffers.size()); - - size_t chunk_bytes = count * primitive_util::ByteWidth(dtype); - - participant.chunk_size = chunk_bytes; - participant.source_buffers.reserve(send_buffers.size()); - participant.destination_buffers.reserve(recv_buffers.size()); - for (se::DeviceMemoryBase send_buffer : send_buffers) { - participant.source_buffers.push_back(send_buffer.opaque()); - } - for (se::DeviceMemoryBase recv_buffer : recv_buffers) { - participant.destination_buffers.push_back(recv_buffer.opaque()); - } - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllToAllRendezvous::SubmitParticipant( - [&] { - return state_->all_to_all_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::AllGather( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - AllGatherParticipantData participant(key, rank_); - participant.chunk_size = count * primitive_util::ByteWidth(dtype); - participant.source_buffer = send_buffer.opaque(); - participant.destination_buffer = recv_buffer.opaque(); - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuAllGatherRendezvous::SubmitParticipant( - [&] { - return state_->all_gather_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} - -absl::Status InProcessCollectivesCommunicator::ReduceScatter( - se::DeviceMemoryBase send_buffer, se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, ReductionKind reduction_kind, - const Executor& executor) { - TF_ASSIGN_OR_RETURN(auto cpu_executor, CpuCollectives::TryCast(&executor)); - const RendezvousKey& key = cpu_executor->rendezvous_key(); - - ReduceScatterParticipantData participant(key, rank_); - participant.element_type = dtype; - participant.reduction_kind = reduction_kind; - participant.chunk_elems = count; - participant.source_buffer = send_buffer.opaque(); - participant.destination_buffer = recv_buffer.opaque(); - auto make_cpu_rendezvous = [](const RendezvousKey& k) { - return std::make_unique(k); - }; - return CpuReduceScatterRendezvous::SubmitParticipant( - [&] { - return state_->reduce_scatter_rendezvous_map.GetOrCreateIfAbsent( - key, make_cpu_rendezvous); - }, - participant) - .status(); -} -InProcessCollectives::InProcessCollectives() - : state_(std::make_unique()) {} -InProcessCollectives::~InProcessCollectives() = default; - -absl::StatusOr> -InProcessCollectives::GetCommunicator(absl::Span devices, - int rank) { // We don't care about devices here: we share rendezvous state globally. - return std::make_shared(state_.get(), rank, - devices.size()); + return std::make_shared(std::move(state), rank, + devices.size()); } -} // namespace runtime -} // namespace cpu -} // namespace xla +} // namespace xla::cpu::runtime diff --git a/xla/service/cpu/in_process_collectives.h b/xla/service/cpu/in_process_collectives.h index 9f04e9890eda0..976470ac07b8a 100644 --- a/xla/service/cpu/in_process_collectives.h +++ b/xla/service/cpu/in_process_collectives.h @@ -16,73 +16,31 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ #define XLA_SERVICE_CPU_IN_PROCESS_COLLECTIVES_H_ -#include #include -#include -#include "absl/status/status.h" +#include "absl/base/thread_annotations.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/core/collectives/rank_id.h" -#include "xla/service/collective_ops_utils.h" +#include "xla/backends/cpu/collectives/in_process_communicator.h" +#include "xla/core/collectives/communicator.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" -#include "xla/stream_executor/device_memory.h" #include "xla/xla_data.pb.h" namespace xla::cpu::runtime { -struct InProcessCollectivesState; - -class InProcessCollectivesCommunicator : public CollectivesCommunicator { - public: - InProcessCollectivesCommunicator(InProcessCollectivesState* state, int rank, - int size); - ~InProcessCollectivesCommunicator() override; - - absl::Status AllReduce(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, ReductionKind reduction_kind, - const Executor& executor) override; - - absl::Status CollectivePermute(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - std::optional source_rank, - absl::Span target_ranks, - const Executor& executor) override; - - absl::Status AllToAll(absl::Span send_buffers, - absl::Span recv_buffers, - PrimitiveType dtype, size_t count, - const Executor& executor) override; - - absl::Status AllGather(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, - size_t count, const Executor& executor) override; - - absl::Status ReduceScatter(se::DeviceMemoryBase send_buffer, - se::DeviceMemoryBase recv_buffer, - PrimitiveType dtype, size_t count, - ReductionKind reduction_kind, - const Executor& executor) override; - - private: - InProcessCollectivesState* state_; - int rank_; -}; - class InProcessCollectives : public CollectivesInterface { public: - InProcessCollectives(); - ~InProcessCollectives() override; - // Thread-safe. - absl::StatusOr> GetCommunicator( + absl::StatusOr> GetCommunicator( absl::Span devices, int rank) override; private: - std::unique_ptr state_; + absl::Mutex mu_; + + // State shared by all constructed communicators. + std::weak_ptr state_ ABSL_GUARDED_BY(mu_); }; } // namespace xla::cpu::runtime diff --git a/xla/service/cpu/ir_emitter2.cc b/xla/service/cpu/ir_emitter2.cc index ca6f1d2610116..1890d5377bfb4 100644 --- a/xla/service/cpu/ir_emitter2.cc +++ b/xla/service/cpu/ir_emitter2.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -114,46 +115,6 @@ IrEmitter2::KernelInfo::KernelInfo(KernelPrototype prototype, thread_dims(thread_dims), invariant_arguments(std::move(prototype.invariant_arguments)) {} -absl::StatusOr IrEmitter2::EmitElementalHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit elemental host kernel: " << instr->name(); - - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - - IrEmitter::IRBuilderGuard builder_guard = nested_ir_emitter_->WithBuilder(b); - - CpuElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; - for (int64_t i = 0; i < instr->operand_count(); ++i) { - const HloInstruction* operand = instr->operand(i); - operand_to_generator[operand] = [&, i](const llvm_ir::IrArray::Index& idx) { - return kernel_prototype.arguments[i].EmitReadArrayElement(idx, &b); - }; - } - - if (instr->has_to_apply()) { - HloComputation* nested_computation = instr->to_apply(); - bool is_reducer = instr->opcode() == HloOpcode::kReduce || - instr->opcode() == HloOpcode::kReduceWindow; - TF_RETURN_IF_ERROR(nested_ir_emitter_->EmitNestedComputation( - *nested_computation, llvm_ir::IrName(instr), is_reducer)); - } - - CpuElementalIrEmitter elemental_emitter = ElementalIrEmmiterFactory(&b); - llvm_ir::ElementGenerator element_generator = - elemental_emitter.MakeElementGenerator(instr, operand_to_generator); - - TF_ASSIGN_OR_RETURN( - se::ThreadDim thread_dims, - EmitElementalLoops(b, instr, kernel_prototype, element_generator)); - - return kernels_.emplace_back( - KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); -} - absl::StatusOr IrEmitter2::EmitPadHostKernel( const HloInstruction* pad) { VLOG(2) << "Emit Pad host kernel."; @@ -247,14 +208,6 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } -absl::StatusOr IrEmitter2::EmitReductionHostKernel( - const HloInstruction* instr) { - VLOG(2) << "Emit reduction host kernel: " << instr->name(); - - // TODO(ezhulenev): Port vectorized reduction emitter from IrEmitter. - return EmitElementalHostKernel(instr); -} - // Dot (fusion) host kernel only supports strategies that emit LLVM IR. static bool IsDotCodegenStrategy(DotImplementationStrategy strategy) { static std::array kDotCodegenStrategies = { @@ -303,25 +256,20 @@ absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( const HloInstruction* instr) { VLOG(2) << "Emit concatenate host kernel: " << instr->name(); - auto fast_impl_reason = CanDoFastConcatenate(instr); - if (fast_impl_reason.ok()) { - VLOG(1) << "Emitting fast concatenate for " << instr->ToString() << ": " - << fast_impl_reason.message(); - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); - llvm::IRBuilder<> ir_builder(module_->getContext()); - ir_builder.SetInsertPoint( - kernel_prototype.function->getEntryBlock().getTerminator()); - - llvm_ir::IrArray output_array = kernel_prototype.results[0]; - TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( - instr, kernel_prototype.arguments, output_array, module_, ir_builder)); - return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), - se::BlockDim(), se::ThreadDim())); - } - VLOG(1) << "Could not emit fast concatenate for " << instr->ToString() << ": " - << fast_impl_reason.message(); - return EmitElementalHostKernel(instr); + DCHECK_OK(CanDoFastConcatenate(instr)); + + VLOG(1) << "Emitting fast concatenate for " << instr->ToString(); + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); + llvm::IRBuilder<> ir_builder(module_->getContext()); + ir_builder.SetInsertPoint( + kernel_prototype.function->getEntryBlock().getTerminator()); + + llvm_ir::IrArray output_array = kernel_prototype.results[0]; + TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( + instr, kernel_prototype.arguments, output_array, module_, ir_builder)); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( @@ -401,26 +349,22 @@ absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( absl::StatusOr IrEmitter2::EmitDynamicUpdateSliceHostKernel(const HloInstruction* instr) { - if (llvm_ir::CanUpdateDynamicSliceInPlace(const_cast(instr), - nested_ir_emitter_->assignment())) { - VLOG(2) << "Emit in-place dynamic-update-slice kernel: " << instr->name(); + DCHECK(CanUpdateDynamicSliceInPlace(instr)); - TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, - EmitKernelPrototype(instr)); + VLOG(2) << "Emit in-place dynamic-update-slice kernel: " << instr->name(); - llvm::IRBuilder<> b(module_->getContext()); - b.SetInsertPoint( - kernel_prototype.function->getEntryBlock().getTerminator()); + TF_ASSIGN_OR_RETURN(KernelPrototype kernel_prototype, + EmitKernelPrototype(instr)); - TF_RETURN_IF_ERROR(llvm_ir::EmitDynamicUpdateSliceInPlace( - kernel_prototype.arguments, kernel_prototype.results.front(), - llvm_ir::IrName(instr, "in_place"), &b)); + llvm::IRBuilder<> b(module_->getContext()); + b.SetInsertPoint(kernel_prototype.function->getEntryBlock().getTerminator()); - return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), - se::BlockDim(), se::ThreadDim())); - } + TF_RETURN_IF_ERROR(llvm_ir::EmitDynamicUpdateSliceInPlace( + kernel_prototype.arguments, kernel_prototype.results.front(), + llvm_ir::IrName(instr, "in_place"), &b)); - return EmitElementalHostKernel(instr); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitSortComparator( @@ -499,6 +443,12 @@ absl::Status IrEmitter2::CanDoFastConcatenate( return absl::OkStatus(); }; +bool IrEmitter2::CanUpdateDynamicSliceInPlace( + const HloInstruction* update) const { + return llvm_ir::CanUpdateDynamicSliceInPlace( + const_cast(update), nested_ir_emitter_->assignment()); +} + IrEmitter2::ParallelPartitionBounds IrEmitter2::EmitParallelPartitionBounds( llvm::IRBuilderBase& b, const KernelPrototype& kernel_prototype, const ParallelConfig& parallel_config, const Shape& shape, diff --git a/xla/service/cpu/ir_emitter2.h b/xla/service/cpu/ir_emitter2.h index 2bcb7c1c9316f..77ea6647d4ec9 100644 --- a/xla/service/cpu/ir_emitter2.h +++ b/xla/service/cpu/ir_emitter2.h @@ -98,10 +98,6 @@ class IrEmitter2 { absl::Span comparators() const { return comparators_; } - // Emits an elemental host kernel for the given HLO instruction. - absl::StatusOr EmitElementalHostKernel( - const HloInstruction* instr); - // Emits a host kernel for the pad instruction. absl::StatusOr EmitPadHostKernel(const HloInstruction* pad); @@ -109,10 +105,6 @@ class IrEmitter2 { absl::StatusOr EmitFusionHostKernel( const HloFusionInstruction* fusion); - // Emits a host kernel for the given reduction instruction. - absl::StatusOr EmitReductionHostKernel( - const HloInstruction* instr); - // Emits a host kernel for the given dot instruction. Small dot operations // are emitted as LLVM IR directly, while larger ones are emitted as a dot // thunk that calls into libraries. @@ -137,6 +129,9 @@ class IrEmitter2 { // Emits a comparator function for the given sort instruction. absl::StatusOr EmitSortComparator(HloComputation* comparator); + absl::Status CanDoFastConcatenate(const HloInstruction* concatenate) const; + bool CanUpdateDynamicSliceInPlace(const HloInstruction* update) const; + private: class ElementalIrEmitter; @@ -160,8 +155,6 @@ class IrEmitter2 { // the instruction has to be compiled to a single threaded loop. std::optional GetParallelConfig(const HloInstruction* instr); - absl::Status CanDoFastConcatenate(const HloInstruction* concatenate) const; - // Emits LLVM IR that computes parallel partition bounds from the call frame's // block and thread dimensions and parallel execution config. ParallelPartitionBounds EmitParallelPartitionBounds( diff --git a/xla/service/cpu/thunk_emitter.cc b/xla/service/cpu/thunk_emitter.cc index 5a3b848c3db3c..a5d0aeade482f 100644 --- a/xla/service/cpu/thunk_emitter.cc +++ b/xla/service/cpu/thunk_emitter.cc @@ -526,6 +526,13 @@ absl::StatusOr ThunkEmitter::EmitCallThunk( absl::StatusOr ThunkEmitter::EmitConcatenateKernelThunk( const HloInstruction* instruction) { + if (absl::Status status = ir_emitter_.CanDoFastConcatenate(instruction); + !status.ok()) { + VLOG(1) << "Could not emit fast concatenate for " << instruction->ToString() + << ": " << status.message(); + return EmitElementalKernelThunk(instruction); + } + auto* concatenate = Cast(instruction); TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitConcatenateHostKernel(concatenate)); @@ -661,13 +668,8 @@ absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( const HloInstruction* instruction) { - TF_ASSIGN_OR_RETURN(auto kernel, - ir_emitter_.EmitReductionHostKernel(instruction)); - TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - - return MakeKernelThunkSequence( - instruction, buffers, kernel, - /*min_alignment=*/cpu_function_runtime::MinAlign()); + // TODO(ezhulenev): Port vectorized reduction emitter from IrEmitter. + return EmitElementalKernelThunk(instruction); } absl::StatusOr ThunkEmitter::EmitRngThunk( @@ -1041,6 +1043,12 @@ absl::StatusOr ThunkEmitter::EmitSliceThunk( absl::StatusOr ThunkEmitter::EmitDynamicUpdateSliceThunk( const HloInstruction* instruction) { + if (!ir_emitter_.CanUpdateDynamicSliceInPlace(instruction)) { + VLOG(2) << "Could not emit in-place dynamic-update-slice kernel: " + << instruction->name(); + return EmitElementalKernelThunk(instruction); + } + TF_ASSIGN_OR_RETURN( auto kernel, ir_emitter_.EmitDynamicUpdateSliceHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 3bf55230507d4..605b8f62f8450 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1728,6 +1728,7 @@ xla_test( "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", "@com_google_googletest//:gtest", "@tsl//tsl/platform:logging", ], @@ -1960,6 +1961,7 @@ xla_cc_test( ":amdgpu_compiler_impl", ]) + [ ":gpu_transfer_manager", + "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/service:compiler", @@ -1971,6 +1973,7 @@ xla_cc_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -2316,6 +2319,7 @@ gpu_kernel_library( "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", + "@local_config_rocm//rocm:rocm_config", ]), ) @@ -2476,6 +2480,10 @@ cc_library( "@tsl//tsl/platform:logging", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:statusor", + ]) + if_rocm_is_configured([ + # keep sorted + "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:rocm_headers", ]), ) @@ -2486,7 +2494,9 @@ gpu_kernel_library( deps = [ "//xla:shape_util", "//xla:types", - ], + ] + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) xla_test( @@ -3069,13 +3079,13 @@ cc_library( hdrs = ["gpu_collective_combiner_utils.h"], deps = [ ":backend_configs_cc", + ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:collective_utils", "//xla/stream_executor:device_description", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@tsl//tsl/platform:statusor", ], ) @@ -3086,7 +3096,6 @@ xla_cc_test( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -3094,16 +3103,12 @@ xla_cc_test( "//xla/hlo/transforms/simplifiers:hlo_dce", "//xla/hlo/utils:hlo_query", "//xla/service:collective_pipeliner", - "//xla/service:collective_utils", "//xla/service:hlo_module_config", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", ], ) @@ -3114,7 +3119,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/collectives:all_gather_combiner", @@ -3153,7 +3157,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_domain_map", @@ -3190,7 +3193,6 @@ cc_library( deps = [ ":backend_configs_cc", ":gpu_collective_combiner_utils", - ":gpu_hlo_schedule", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms/collectives:all_reduce_combiner", diff --git a/xla/service/gpu/all_gather_combiner.cc b/xla/service/gpu/all_gather_combiner.cc index 996d3a1fe83be..96f10d43113b5 100644 --- a/xla/service/gpu/all_gather_combiner.cc +++ b/xla/service/gpu/all_gather_combiner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_gather_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "tsl/platform/statusor.h" @@ -78,8 +77,7 @@ absl::StatusOr GpuAllGatherCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllGather, pointer_size_); + *module, device_info_, HloOpcode::kAllGather, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/xla/service/gpu/all_reduce_combiner.cc b/xla/service/gpu/all_reduce_combiner.cc index 108d10cee3e5d..5fb3d960bb237 100644 --- a/xla/service/gpu/all_reduce_combiner.cc +++ b/xla/service/gpu/all_reduce_combiner.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/transforms/collectives/all_reduce_combiner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "tsl/platform/statusor.h" @@ -76,8 +75,7 @@ absl::StatusOr GpuAllReduceCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size_); + *module, device_info_, HloOpcode::kAllReduce, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/xla/service/gpu/auto_sharding_gpu_compiler_test.cc b/xla/service/gpu/auto_sharding_gpu_compiler_test.cc index 89be2dac856e0..ad5a80d836ea2 100644 --- a/xla/service/gpu/auto_sharding_gpu_compiler_test.cc +++ b/xla/service/gpu/auto_sharding_gpu_compiler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index ba6743ee4801a..39f31c0dcf0b5 100644 --- a/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -1016,11 +1016,11 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, << " with config '" << ConfigToString(config) << "'\nFused HLO computation:\n" << fusion->fused_instructions_computation()->ToString(); + log(*executable != nullptr); if (*executable != nullptr) { absl::MutexLock lock(&results_mu); results[fusion].push_back({config, std::move(*executable)}); } - log(*executable != nullptr); counter.DecrementCount(); }); } @@ -1047,10 +1047,10 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, TF_ASSIGN_OR_RETURN( std::unique_ptr executable, compile(fusion, config, gemm_config_set.size() > 1)); + log(executable != nullptr); if (executable != nullptr) { results[fusion].push_back({config, std::move(executable)}); } - log(executable != nullptr); } } } diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index 84f008d3717b3..906baaa33d512 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -270,6 +270,11 @@ message CudnnfMHABackendConfig { // Sliding window length // ignored if the value <= 0 int32 sliding_window_length = 24; + + // The maximum number of segments in each batch + // Only used with packed layout + // ignored if the valued <= 1 + int32 max_seg_per_batch = 25; } // Backend config for a general custom call instruction, e.g. XLA FFI. diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index f859c70af9405..17d79786b802b 100644 --- a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -608,7 +608,7 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm, // opportunities for LICM. This would not be necessary if LICM also moved // instructions over ifs. pm.addPass(mlir::createLoopInvariantCodeMotionPass()); - pm.addNestedPass(CreateVectorizeLoadsAndStoresPass()); + pm.addNestedPass(CreateVectorizeLoadsAndStoresPass(device)); pm.addNestedPass(CreateOptimizeLoopsPass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); diff --git a/xla/service/gpu/fusions/scatter_mlir.cc b/xla/service/gpu/fusions/scatter_mlir.cc index 5163375e38cdb..4f98d4bfd61dc 100644 --- a/xla/service/gpu/fusions/scatter_mlir.cc +++ b/xla/service/gpu/fusions/scatter_mlir.cc @@ -301,8 +301,8 @@ class EmitterHelper { Value write_to_output_required, ValueRange thread_and_block_ids, Value iv, const IndexingMap& slice_indexing, - Value offsets_changed, ValueRange offsets, - Value accumulator, Value output_tensor) const; + ValueRange offsets, Value accumulator, + Value output_tensor) const; private: Value GetElement(ImplicitLocOpBuilder& b, int operand_index, @@ -371,8 +371,8 @@ SmallVector EmitterHelper::WriteAccumulatedElementToOutput( Value EmitterHelper::WriteAccumulatorToOutput( ImplicitLocOpBuilder& b, Value write_to_output_required, ValueRange thread_and_block_ids, Value iv, - const IndexingMap& slice_indexing, Value offsets_changed, - ValueRange offsets, Value accumulator, Value output_tensor) const { + const IndexingMap& slice_indexing, ValueRange offsets, Value accumulator, + Value output_tensor) const { SmallVector dims = Pack({thread_and_block_ids, iv}); return EmitUpdateIf( b, write_to_output_required, output_tensor, @@ -721,11 +721,15 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( // Prepare loop initial values. Inits are packed as // [index_changed, is_inbounds, index_0, ..., accumulator]. Value is_inbounds_init = b.create(0, b.getI1Type()); + Value slice_id_init = b.create(0); std::vector indices_init(description_.index_vector_length, b.create(-1)); Value accumulator_init = InitializeAccumulator(b); SmallVector inits = - Pack({indices_init, is_inbounds_init, accumulator_init, output_tensor}); + Pack({slice_id_init, indices_init, is_inbounds_init, accumulator_init, + output_tensor}); + + int64_t output_rank = description_.output_shape.size(); auto loop_over_indices_fn = [&](ImplicitLocOpBuilder& nested_b, ValueRange ivs, @@ -733,14 +737,13 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( ValueRange outer_iter_args) -> SmallVector { // Unpack the iter_args. SmallVector iter_args_unpack = - Unpack(outer_iter_args, {description_.index_vector_length, 1, 1, 1}); - ValueRange trimmed_offsets = iter_args_unpack[0]; - Value iter_is_inbounds = iter_args_unpack[1].front(); - Value iter_acc = iter_args_unpack[2].front(); - Value iter_output = iter_args_unpack[3].front(); + Unpack(outer_iter_args, {1, description_.index_vector_length, 1, 1, 1}); + ValueRange trimmed_offsets = iter_args_unpack[1]; + Value iter_is_inbounds = iter_args_unpack[2].front(); + Value iter_acc = iter_args_unpack[3].front(); + Value iter_output = iter_args_unpack[4].front(); Value iter_slice_id = ivs.front(); - int64_t output_rank = description_.output_shape.size(); SmallVector offsets = PadWithZeros(trimmed_offsets, output_rank, nested_b); @@ -767,78 +770,95 @@ absl::Status ScatterWithDistributedIndices::EmitEntryFunctionImpl( b.create(offsets_changed, iter_is_inbounds)); iter_output = helper.WriteAccumulatorToOutput( b, write_to_output_required, thread_and_block_ids, iter_slice_id, - slice_indexing, offsets_changed, offsets, iter_acc, iter_output); + slice_indexing, offsets, iter_acc, iter_output); // Update `is_inbounds` if the offsets changed. Value new_is_inbounds = UpdateIsInbounds( nested_b, iter_is_inbounds, offsets_changed, new_offsets, description_.slice_shape, description_.output_shape); - // Update accumulator and/or output. - auto is_last_iteration = nested_b.create( - arith::CmpIPredicate::eq, iter_slice_id, - b.create(num_indices_per_warp_ - 1)); - - SmallVector acc_and_output = {iter_acc, iter_output}; - auto loop_over_slices_fn = - [&](ImplicitLocOpBuilder& update_loop_b, ValueRange accumulator_indices, - ValueRange slice_indices, - ValueRange inner_iter_args) -> SmallVector { - Value acc_arg = inner_iter_args.front(); - Value output_arg = inner_iter_args.back(); - auto update_elem = helper.GetUpdateElement(update_loop_b, slice_indices); - auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); - // If the index changed, overwrite the accumulator element, otherwise - // apply the scatter computation to reduce with the accumulator element. - auto updated_accumulator = - update_loop_b - .create( - offsets_changed, - [&](OpBuilder& then_b, Location then_loc) -> void { - Value updated_accumulator = then_b.create( - then_loc, update_elem, acc_arg, acc_ind_opfold); - then_b.create(then_loc, updated_accumulator); - }, - [&](OpBuilder& else_b, Location else_loc) -> void { - ImplicitLocOpBuilder implicit_else_b(else_loc, else_b); - Value accumulator_elem = - implicit_else_b.create( - acc_arg, acc_ind_opfold); - auto reduced_val = mlir_converter::InlineBlock( - implicit_else_b, helper.GetReducer().getBody().front(), - {accumulator_elem, update_elem})[0]; - Value updated_ac = implicit_else_b.create( - reduced_val, acc_arg, acc_ind_opfold); - implicit_else_b.create(updated_ac); - }) - .getResult(0); - // If this is the last index, that this warp has to process, then we write - // to the output. - auto updated_output = - EmitUpdateIf(update_loop_b, is_last_iteration, output_arg, - [&](ImplicitLocOpBuilder& nested_b) { - return helper.WriteAccumulatedElementToOutput( - nested_b, updated_accumulator, accumulator_indices, - slice_indices, new_offsets, output_arg); - }) - .front(); - return {updated_accumulator, updated_output}; + // Emits a loop that overwrites the accumulator with the new update elements + // if the offsets changed. + auto emit_overwrite_accumulator_fn = [&](OpBuilder& then_b, + Location then_loc) -> void { + ImplicitLocOpBuilder implicit_then_b(then_loc, then_b); + auto then_results = EmitXlaLoopOp( + implicit_then_b, Pack({thread_and_block_ids, iter_slice_id}), + {iter_acc}, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange inner_iter_args) -> SmallVector { + Value acc_arg = inner_iter_args.front(); + auto update_elem = + helper.GetUpdateElement(update_loop_b, slice_indices); + auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); + return update_loop_b + .create(then_loc, update_elem, acc_arg, + acc_ind_opfold) + ->getResults(); + }); + implicit_then_b.create(then_loc, then_results); + }; + // Emits a loop that combines the accumulator with the new update elements + // if the offsets did not change. + auto emit_combine_accumulator_fn = [&](OpBuilder& else_b, + Location else_loc) -> void { + ImplicitLocOpBuilder implicit_else_b(else_loc, else_b); + auto else_results = EmitXlaLoopOp( + implicit_else_b, Pack({thread_and_block_ids, iter_slice_id}), + {iter_acc}, slice_indexing, + [&](ImplicitLocOpBuilder& update_loop_b, + ValueRange accumulator_indices, ValueRange slice_indices, + ValueRange inner_iter_args) -> SmallVector { + Value acc_arg = inner_iter_args.front(); + auto update_elem = + helper.GetUpdateElement(update_loop_b, slice_indices); + auto acc_ind_opfold = mlir::getAsOpFoldResult(accumulator_indices); + Value accumulator_elem = update_loop_b.create( + acc_arg, acc_ind_opfold); + auto reduced_val = mlir_converter::InlineBlock( + update_loop_b, helper.GetReducer().getBody().front(), + {accumulator_elem, update_elem})[0]; + return update_loop_b + .create(reduced_val, acc_arg, acc_ind_opfold) + ->getResults(); + }); + implicit_else_b.create(else_results); }; - auto updated_accumulator_and_output = - EmitUpdateIf(nested_b, new_is_inbounds, acc_and_output, + auto updated_accumulator = + EmitUpdateIf(nested_b, new_is_inbounds, {iter_acc}, [&](ImplicitLocOpBuilder& if_b) { - return EmitXlaLoopOp( - if_b, Pack({thread_and_block_ids, iter_slice_id}), - acc_and_output, slice_indexing, loop_over_slices_fn); - }); - SmallVector updated_if_loop_results = Pack( - {new_trimmed_offsets, new_is_inbounds, updated_accumulator_and_output}); + return nested_b + .create(offsets_changed, + emit_overwrite_accumulator_fn, + emit_combine_accumulator_fn) + .getResults(); + }) + .front(); + SmallVector updated_if_loop_results = + Pack({iter_slice_id, new_trimmed_offsets, new_is_inbounds, + updated_accumulator, iter_output}); return updated_if_loop_results; }; auto loop_over_indices_results = EmitXlaLoopOp(b, thread_and_block_ids, inits, thread_id_to_update_id_map, loop_over_indices_fn); - b.create(loop_over_indices_results.back()); + + // Write the accumulator to the output tensor. + SmallVector loop_over_indices_results_unpacked = + Unpack(loop_over_indices_results, + {1, description_.index_vector_length, 1, 1, 1}); + Value result_slice_id = loop_over_indices_results_unpacked[0].front(); + auto result_offsets = + PadWithZeros(loop_over_indices_results_unpacked[1], output_rank, b); + Value result_is_inbounds = loop_over_indices_results_unpacked[2].front(); + Value result_acc = loop_over_indices_results_unpacked[3].front(); + Value result_output = loop_over_indices_results_unpacked[4].front(); + result_output = helper.WriteAccumulatorToOutput( + b, result_is_inbounds, thread_and_block_ids, result_slice_id, + slice_indexing, result_offsets, result_acc, result_output); + + b.create(result_output); return absl::OkStatus(); } diff --git a/xla/service/gpu/fusions/scatter_mlir.h b/xla/service/gpu/fusions/scatter_mlir.h index 6b555c17c0490..676123d74b11a 100644 --- a/xla/service/gpu/fusions/scatter_mlir.h +++ b/xla/service/gpu/fusions/scatter_mlir.h @@ -147,28 +147,36 @@ class ScatterWithDistributedUpdates : public MlirScatterFusion { %acc = vector // #indices_map - for %i = 0 to %num_indices_per_warp_ step 1 { - %new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %i)) - %indices_changed = EmitInequalityCheck(%new_indices, %indices) - if (%indices_changed && %i != 0) { - %output_tensor = WriteAccumulatorToTheOutput(%acc, %output_tensor); - } - if (%indices_changed) { - %inbounds = EmitBoundsCheck(%new_indices, %slice_shape, %output_shape) - } - if (%inbounds) { + %updated_accumulator, %updated_out = for %i = 0 to %num_indices_per_warp_ { + %new_indices = PadWithZeros(ExtractOffsets(%indices_operand, %i)) + %indices_changed = EmitInequalityCheck(%new_indices, %indices) + if (%indices_changed && %i != 0) { + %output_tensor = WriteAccumulatorToOutput(%current_acc, %current_out); + } + if (%indices_changed) { + %inbounds = EmitBoundsCheck(%new_indices, %slice_shape, %output_shape) + } + if (%inbounds) { + if (%indices_changed) { // updates_map(%i) for %j = 0 to %num_slice_iterations_per_warp step 1 { for %k = 0 to %vector_size step 1 { %update_elem = GetUpdateElement - %acc = %indices_changed ? %update_elem : Reduce(%update_elem, %acc) - if (%i = %num_indices_per_warp - 1) { - %output_tensor = WriteAccumulatorToTheOutput(%acc, %output_tensor); - } + %acc = %update_elem } } - } - } + } else { + // updates_map(%i) + for %j = 0 to %num_slice_iterations_per_warp step 1 { + for %k = 0 to %vector_size step 1 { + %update_elem = GetUpdateElement + %acc = Reduce(%update_elem, %acc) + } + } + } + } +} +%final_out = WriteAccumulatorToOutput(%updated_accumulator, %updated_out); */ class ScatterWithDistributedIndices : public MlirScatterFusion { public: diff --git a/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo b/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo index 69fdf05c86cd3..332eb543af61b 100644 --- a/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo +++ b/xla/service/gpu/fusions/tests/scatter/sorted_indices.hlo @@ -9,13 +9,13 @@ add { } scatter { %operand = f32[100] parameter(0) - %indices = s32[2000,1] parameter(1) - %update = f32[2000,32] parameter(2) + %indices = s32[2001,1] parameter(1) + %update = f32[2001,32] parameter(2) ROOT %scatter = f32[100] scatter( f32[100] %operand, - s32[2000,1] %indices, - f32[2000,32] %update + s32[2001,1] %indices, + f32[2001,32] %update ), update_window_dims={1}, inserted_window_dims={}, diff --git a/xla/service/gpu/fusions/triton/emitter_helpers.cc b/xla/service/gpu/fusions/triton/emitter_helpers.cc index 60f4132b9e7f1..7f3b990219c23 100644 --- a/xla/service/gpu/fusions/triton/emitter_helpers.cc +++ b/xla/service/gpu/fusions/triton/emitter_helpers.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -64,13 +65,9 @@ namespace mh = ::mlir::mhlo; namespace mm = ::mlir::math; namespace mt = ::mlir::triton; -ScalarOrTensor::ScalarOrTensor(mlir::Value value) { - if (auto tt = mlir::dyn_cast(value.getType())) { - CHECK_GT(tt.getRank(), 0); - value_ = TensorValue{value}; - } else { - value_ = ScalarValue{value}; - } +ScalarOrTensor::ScalarOrTensor(mlir::Value value) : value_(value) { + CHECK(IsScalar() || UnwrapTensor().getType().getRank() > 0) + << "0D tensors are not supported by Triton"; } SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { @@ -313,7 +310,7 @@ Value Minimum(EmitterLocOpBuilder& b, const se::DeviceDescription& device_info, ScalarOrTensor Splat(EmitterLocOpBuilder& b, ScalarOrTensor value, ArrayRef shape) { CHECK(!shape.empty()); - auto type = mlir::RankedTensorType::get(shape, value.Type()); + auto type = mlir::RankedTensorType::get(shape, value.getType()); return ScalarOrTensor(b.create(type, value.UnwrapUnsafe())); } diff --git a/xla/service/gpu/fusions/triton/emitter_helpers.h b/xla/service/gpu/fusions/triton/emitter_helpers.h index fe283bada6f5e..7e20b6b3f6157 100644 --- a/xla/service/gpu/fusions/triton/emitter_helpers.h +++ b/xla/service/gpu/fusions/triton/emitter_helpers.h @@ -48,6 +48,8 @@ namespace xla::gpu::triton { // non-0D tensor. An attempt to use this class with 0D tensors will CHECK-fail // because 0D tensors are not supported by Triton. class ScalarOrTensor { + using TensorValue = mlir::TypedValue; + public: ScalarOrTensor() = default; @@ -55,17 +57,17 @@ class ScalarOrTensor { // value is a 0D tensor, because Triton does not support 0D tensors. explicit ScalarOrTensor(mlir::Value value); - bool IsScalar() const { return std::holds_alternative(value_); } - bool IsTensor() const { return std::holds_alternative(value_); } + bool IsScalar() const { return !IsTensor(); } + bool IsTensor() const { return mlir::isa(value_); } - mlir::Value UnwrapScalar() { + mlir::Value UnwrapScalar() const { CHECK(IsScalar()); - return std::get(value_).scalar_value; + return value_; } - mlir::Value UnwrapTensor() { + TensorValue UnwrapTensor() const { CHECK(IsTensor()); - return std::get(value_).tensor_value; + return mlir::cast(value_); } // Returns the underlying value regardless of whether it is a scalar or a @@ -73,25 +75,12 @@ class ScalarOrTensor { // both needs to use an `mlir::Value` and functions identically for scalars // and tensors. In other cases, prefer to use the `UnwrapScalar` or // `UnwrapTensor` methods. - mlir::Value UnwrapUnsafe() { - if (auto* scalar = std::get_if(&value_)) { - return scalar->scalar_value; - } - return std::get(value_).tensor_value; - } + mlir::Value UnwrapUnsafe() const { return value_; } - mlir::Type Type() { return UnwrapUnsafe().getType(); } + mlir::Type getType() const { return value_.getType(); } private: - struct ScalarValue { - mlir::Value scalar_value; - }; - - struct TensorValue { - mlir::Value tensor_value; - }; - - std::variant value_; + mlir::Value value_; }; // Triton requires that all block dimensions are a power of 2. diff --git a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 46655c5be8622..d0afa63f72177 100644 --- a/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -218,7 +218,7 @@ absl::StatusOr EmitReduce( *::xla::Cast(tiled_hlo_reduce.hlo()); ScalarOrTensor input = values[tiled_hlo_reduce.operand(0)]; llvm::ArrayRef input_shape = - mlir::cast(input.Type()).getShape(); + mlir::cast(input.getType()).getShape(); absl::Span source_tensor_shape = hlo_reduce.operand(0)->shape().dimensions(); @@ -511,7 +511,7 @@ absl::StatusOr EmitTiledReshape(EmitterLocOpBuilder& b, // At this point we know that the input is a non-0D tensor. - auto input_shaped_type = mlir::cast(input.Type()); + auto input_shaped_type = mlir::cast(input.getType()); // Handle the case of reshaping [1,1,1...] to a scalar. if (tile_sizes.empty()) { @@ -621,7 +621,7 @@ absl::StatusOr EmitTiledHloInstruction( // as i8. It's important to type checking that we perform a conversion after // loading if the type of the loaded parameter does not match what is // expected. - Type loaded_element_type = getElementTypeOrSelf(parameter.Type()); + Type loaded_element_type = getElementTypeOrSelf(parameter.getType()); TF_ASSIGN_OR_RETURN(Type expected_element_type, TritonType(b, hlo->shape().element_type())); @@ -976,7 +976,7 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, // as i8. It's important to type checking that we perform a conversion before // storing if the type of the result does not match the type of the output // pointer. - Type result_element_type = getElementTypeOrSelf(result.Type()); + Type result_element_type = getElementTypeOrSelf(result.getType()); Type result_storage_type = StorageType(b, result_element_type); if (result_element_type != result_storage_type) { diff --git a/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc b/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc index d4c84259f2dbd..08d4bc8894a2e 100644 --- a/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc +++ b/xla/service/gpu/fusions/triton/xla_triton_sparse_passes.cc @@ -360,7 +360,7 @@ struct SparseBlockedToMMAPass auto pattern = std::make_unique(context, compute_capability); RewritePatternSet patterns(context, std::move(pattern)); - if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) { + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { return signalPassFailure(); } } @@ -975,8 +975,7 @@ struct SparseWGMMAOpToLLVMPass MLIRContext *context = &getContext(); auto pattern = std::make_unique(context); RewritePatternSet patterns(context, std::move(pattern)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/xla/service/gpu/gpu_aot_compilation_test.cc b/xla/service/gpu/gpu_aot_compilation_test.cc index 945f63a1f87c0..76efde170bca3 100644 --- a/xla/service/gpu/gpu_aot_compilation_test.cc +++ b/xla/service/gpu/gpu_aot_compilation_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/literal_util.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/fusions/triton/triton_support.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/xla/service/gpu/gpu_collective_combiner_utils.cc b/xla/service/gpu/gpu_collective_combiner_utils.cc index d789b652df6d4..43a99ea4fe612 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils.cc +++ b/xla/service/gpu/gpu_collective_combiner_utils.cc @@ -25,14 +25,11 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" namespace xla::gpu { - -using MemoryAwareScheduler = std::function( - const HloModule*, int64_t, int64_t*)>; - namespace { int64_t GetDefaultValue(HloOpcode opcode) { @@ -52,13 +49,13 @@ int64_t GetDefaultValue(HloOpcode opcode) { int64_t ComputeSuggestedCombinerThreshold( const HloModule& module, const se::DeviceDescription& device_info, - MemoryAwareScheduler scheduler, HloOpcode collective_opcode, - int64_t pointer_size) { + HloOpcode collective_opcode, int64_t pointer_size) { int64_t base_limit = module.config().device_memory_size() != 0 ? module.config().device_memory_size() : device_info.device_memory_size(); int64_t peak_memory_bytes = -1; - auto mem_schedule = scheduler(&module, pointer_size, &peak_memory_bytes); + auto mem_schedule = ScheduleGpuModuleWithMemoryScheduler( + &module, pointer_size, &peak_memory_bytes); if (!mem_schedule.ok() || peak_memory_bytes == -1) { VLOG(1) << "Cannot schedule module: " << mem_schedule.status().message(); diff --git a/xla/service/gpu/gpu_collective_combiner_utils.h b/xla/service/gpu/gpu_collective_combiner_utils.h index 38a7890decb59..d78abf552eeb3 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils.h +++ b/xla/service/gpu/gpu_collective_combiner_utils.h @@ -17,10 +17,8 @@ limitations under the License. #define XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ #include -#include #include "absl/status/status.h" -#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -36,9 +34,6 @@ namespace xla::gpu { // `collective_opcode`. int64_t ComputeSuggestedCombinerThreshold( const HloModule& module, const se::DeviceDescription& device_info, - std::function(const HloModule*, int64_t, - int64_t*)> - scheduler, HloOpcode collective_opcode, int64_t pointer_size); // Adds information that `instr` has been pipelined to the diff --git a/xla/service/gpu/gpu_collective_combiner_utils_test.cc b/xla/service/gpu/gpu_collective_combiner_utils_test.cc index f0b213f343e58..9d7a959664161 100644 --- a/xla/service/gpu/gpu_collective_combiner_utils_test.cc +++ b/xla/service/gpu/gpu_collective_combiner_utils_test.cc @@ -19,27 +19,20 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_pipeliner.h" -#include "xla/service/collective_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" namespace xla::gpu { namespace { @@ -65,8 +58,7 @@ TEST_F(CollectiveCombinerUtilsTest, device_info.set_device_memory_size(20000); int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( - *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size); + *module, device_info, HloOpcode::kAllReduce, pointer_size); // device size = 20000 bytes // slop factor = 0.95 @@ -96,8 +88,7 @@ TEST_F(CollectiveCombinerUtilsTest, stream_executor::DeviceDescription device_info; int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( - *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kAllReduce, pointer_size); + *module, device_info, HloOpcode::kAllReduce, pointer_size); // device size = 20000 bytes // slop factor = 0.95 @@ -106,45 +97,6 @@ TEST_F(CollectiveCombinerUtilsTest, EXPECT_EQ(suggested_threshold, 6712); } -TEST_F( - CollectiveCombinerUtilsTest, - ComputeSuggestedCombinerThresholdReturnsDefaultValueUponSchedulingFailure) { - absl::string_view kHloText = R"( - HloModule m - - ENTRY ar { - p0 = f32[32,32] parameter(0) - p1 = f32[32,32] parameter(1) - - ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), - custom_call_target="__cublas$gemm" - })"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - int pointer_size = 4; - stream_executor::DeviceDescription device_info; - device_info.set_device_memory_size(20000); - - auto sched_fun = [](const HloModule* m, int64_t p_sz, - int64_t* p) -> absl::StatusOr { - return absl::UnimplementedError("Fail."); - }; - - int64_t suggested_threshold_all_reduce = ComputeSuggestedCombinerThreshold( - *module, device_info, sched_fun, HloOpcode::kAllReduce, pointer_size); - int64_t suggested_threshold_all_gather = ComputeSuggestedCombinerThreshold( - *module, device_info, sched_fun, HloOpcode::kAllGather, pointer_size); - int64_t suggested_threshold_reduce_scatter = - ComputeSuggestedCombinerThreshold(*module, device_info, sched_fun, - HloOpcode::kReduceScatter, - pointer_size); - - EXPECT_EQ(suggested_threshold_all_reduce, kDefaultAllReduceCombineThreshold); - EXPECT_EQ(suggested_threshold_all_gather, kDefaultAllGatherCombineThreshold); - EXPECT_EQ(suggested_threshold_reduce_scatter, - kDefaultReduceScatterCombineThreshold); -} - TEST_F(CollectiveCombinerUtilsTest, AppendPipelinedInstructionAppendsPipelinedInstructionInfoForward) { // This is just a canonical IR which makes it easy to pipeline a collective diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc index 705c0eb327e1b..5c6a5ab6ca172 100755 --- a/xla/service/gpu/gpu_compiler.cc +++ b/xla/service/gpu/gpu_compiler.cc @@ -1676,6 +1676,22 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); } + { + // Because of an issue with JAX remat and `SimplifyFPConversions` (see PR: + // https://github.com/jax-ml/jax/pull/22244), we can only eliminate the + // no-op reduce-precision operations after the last call to + // `SimplifyFPConversions`. We are creating a sub-pipeline here because that + // allows us to test this order in a unit test. + HloPassPipeline& remove_no_op_reduce_precision_pipeline = + pipeline.AddPass( + "remove-no-op-reduce-precision-algebraic-simplifier"); + AlgebraicSimplifierOptions simplifier_options_{simplifier_options}; + simplifier_options_.set_enable_remove_no_op_reduce_precision(true); + remove_no_op_reduce_precision_pipeline + .AddPass>(simplifier_options_, + gpu_version); + } + pipeline.AddPass(/*is_layout_sensitive=*/true); pipeline.AddPass( diff --git a/xla/service/gpu/gpu_compiler_test.cc b/xla/service/gpu/gpu_compiler_test.cc index d60a7f5daedcb..26e8899aa6560 100644 --- a/xla/service/gpu/gpu_compiler_test.cc +++ b/xla/service/gpu/gpu_compiler_test.cc @@ -1554,6 +1554,19 @@ TEST_F(PassOrderTest, GemmRewriterRunsAfterDotNormalizer) { VerifyNotRunInBetween(pass_range, /*pass_regex=*/"algsimp"); } +TEST_F(PassOrderTest, + ReducePrecisionIsRemovedAfterAllCallsToSimplifyFPConversions) { + // Because of an issue with JAX remat and `SimplifyFPConversions` (see PR: + // https://github.com/jax-ml/jax/pull/22244), we can only eliminate the + // no-op reduce-precision operations after the last call to + // `SimplifyFPConversions`. No-op reduce-precisions are removed within + // algebraic simplifier, if the option to remove them is set. In the compiler + // pipeline, this is done as a subpipeline, which should be after the last + // invocation of SimplifyFPConversions. + VerifyPassOrder("simplify-fp-conversions", + "remove-no-op-reduce-precision-algebraic-simplifier"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/coalescing_analysis.cc b/xla/service/gpu/model/coalescing_analysis.cc index a2ceba1f01a29..a583c692c2d8b 100644 --- a/xla/service/gpu/model/coalescing_analysis.cc +++ b/xla/service/gpu/model/coalescing_analysis.cc @@ -548,7 +548,7 @@ std::vector FindContiguousIntervals( } // Case 2: f(thread_x) != thread_x * multiplier. auto intervals = FindIntervals(partitioned_expr.func_of_d0, - {indexing_map.GetDimVars(0).bounds}); + {indexing_map.GetDimVar(0).bounds}); // Case 2.1: g(s) != s. if (partitioned_expr.func_of_s0 != range) { return intervals; diff --git a/xla/service/gpu/reduce_scatter_combiner.cc b/xla/service/gpu/reduce_scatter_combiner.cc index 6b07a79cd4ecd..2d9813dda1e6a 100644 --- a/xla/service/gpu/reduce_scatter_combiner.cc +++ b/xla/service/gpu/reduce_scatter_combiner.cc @@ -26,7 +26,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_collective_combiner_utils.h" -#include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/hlo_domain_map.h" #include "xla/service/reduce_scatter_combiner.h" #include "tsl/platform/statusor.h" @@ -76,8 +75,7 @@ absl::StatusOr GpuReduceScatterCombiner::Run( // Combine as much as possible for pipelined collectives. int previous_combiner_threshold = combine_threshold_in_bytes_; combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( - *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, - HloOpcode::kReduceScatter, pointer_size_); + *module, device_info_, HloOpcode::kReduceScatter, pointer_size_); TF_ASSIGN_OR_RETURN( bool combined_pipelined_instructions, RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); diff --git a/xla/service/gpu/tests/BUILD b/xla/service/gpu/tests/BUILD index 1c524b8c35c70..3140bdd3bbe03 100644 --- a/xla/service/gpu/tests/BUILD +++ b/xla/service/gpu/tests/BUILD @@ -889,7 +889,11 @@ xla_test( srcs = ["nop_custom_call_test.cc"], backends = ["gpu"], deps = [ + "//xla:literal", + "//xla:literal_util", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", + "//xla/tsl/platform:test", "@tsl//tsl/platform:test_main", ], ) diff --git a/xla/service/gpu/tests/gpu_fused_mha_test.cc b/xla/service/gpu/tests/gpu_fused_mha_test.cc index 33214758e230f..abdb9f471d1ce 100644 --- a/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -1263,6 +1263,136 @@ class FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM } }; +class FlashAttentionBMMScaleSegmentMaskSoftmaxBMM + : public MultiHeadedAttentionTest { + protected: + const std::string // NOLINT + GetModuleFlash_Attention_Training_Sequence_Packing_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit_impl, entry_computation_layout={(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})->(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + ENTRY main.22 { + Arg_0.1 = bf16[2,512,2,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,512,2,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[2,512,2,64]{3,2,1,0} parameter(2) + constant.5 = s32[] constant(256) + broadcast.6 = s32[4]{0} broadcast(constant.5), dimensions={} + constant.7 = s32[5]{0} constant({0, 32768, 65536, 98304, 131072}) + custom-call.8 = (bf16[2,2,512,64]{3,1,2,0}, f32[4,2,512]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, broadcast.6, broadcast.6, /*index=5*/constant.7, constant.7), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}, s32[5]{0}, s32[5]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 2}} + get-tuple-element.11 = u8[0]{0} get-tuple-element(custom-call.8), index=2 + get-tuple-element.10 = f32[4,2,512]{2,1,0} get-tuple-element(custom-call.8), index=1 + Arg_3.4 = bf16[2,512,2,64]{3,2,1,0} parameter(3) + get-tuple-element.9 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.8), index=0 + transpose.12 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.9), dimensions={0,2,1,3} + custom-call.13 = (bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.10, Arg_3.4, /*index=5*/transpose.12, broadcast.6, broadcast.6, constant.7, constant.7), custom_call_target="__cudnn$fmhaSoftmaxBackward", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, f32[4,2,512]{2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, s32[4]{0}, s32[4]{0}, s32[5]{0}, s32[5]{0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 2}} + get-tuple-element.17 = u8[0]{0} get-tuple-element(custom-call.13), index=3 + get-tuple-element.14 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=0 + transpose.18 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.14), dimensions={0,2,1,3} + get-tuple-element.15 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=1 + transpose.19 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.15), dimensions={0,2,1,3} + get-tuple-element.16 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.13), index=2 + transpose.20 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.16), dimensions={0,2,1,3} + ROOT tuple.21 = (bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}) tuple(transpose.12, transpose.18, transpose.19, transpose.20) + } // main.22 + )"; + return hlo_text; + } + + const std::string // NOLINT + GetModuleFlash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_HloString_BF16() { // NOLINT + const std::string hlo_text = R"( + HloModule jit_ref, entry_computation_layout={(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})->(bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true} + + _where.9 { + Arg_0.10 = pred[512]{0} parameter(0) + Arg_1.11 = s32[512]{0} parameter(1) + Arg_2.12 = s32[512]{0} parameter(2) + ROOT select.13 = s32[512]{0} select(Arg_0.10, Arg_1.11, Arg_2.12) + } + + floor_divide.14 { + Arg_0.15 = s32[512]{0} parameter(0) + sign.23 = s32[512]{0} sign(Arg_0.15) + Arg_1.16 = s32[] parameter(1) + sign.24 = s32[] sign(Arg_1.16) + broadcast.25 = s32[512]{0} broadcast(sign.24), dimensions={} + compare.26 = pred[512]{0} compare(sign.23, broadcast.25), direction=NE + broadcast.27 = s32[512]{0} broadcast(Arg_1.16), dimensions={} + remainder.28 = s32[512]{0} remainder(Arg_0.15, broadcast.27) + constant.19 = s32[] constant(0) + broadcast.20 = s32[512]{0} broadcast(constant.19), dimensions={} + compare.29 = pred[512]{0} compare(remainder.28, broadcast.20), direction=NE + and.30 = pred[512]{0} and(compare.26, compare.29) + broadcast.21 = s32[512]{0} broadcast(Arg_1.16), dimensions={} + divide.22 = s32[512]{0} divide(Arg_0.15, broadcast.21) + constant.17 = s32[] constant(1) + broadcast.18 = s32[512]{0} broadcast(constant.17), dimensions={} + subtract.31 = s32[512]{0} subtract(divide.22, broadcast.18) + ROOT call.32 = s32[512]{0} call(and.30, subtract.31, divide.22), to_apply=_where.9 + } // floor_divide.14 + + ENTRY main.61 { + Arg_0.1 = bf16[2,512,2,64]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[2,512,2,64]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[2,512,2,64]{3,2,1,0} parameter(2) + iota.8 = s32[512]{0} iota(), iota_dimension=0 + constant.7 = s32[] constant(256) + call.33 = s32[512]{0} call(iota.8, constant.7), to_apply=floor_divide.14 + broadcast.34 = s32[2,512]{1,0} broadcast(call.33), dimensions={1} + reshape.35 = s32[2,512,1]{2,1,0} reshape(broadcast.34) + broadcast.37 = s32[2,512,1]{2,1,0} broadcast(reshape.35), dimensions={0,1,2} + reshape.38 = s32[2,512]{1,0} reshape(broadcast.37) + broadcast.39 = s32[2,512,512]{2,1,0} broadcast(reshape.38), dimensions={0,1} + reshape.36 = s32[2,1,512]{2,1,0} reshape(broadcast.34) + broadcast.40 = s32[2,1,512]{2,1,0} broadcast(reshape.36), dimensions={0,1,2} + reshape.41 = s32[2,512]{1,0} reshape(broadcast.40) + broadcast.42 = s32[2,512,512]{2,1,0} broadcast(reshape.41), dimensions={0,2} + compare.43 = pred[2,512,512]{2,1,0} compare(broadcast.39, broadcast.42), direction=NE + convert.44 = bf16[2,512,512]{2,1,0} convert(compare.43) + reshape.45 = bf16[2,1,512,512]{3,2,1,0} reshape(convert.44) + constant.5 = bf16[] constant(-2.199e+12) + broadcast.6 = bf16[2,1,512,512]{3,2,1,0} broadcast(constant.5), dimensions={} + multiply.46 = bf16[2,1,512,512]{3,2,1,0} multiply(reshape.45, broadcast.6) + custom-call.47 = (bf16[2,2,512,64]{3,1,2,0}, f32[2,2,512]{2,1,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, multiply.46), custom_call_target="__cudnn$fmhaScaleBiasSoftmax", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,1,512,512]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1}} + get-tuple-element.50 = u8[0]{0} get-tuple-element(custom-call.47), index=2 + get-tuple-element.49 = f32[2,2,512]{2,1,0} get-tuple-element(custom-call.47), index=1 + Arg_3.4 = bf16[2,512,2,64]{3,2,1,0} parameter(3) + get-tuple-element.48 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.47), index=0 + transpose.51 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.48), dimensions={0,2,1,3} + custom-call.52 = (bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, bf16[2,2,512,64]{3,1,2,0}, u8[0]{0}) custom-call(Arg_0.1, Arg_1.2, Arg_2.3, get-tuple-element.49, Arg_3.4, /*index=5*/multiply.46, transpose.51), custom_call_target="__cudnn$fmhaScaleBiasSoftmaxBackward", operand_layout_constraints={bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, f32[2,2,512]{2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,1,512,512]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 0.1, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["2", "2", "512", "512"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm1_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["2"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_grad_gemm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "dropout_rate": 0, "seed": 42, "sliding_window_length": 0, "max_seg_per_batch": 1}} + get-tuple-element.56 = u8[0]{0} get-tuple-element(custom-call.52), index=3 + get-tuple-element.53 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=0 + transpose.57 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.53), dimensions={0,2,1,3} + get-tuple-element.54 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=1 + transpose.58 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.54), dimensions={0,2,1,3} + get-tuple-element.55 = bf16[2,2,512,64]{3,1,2,0} get-tuple-element(custom-call.52), index=2 + transpose.59 = bf16[2,512,2,64]{3,2,1,0} transpose(get-tuple-element.55), dimensions={0,2,1,3} + ROOT tuple.60 = (bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}, bf16[2,512,2,64]{3,2,1,0}) tuple(transpose.51, transpose.57, transpose.58, transpose.59) + } // main.61 + )"; + return hlo_text; + } + + template + void TestImpl_Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2() { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 6, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.6.0."; + } + XlaBuilder builder(TestName()); + // Cudnn sequence packing packs multiple batches(segments) into one batch + // using offsets and seqlen tensors to indicate where each segment begins + std::string hlo_string = + GetModuleFlash_Attention_Training_Sequence_Packing_HloString_BF16(); // NOLINT + // Reference implementation is regular attention with segment mask + std::string hlo_string_ref = + GetModuleFlash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_HloString_BF16(); // NOLINT + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{1e-3, 1e-3})); + } +}; + class FlashAttentionBMMScaleSoftmaxBMMF8 : public MultiHeadedAttentionTest {}; class FlashAttentionBMMScaleSoftmaxDropoutBMM @@ -1378,6 +1508,13 @@ XLA_TEST_F(FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM, bfloat16>(); // NOLINT } +// BMM1 - Scale - SegmentMask - Softmax - BMM2 +XLA_TEST_F(FlashAttentionBMMScaleSegmentMaskSoftmaxBMM, + Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2_BF16) { + TestImpl_Flash_Attention_Training_BMM1_SegmentMask_Softmax_BMM2< + bfloat16>(); // NOLINT +} + absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef() { static constexpr absl::string_view hlo_text = R"( diff --git a/xla/service/gpu/tests/nop_custom_call_test.cc b/xla/service/gpu/tests/nop_custom_call_test.cc index d979d18aa8ac9..06df6792eb3e9 100644 --- a/xla/service/gpu/tests/nop_custom_call_test.cc +++ b/xla/service/gpu/tests/nop_custom_call_test.cc @@ -13,9 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace gpu { diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 0dc92c47d2cb5..67f33164fa263 100644 --- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -149,12 +149,14 @@ absl::StatusOr HloCustomCallToCuDnnGraph( GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); const int sliding_window_length = config.sliding_window_length(); + const int max_seg_per_batch = config.max_seg_per_batch(); TF_ASSIGN_OR_RETURN( se::gpu::CudnnGraph graph, se::gpu::GetCudnnFlashAttentionOperationGraph( dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, activation, static_cast(config.fmha_scale()), dropout_rate > 0.0, - dropout_rate, dnn_mask_type, sliding_window_length)); + dropout_rate, dnn_mask_type, sliding_window_length, + max_seg_per_batch)); return graph; } else if (IsFwdCustomCallTofMHAF8(*custom_call)) { TF_ASSIGN_OR_RETURN( @@ -230,12 +232,19 @@ absl::StatusOr HloCustomCallToCuDnnGraph( // Unused fwd_output_shape ++input_index; + const int max_seg_per_batch = config.max_seg_per_batch(); if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL || + max_seg_per_batch > 1) { // skip q_seqlen and kv_seqlen input_index += 2; } + + if (max_seg_per_batch > 1) { + // skip q_offsets and kv_offsets + input_index += 2; + } TF_RET_CHECK(input_index == custom_call->operand_count()); int output_index = 0; @@ -312,7 +321,8 @@ absl::StatusOr HloCustomCallToCuDnnGraph( bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, d_bmm1_rhs, d_bmm2_rhs, bias, dropout_rate, config.seed(), config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, - dnn_mask_type, force_deterministic, sliding_window_length)); + dnn_mask_type, force_deterministic, sliding_window_length, + max_seg_per_batch)); return graph; } else { TF_ASSIGN_OR_RETURN( diff --git a/xla/service/hlo_creation_utils_test.cc b/xla/service/hlo_creation_utils_test.cc index 252345fbbbc5f..debabe09c3c51 100644 --- a/xla/service/hlo_creation_utils_test.cc +++ b/xla/service/hlo_creation_utils_test.cc @@ -15,19 +15,29 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" +#include #include +#include +#include "absl/log/check.h" +#include "absl/types/span.h" +#include "xla/array2d.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { namespace { diff --git a/xla/service/hlo_module_test.cc b/xla/service/hlo_module_test.cc index 339feeb8fd2d4..960f107c9117b 100644 --- a/xla/service/hlo_module_test.cc +++ b/xla/service/hlo_module_test.cc @@ -24,25 +24,37 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/test_compilation_environment.pb.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/xla/service/hlo_runner_pjrt.cc b/xla/service/hlo_runner_pjrt.cc index b4b9e1cd889c3..dce3bc9e1ca5b 100644 --- a/xla/service/hlo_runner_pjrt.cc +++ b/xla/service/hlo_runner_pjrt.cc @@ -23,11 +23,13 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/log/die_if_null.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" @@ -135,6 +137,27 @@ absl::StatusOr GetStaticDeviceAssignmentOrComputeDefault( module.config().num_partitions()); } +std::vector BufferVecToPointerVec( + const absl::Span> buffer) { + std::vector argument_ptrs; + argument_ptrs.resize(buffer.size()); + for (int i = 0; i < buffer.size(); ++i) { + argument_ptrs[i] = buffer[i].get(); + } + + return argument_ptrs; +} + +std::vector> BufferMatToPointerMat( + const absl::Span>> buffer) { + std::vector> argument_ptrs; + argument_ptrs.reserve(buffer.size()); + for (int i = 0; i < buffer.size(); ++i) { + argument_ptrs.push_back(BufferVecToPointerVec(buffer[i])); + } + return argument_ptrs; +} + } // namespace // TODO(b/245550554): Remove the use of PjRtWrappedExecutable. @@ -214,6 +237,9 @@ absl::StatusOr HloRunnerPjRt::GenerateDefaultCompileOptions( compile_options.executable_build_options.set_result_layout( module->entry_computation_layout().result_shape()); + compile_options.executable_build_options.set_use_spmd_partitioning( + module->config().use_spmd_partitioning()); + return compile_options; } @@ -305,36 +331,12 @@ absl::StatusOr HloRunnerPjRt::Execute( ExecutionProfile* profile) { // TODO (b/245550554) : Remove UpdateEntryComputationLayout from runner. UpdateEntryComputationLayout(module.get()); - TF_ASSIGN_OR_RETURN(auto compile_options, GenerateDefaultCompileOptions( - module.get(), run_hlo_passes)); - TF_ASSIGN_OR_RETURN(auto executable, CreateExecutable(std::move(module), run_hlo_passes)); return ExecuteWithExecutable(executable.get(), arguments, {}); } -std::vector HloRunnerPjRt::BufferVecToPointerVec( - const std::vector>& buffer) { - std::vector argument_ptrs; - argument_ptrs.resize(buffer.size()); - for (int i = 0; i < buffer.size(); ++i) { - argument_ptrs[i] = buffer[i].get(); - } - - return argument_ptrs; -} - -std::vector> HloRunnerPjRt::BufferMatToPointerMat( - std::vector>>& buffer) { - std::vector> argument_ptrs; - argument_ptrs.reserve(buffer.size()); - for (int i = 0; i < buffer.size(); ++i) { - argument_ptrs.push_back(BufferVecToPointerVec(buffer[i])); - } - return argument_ptrs; -} - absl::StatusOr> HloRunnerPjRt::CreateExecutable(HloModule* module, CompileOptions compile_options) { @@ -442,7 +444,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( const HloRunnerInterface::ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile) { return ExecuteReplicatedImpl( - [&](absl::Span>& argument_buffer_slices) + [&](absl::Span> argument_buffer_slices) -> absl::StatusOr>> { PjRtWrappedExecutable* wrapped_executable = static_cast(executable); @@ -476,7 +478,7 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( TF_RET_CHECK(device_assignment->computation_count() == 1) << "Only single-computation execution is supported."; return ExecuteReplicatedImpl( - [&](absl::Span>& argument_buffer_slices) + [&](absl::Span> argument_buffer_slices) -> absl::StatusOr>> { TF_RET_CHECK(options.use_threads); @@ -538,26 +540,29 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( std::function>>( - absl::Span>&)> + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment) { + TF_RET_CHECK(options.infeed_values.empty() || + options.infeed_values.size() == options.num_replicas); + + std::vector replica_devices(options.num_replicas, nullptr); std::vector>> argument_buffer_slices; argument_buffer_slices.reserve(options.num_replicas); - for (int64_t i = 0; i < options.num_replicas; ++i) { - TF_ASSIGN_OR_RETURN(PjRtDevice * device_ptr, + // Amortize device lookup. + TF_ASSIGN_OR_RETURN(PjRtDevice* const device_ptr, pjrt_client_->LookupDevice( DeviceIdForInvocation(*device_assignment, i))); + replica_devices[i] = device_ptr; // Transfer literals to device. const int64_t argument_count = argument_count_provider(i); - std::vector> replica_buffers; replica_buffers.reserve(argument_count); - for (int64_t arg_index = 0; arg_index < argument_count; arg_index++) { const Literal* const argument = argument_provider(i, arg_index); TF_RET_CHECK(argument != nullptr); @@ -570,37 +575,93 @@ absl::StatusOr> HloRunnerPjRt::ExecuteReplicatedImpl( : pjrt_client_->BufferFromHostLiteral(*argument, device_ptr)); replica_buffers.push_back(std::move(assignment)); } - argument_buffer_slices.push_back(std::move(replica_buffers)); } - TF_RET_CHECK(options.infeed_values.empty() || - options.infeed_values.size() == options.num_replicas); - - if (!options.infeed_values.empty()) { - // TODO(b/245550554): Infeed/Outfeed + // Handle infeed and outfeed. + const bool has_infeed = !options.infeed_values.empty(); + const bool has_outfeed = ShapeUtil::IsInitialized(options.outfeed_shape); + std::unique_ptr pool = nullptr; + absl::Mutex infeed_outfeed_status_mu; + absl::Status infeed_outfeed_status = absl::OkStatus(); + if (has_infeed || has_outfeed) { + // One infeed per infeed value and one outfeed per replica. + const int64_t num_threads = + options.infeed_values.size() + (has_outfeed ? options.num_replicas : 0); + pool = std::make_unique( + tsl::Env::Default(), "infeed_outfeed", num_threads); } - - if (ShapeUtil::IsInitialized(options.outfeed_shape)) { - // TODO(b/245550554): Infeed/Outfeed + if (has_infeed) { + for (int64_t i = 0; i < options.num_replicas; ++i) { + pool->Schedule( + [device = replica_devices[i], + &infeed_literal = *ABSL_DIE_IF_NULL(options.infeed_values[i]), + infeed_steps = options.infeed_steps, &infeed_outfeed_status_mu, + &infeed_outfeed_status]() { + VLOG(1) << "Starting infeed on device " << device->ToString(); + absl::Status per_feed_status = absl::OkStatus(); + for (int64_t step = 1; infeed_steps < 0 || step <= infeed_steps; + ++step) { + per_feed_status.Update(device->TransferToInfeed(infeed_literal)); + if (step % 100 == 0) { + VLOG(1) << "Infeed step " << step; + } + } + absl::MutexLock lock(&infeed_outfeed_status_mu); + infeed_outfeed_status.Update(per_feed_status); + }); + } + } + if (has_outfeed) { + if (options.outfeed_values != nullptr) { + options.outfeed_values->resize(options.num_replicas); + } + for (int64_t i = 0; i < options.num_replicas; ++i) { + pool->Schedule([i, device = replica_devices[i], + outfeed_values = options.outfeed_values, + outfeed_shape = options.outfeed_shape, + infeed_steps = options.infeed_steps, + &infeed_outfeed_status_mu, &infeed_outfeed_status]() { + VLOG(1) << "Starting outfeed on device " << device->ToString(); + absl::Status per_feed_status = absl::OkStatus(); + for (int64_t step = 1; infeed_steps < 0 || step <= infeed_steps; + ++step) { + Literal literal(outfeed_shape); + per_feed_status.Update(device->TransferFromOutfeed(&literal)); + if (outfeed_values != nullptr) { + outfeed_values->at(i) = std::move(literal); + } + if (step % 100 == 0) { + VLOG(1) << "Outfeed step " << step; + } + } + absl::MutexLock lock(&infeed_outfeed_status_mu); + infeed_outfeed_status.Update(per_feed_status); + }); + } } - auto mat = BufferMatToPointerMat(argument_buffer_slices); - - auto span = absl::Span>(mat); - - TF_ASSIGN_OR_RETURN(auto results, execution_helper(span)); - std::vector exec_results; - exec_results.reserve(options.num_replicas); + VLOG(1) << "Replicated execution started"; + TF_ASSIGN_OR_RETURN( + const std::vector> result_buffers, + execution_helper(BufferMatToPointerMat(argument_buffer_slices))); + VLOG(1) << "Replicated execution terminated"; + // Get the result from execution. + std::vector result_literals; + result_literals.reserve(options.num_replicas); for (int64_t i = 0; i < options.num_replicas; ++i) { TF_ASSIGN_OR_RETURN(Literal literal, - TransferLiteralFromDevice(*results[i])); - - exec_results.push_back(std::move(literal)); + TransferLiteralFromDevice(*result_buffers[i])); + result_literals.push_back(std::move(literal)); } - return std::move(exec_results); + // Join infeed and outfeed threads, if they exist. The thread pool's threads + // are joined on destruction. No-op otherwise. + pool = nullptr; + TF_RETURN_IF_ERROR(infeed_outfeed_status); + + return std::move(result_literals); } absl::string_view HloRunnerPjRt::Name() const { return "HloRunnerPjRt"; } diff --git a/xla/service/hlo_runner_pjrt.h b/xla/service/hlo_runner_pjrt.h index dc4ec3921b4a6..db0f258895866 100644 --- a/xla/service/hlo_runner_pjrt.h +++ b/xla/service/hlo_runner_pjrt.h @@ -25,7 +25,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_runner_interface.h" #include "xla/xla_data.pb.h" @@ -118,28 +124,22 @@ class HloRunnerPjRt : public HloRunnerInterface { } private: - std::unique_ptr pjrt_client_; - DeviceShapeRepresentationFn device_shape_representation_fn_; - DeviceShapeSizeFn device_shape_size_fn_; - bool use_parameter_layout_on_device_ = false; - - std::vector BufferVecToPointerVec( - const std::vector>& buffer); - - std::vector> BufferMatToPointerMat( - std::vector>>& buffer); - absl::StatusOr GenerateDefaultCompileOptions( HloModule* module, bool run_hlo_passes); absl::StatusOr> ExecuteReplicatedImpl( std::function>>( - absl::Span>&)> + absl::Span>)> execution_helper, std::function argument_count_provider, std::function argument_provider, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment); + + std::unique_ptr pjrt_client_; + DeviceShapeRepresentationFn device_shape_representation_fn_; + DeviceShapeSizeFn device_shape_size_fn_; + bool use_parameter_layout_on_device_ = false; }; } // namespace xla diff --git a/xla/service/hlo_schedule_test.cc b/xla/service/hlo_schedule_test.cc index d18c8527893c8..fd89bcc5b23fc 100644 --- a/xla/service/hlo_schedule_test.cc +++ b/xla/service/hlo_schedule_test.cc @@ -22,19 +22,20 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/log/log.h" -#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/literal_util.h" +#include "xla/service/buffer_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/xla/service/hlo_verifier.cc b/xla/service/hlo_verifier.cc index 88823f1dd9e5c..9e84f287beb87 100644 --- a/xla/service/hlo_verifier.cc +++ b/xla/service/hlo_verifier.cc @@ -2483,6 +2483,27 @@ absl::Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { return absl::OkStatus(); } +// Verifies that leaf nodes in an original value contain values. +absl::Status VerifyOriginalValue(const HloModule& module) { + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + if (auto original_value = instruction->original_value()) { + // An original value is expected to have intermediate nodes that are + // always nullopt and leaves with actual values. + for (const auto& leaf : original_value->leaves()) { + if (!leaf.second.has_value()) { + return Internal( + "Leaf nodes in an original value is expected to contain values." + " Instruction: %s.", + instruction->ToString()); + } + } + } + } + } + return absl::OkStatus(); +} + // Checks various invariants of channel instructions (send/recv and // collectives). absl::Status VerifyChannels(const HloModule& module, @@ -3117,6 +3138,7 @@ absl::StatusOr HloVerifier::Run( TF_RETURN_IF_ERROR(module->buffer_donor_config().Verify(*module)); TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module)); + TF_RETURN_IF_ERROR(VerifyOriginalValue(*module)); return false; }(); if (status_or_changed.ok()) { diff --git a/xla/service/hlo_verifier_test.cc b/xla/service/hlo_verifier_test.cc index 419156664e7f4..6e2207726caeb 100644 --- a/xla/service/hlo_verifier_test.cc +++ b/xla/service/hlo_verifier_test.cc @@ -3635,5 +3635,19 @@ TEST_F(HloVerifierTest, UnaryOpWithResultAccuracy) { EXPECT_TRUE(status.ok()) << status; } +TEST_F(HloVerifierTest, EmptyLeafInOriginalValue) { + const std::string hlo_string = R"( +HloModule module +ENTRY %entry_computation { + ROOT op = ((f32[], f32[3]{0}), f32[2,3]) parameter(0), origin={(({}, {"v2"}), {"v3"})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_FALSE(status.ok()); +} + } // namespace } // namespace xla diff --git a/xla/service/latency_hiding_scheduler.cc b/xla/service/latency_hiding_scheduler.cc index 6532e9c993407..d199e1f046daa 100644 --- a/xla/service/latency_hiding_scheduler.cc +++ b/xla/service/latency_hiding_scheduler.cc @@ -384,19 +384,17 @@ AsyncTracker::RecursivelyComputeResourceMap( int64_t AsyncTracker::GetNumResourcesPerInstruction( int64_t resource_type, const HloInstruction& instr) const { - // For instructions not calling a computation then return 1 if the instruction - // has opcode equal to 'async_done' + // For instructions not calling a computation, or async start/done + // instructions, we directly check the resources from the instruction. if (instr.called_computations().empty() || instr.opcode() == HloOpcode::kAsyncStart || instr.opcode() == HloOpcode::kAsyncDone) { - return absl::c_any_of(GetResourcesFromInstruction(instr), - [resource_type](const ResourcePair& resource) { - return resource.second == - ResourceUsageType::kResourceOccupy && - (resource_type == resource.first); - }) - ? 1 - : 0; + return absl::c_count_if(GetResourcesFromInstruction(instr), + [resource_type](const ResourcePair& resource) { + return resource.second == + ResourceUsageType::kResourceOccupy && + (resource_type == resource.first); + }); } int64_t num_resources = 0; for (const HloComputation* computation : instr.called_computations()) { diff --git a/xla/service/spmd/shardy/round_trip_common/BUILD b/xla/service/spmd/shardy/round_trip_common/BUILD index 48fb0862daa5f..b3ab4176a0be7 100644 --- a/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/xla/service/spmd/shardy/round_trip_common/BUILD @@ -110,6 +110,7 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], ) diff --git a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index c4d7a13a55bb9..1438d40cf61fc 100644 --- a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h" @@ -48,7 +49,13 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) { // We need to canonicalize redundant mhlo::GetTupleElementOp and // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before // `createOpenWhileFreeVarsShardingPass`. - pm.addPass(mlir::createCanonicalizerPass()); + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; + config.fold = false; + config.cseConstants = false; + // TODO(tomnatan): consider only enabling the specific passes we need. + pm.addPass(mlir::createCanonicalizerPass(config)); // Shardy is currently operating on stablehlo, since this is what JAX // emits. Long term shardy will be fully dialect agnostic, and both mhlo // and stablehlo can register their ops for sdy propagation. diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc index 1abdf7359f71b..9d0912d4b4c5a 100644 --- a/xla/service/spmd/spmd_partitioner.cc +++ b/xla/service/spmd/spmd_partitioner.cc @@ -3355,48 +3355,36 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicSlice(HloInstruction* hlo) { if (hlo->sharding().IsTileMaximal()) { return DefaultAction(hlo); } - - // Replicate along the slice dims to get temp_sharding. - std::vector slice_dims; for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (hlo->dynamic_slice_sizes()[i] != - hlo->operand(0)->shape().dimensions(i)) { - slice_dims.push_back(i); + if (hlo->sharding().tile_assignment().dim(i) != 1 && + hlo->dynamic_slice_sizes()[i] != + hlo->operand(0)->shape().dimensions(i)) { + // We currently do not partition the sliced dimensions. + return DefaultAction(hlo); } } - const HloSharding temp_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(), - slice_dims); - - // Reshard the input to temp_sharding. - HloInstruction* input_with_temp_sharding = - GetPartitionedHlo(hlo->operand(0)).Reshard(temp_sharding).hlo(); - - std::vector new_indices; - new_indices.reserve(hlo->shape().rank()); - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (hlo->dynamic_slice_sizes()[i] != + std::vector new_indices(hlo->shape().rank()); + auto new_input = + GetPartitionedHlo(hlo->operand(0)).Reshard(hlo->sharding()).hlo(); + for (int64_t i = 0; i < new_indices.size(); ++i) { + if (hlo->dynamic_slice_sizes()[i] == hlo->operand(0)->shape().dimensions(i)) { - new_indices.push_back( - GetPartitionedHlo(hlo->operand(i + 1)).Replicate().hlo()); - } else { - // Index must be clamped to be 0. - new_indices.push_back(CreateZero(hlo->operand(i + 1)->shape(), &b_)); + // Trivial slice dim: index must be clampped to 0. + new_indices[i] = CreateZero(hlo->operand(i + 1)->shape(), &b_); + continue; } + // Replicate the indices.; + new_indices[i] = GetPartitionedHlo(hlo->operand(i + 1)) + .Reshard(HloSharding::Replicate()) + .hlo(); } - - // Apply dynamic slice with temp_sharding. - Shape temp_sharded_shape = MakePartitionedShape(hlo->shape(), temp_sharding); - HloInstruction* ds_with_temp_sharding = - b_.AddInstruction(HloInstruction::CreateDynamicSlice( - temp_sharded_shape, input_with_temp_sharding, new_indices, - temp_sharded_shape.dimensions())); - ds_with_temp_sharding->set_sharding(temp_sharding); - - // Reshard the output to the final sharding. - SetPartitionedHlo(hlo, PartitionedHlo(ds_with_temp_sharding, hlo->shape(), - MakePartitioningState()) - .Reshard(hlo->sharding())); + SetPartitionedHlo(hlo, [&]() { + auto partitioned_shape = + MakePartitionedShape(hlo->shape(), hlo->sharding()); + return b_.AddInstruction(HloInstruction::CreateDynamicSlice( + partitioned_shape, new_input, new_indices, + partitioned_shape.dimensions())); + }); return absl::OkStatus(); } diff --git a/xla/service/spmd/spmd_partitioner_test.cc b/xla/service/spmd/spmd_partitioner_test.cc index 727448674a5e1..8e9823d413ac4 100644 --- a/xla/service/spmd/spmd_partitioner_test.cc +++ b/xla/service/spmd/spmd_partitioner_test.cc @@ -7531,7 +7531,7 @@ ENTRY entry { EXPECT_THAT(root, op::PartitionId()); } -TEST_P(SpmdPartitioningTest, DynamicSlicePartitionedBatchDimension) { +TEST_P(SpmdPartitioningTest, DynamicSliceAlongNonPartitionedDimension) { absl::string_view hlo_string = R"( HloModule module @@ -7539,71 +7539,19 @@ ENTRY entry { %input = s32[128,64] parameter(0), sharding={devices=[2,1]0,1} %index = s32[] parameter(1) %trivial_index = s32[] parameter(2) - ROOT %dynamic-slice = s32[128,16] dynamic-slice(%input, %trivial_index, %index), - dynamic_slice_sizes={128,16}, sharding={devices=[2,1]0,1} + ROOT %dynamic-slice = s32[128,2] dynamic-slice(%input, %trivial_index, %index), + dynamic_slice_sizes={128,2}, sharding={devices=[2,1]0,1} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); + const auto root = module->entry_computation()->root_instruction(); auto input = AllOf(op::Parameter(0), op::Shape("s32[64,64]")); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(root, AllOf(op::DynamicSlice(input, op::Constant(), op::Parameter(1)), - op::Shape("s32[64,16]"))); -} - -TEST_P(SpmdPartitioningTest, DynamicSlicePartitionedSliceDimension) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - %input = s32[128,64] parameter(0), sharding={devices=[1,2]0,1} - %index = s32[] parameter(1) - %trivial_index = s32[] parameter(2) - ROOT %dynamic-slice = s32[128,16] dynamic-slice(%input, %trivial_index, %index), - dynamic_slice_sizes={128,16}, sharding={devices=[1,2]0,1} -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/2)); - - auto input = AllOf(op::Parameter(0), op::Shape("s32[128,32]")); - auto input_replicated = - AllOf(op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), input, _, _)), - op::Shape("s32[128,64]")); - auto ds_replicated = AllOf( - op::DynamicSlice(input_replicated, op::Constant(), op::Parameter(1)), - op::Shape("s32[128,16]")); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - AllOf(op::DynamicSlice(ds_replicated, _, _), op::Shape("s32[128,8]"))); -} - -TEST_P(SpmdPartitioningTest, DynamicSlicePartitionedBothDimensions) { - absl::string_view hlo_string = R"( -HloModule module - -ENTRY entry { - %input = s32[128,64] parameter(0), sharding={devices=[2,2]<=[4]} - %index = s32[] parameter(1) - %trivial_index = s32[] parameter(2) - ROOT %dynamic-slice = s32[128,16] dynamic-slice(%input, %trivial_index, %index), - dynamic_slice_sizes={128,16}, sharding={devices=[2,2]<=[4]} -})"; - - TF_ASSERT_OK_AND_ASSIGN(auto module, - PartitionComputation(hlo_string, /*num_devices=*/4)); - - auto input = AllOf(op::Parameter(0), op::Shape("s32[64,32]")); - auto input_reshard = - AllOf(op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(), input, _, _)), - op::Shape("s32[64,64]")); - auto ds = - AllOf(op::DynamicSlice(input_reshard, op::Constant(), op::Parameter(1)), - op::Shape("s32[64,16]")); - EXPECT_THAT(module->entry_computation()->root_instruction(), - AllOf(op::DynamicSlice(ds, _, _), op::Shape("s32[64,8]"))); + op::Shape("s32[64,2]"))); } TEST_P(SpmdPartitioningTest, DynamicUpdateSliceAlongNonPartitionedDimension) { diff --git a/xla/service/triangular_solve_expander_test.cc b/xla/service/triangular_solve_expander_test.cc index fa382b24d0d9d..1a2ba8c71ece6 100644 --- a/xla/service/triangular_solve_expander_test.cc +++ b/xla/service/triangular_solve_expander_test.cc @@ -15,15 +15,20 @@ limitations under the License. #include "xla/service/triangular_solve_expander.h" +#include #include +#include #include +#include "xla/array2d.h" +#include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" namespace xla { namespace { diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc index 57448f9c01319..cc1494e5096f6 100644 --- a/xla/stream_executor/cuda/cuda_dnn.cc +++ b/xla/stream_executor/cuda/cuda_dnn.cc @@ -4965,6 +4965,10 @@ static absl::StatusOr RebuildExecutionPlan( } // namespace +void FixDimsForRaggedOffset(std::vector& dims, int max_reg_per_batch) { + dims[0] *= max_reg_per_batch; +} + absl::StatusOr GetCudnnFlashAttentionOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_descriptor, @@ -4974,7 +4978,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional bias_descriptor, const std::optional stats_descriptor, double scale, const bool use_dropout, const std::optional dropout_rate, - const dnn::FMHAMaskKind mask_type, const int sliding_window_length) { + const dnn::FMHAMaskKind mask_type, const int sliding_window_length, + const int max_seg_per_batch) { using cudnn_frontend::graph::Tensor_attributes; #if CUDNN_VERSION >= 90000 @@ -5007,23 +5012,34 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::vector q_dims = q_descriptor.GetCudnnCompatibleDimensions(true); + std::vector k_dims = k_descriptor.GetCudnnCompatibleDimensions(true); + std::vector v_dims = + v_descriptor.GetCudnnCompatibleDimensions(false); + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(q_dims, max_seg_per_batch); + FixDimsForRaggedOffset(k_dims, max_seg_per_batch); + FixDimsForRaggedOffset(v_dims, max_seg_per_batch); + } + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") - .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) + .set_dim(q_dims) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) .set_uid(next_uid())); std::shared_ptr k_tensor = graph.tensor(Tensor_attributes() .set_name("K") - .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) + .set_dim(k_dims) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") - .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) + .set_dim(v_dims) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) .set_uid(next_uid())); @@ -5049,9 +5065,9 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - if (is_padding) { - auto q_dim = q_descriptor.GetCudnnCompatibleDimensions(true); - auto b = q_dim[0]; + if (is_padding || max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; auto seq_q_tensor = graph.tensor(Tensor_attributes() .set_name("seq_q") @@ -5070,6 +5086,30 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } + + std::shared_ptr offset_q; + if (max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; + offset_q = + graph.tensor(Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + auto offset_kv = + graph.tensor(Tensor_attributes() + .set_name("offset_kv") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + q_tensor->set_ragged_offset(offset_q); + k_tensor->set_ragged_offset(offset_kv); + v_tensor->set_ragged_offset(offset_kv); + } + // Setting seed and offset std::shared_ptr seed_tensor; std::shared_ptr offset_tensor; @@ -5100,10 +5140,16 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( auto [o_tensor, stats_tensor] = graph.sdpa(q_tensor, k_tensor, v_tensor, sdpa_options); + auto o_dims = o_descriptor.dimensions(); + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(o_dims, max_seg_per_batch); + o_tensor->set_ragged_offset(offset_q); + } // Set output attributes. o_tensor->set_name("O") .set_output(true) - .set_dim(o_descriptor.dimensions()) + .set_dim(o_dims) .set_stride(o_descriptor.GetLogicalStrides()) .set_uid(next_uid()); if (stats_descriptor.has_value()) { @@ -5488,7 +5534,8 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const std::optional bias_descriptor, std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, dnn::FMHAMaskKind mask_type, - bool force_deterministic, const int sliding_window_length) { + bool force_deterministic, const int sliding_window_length, + const int max_seg_per_batch) { #if CUDNN_VERSION >= 90000 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() @@ -5514,19 +5561,38 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - + // Get dims and strides + std::vector q_dims = q_desc.GetCudnnCompatibleDimensions(false); + std::vector k_dims = k_desc.GetCudnnCompatibleDimensions(false); + std::vector v_dims = v_desc.GetCudnnCompatibleDimensions(true); + std::vector p_dims = p_desc.GetCudnnCompatibleDimensions(false); + std::vector p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector do_dims = do_desc.GetCudnnCompatibleDimensions(false); + std::vector dq_dims = dq_desc.dimensions(); + std::vector dk_dims = dk_desc.dimensions(); + std::vector dv_dims = dv_desc.dimensions(); + std::vector stats_dims(p_dims.begin(), p_dims.end() - 1); + stats_dims.push_back(1); // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); + std::vector stats_strides; + stats_strides.reserve(p_strides.size()); int64_t p_reduced_dim_len = p_dims.back(); for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); + stats_strides.push_back(stride / p_reduced_dim_len); + } + stats_strides[3] = 1; + + if (max_seg_per_batch > 1) { + FixDimsForRaggedOffset(q_dims, max_seg_per_batch); + FixDimsForRaggedOffset(k_dims, max_seg_per_batch); + FixDimsForRaggedOffset(v_dims, max_seg_per_batch); + FixDimsForRaggedOffset(p_dims, max_seg_per_batch); + FixDimsForRaggedOffset(do_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dq_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dk_dims, max_seg_per_batch); + FixDimsForRaggedOffset(dv_dims, max_seg_per_batch); + FixDimsForRaggedOffset(stats_dims, max_seg_per_batch); } - p_reduction_strides[3] = 1; bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; auto sdpa_backward_options = @@ -5541,52 +5607,51 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") - .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(q_dims) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") - .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(k_dims) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") - .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) + .set_dim(v_dims) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr stats = graph.tensor(Tensor_attributes() .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) + .set_dim(stats_dims) + .set_stride(stats_strides) .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(do_dims) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); - auto bias_dim = bias_descriptor->dimensions(); - auto q_dim = q_desc.GetCudnnCompatibleDimensions(false); - auto b = bias_dim[0]; - auto n = bias_dim[1]; - auto q_n = q_dim[1]; - auto bias_tensor = - graph.tensor(Tensor_attributes() - .set_name("bias") - .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(next_uid())); + auto bias_dims = bias_descriptor->dimensions(); + auto bias_strides = bias_descriptor->GetLogicalStrides(); + auto b = bias_dims[0]; + auto n = bias_dims[1]; + auto q_n = q_dims[1]; + auto bias_tensor = graph.tensor(Tensor_attributes() + .set_name("bias") + .set_dim(bias_dims) + .set_stride(bias_strides) + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for @@ -5604,7 +5669,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::shared_ptr o = graph.tensor(Tensor_attributes() .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_dim(do_dims) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) .set_uid(next_uid()) .set_data_type(ioDataType)); @@ -5612,9 +5677,10 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - if (is_padding) { - auto q_dim = q_desc.GetCudnnCompatibleDimensions(false); - auto b = q_dim[0]; + + if (is_padding || max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; auto seq_q_tensor = graph.tensor(Tensor_attributes() .set_name("seq_q") @@ -5633,6 +5699,31 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } + + std::shared_ptr offset_q, offset_kv; + if (max_seg_per_batch > 1) { + // Get batch size + auto b = q_dims[0]; + offset_q = + graph.tensor(Tensor_attributes() + .set_name("offset_q") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + offset_kv = + graph.tensor(Tensor_attributes() + .set_name("offset_k") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::INT32)); + q->set_ragged_offset(offset_q); + k->set_ragged_offset(offset_kv); + v->set_ragged_offset(offset_kv); + o->set_ragged_offset(offset_q); + dO->set_ragged_offset(offset_q); + } // Setting seed and offset std::shared_ptr seed_tensor; std::shared_ptr offset_tensor; @@ -5668,20 +5759,25 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( auto [dQ, dK, dV] = graph.sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); + if (max_seg_per_batch > 1) { + dQ->set_ragged_offset(offset_q); + dK->set_ragged_offset(offset_kv); + dV->set_ragged_offset(offset_kv); + } dQ->set_output(true) - .set_dim(dq_desc.dimensions()) + .set_dim(dq_dims) .set_stride(dq_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dQ") .set_data_type(ioDataType); dK->set_output(true) - .set_dim(dk_desc.dimensions()) + .set_dim(dk_dims) .set_stride(dk_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dK") .set_data_type(ioDataType); dV->set_output(true) - .set_dim(dv_desc.dimensions()) + .set_dim(dv_dims) .set_stride(dv_desc.GetLogicalStrides()) .set_uid(next_uid()) .set_name("dV") diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h index 78a43f654b764..9d46794e2329b 100644 --- a/xla/stream_executor/cuda/cuda_dnn.h +++ b/xla/stream_executor/cuda/cuda_dnn.h @@ -707,7 +707,8 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const std::optional bias_descriptor, const std::optional stats_descriptor, double scale, const bool use_dropout, const std::optional dropout_rate, - const dnn::FMHAMaskKind mask_type, const int sliding_window_length); + const dnn::FMHAMaskKind mask_type, const int sliding_window_length, + const int max_seg_per_batch); absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( dnn::DnnSupport& dnn_support, @@ -730,7 +731,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, const dnn::FMHAMaskKind mask_type, bool force_deterministic, - const int sliding_window_length); + const int sliding_window_length, const int max_seg_per_batch); absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, diff --git a/xla/stream_executor/rocm/BUILD b/xla/stream_executor/rocm/BUILD index 7634251861450..6b7d44a829faa 100644 --- a/xla/stream_executor/rocm/BUILD +++ b/xla/stream_executor/rocm/BUILD @@ -820,15 +820,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "rocm_rpath", - linkopts = select({ - "//conditions:default": [ - "-Wl,-rpath,../local_config_rocm/rocm/rocm/lib", - ], - }), -) - cc_library( name = "stream_executor_rocm", tags = [ @@ -837,12 +828,12 @@ cc_library( ], deps = [ ":rocm_platform_id", - ":rocm_rpath", "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", + "@local_config_rocm//rocm:rocm_rpath", ] + if_static( [":all_runtime"], ), diff --git a/xla/tests/BUILD b/xla/tests/BUILD index e8ab69dffb4dc..dbe2e5f2eebe1 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -214,35 +214,27 @@ cc_library( deps = [ ":literal_test_util", ":test_utils", - "//xla:debug_options_flags", "//xla:error_spec", "//xla:literal", - "//xla:literal_util", - "//xla:shape_layout", "//xla:shape_util", "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/hlo/pass:hlo_pass", "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/testlib:verified_hlo_module", - "//xla/hlo/utils:hlo_query", - "//xla/service:backend", - "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", - "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", # reference backend - "//xla/service:platform_util", - "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/platform:errors", + "//xla/tsl/platform:logging", + "//xla/tsl/platform:statusor", + "//xla/tsl/platform:test", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -252,10 +244,7 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@tsl//tsl/platform:errors", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:statusor", - "@tsl//tsl/platform:test", + "@tsl//tsl/platform:protobuf", ], ) @@ -979,6 +968,8 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1024,7 +1015,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -1070,7 +1064,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -1158,7 +1155,10 @@ xla_test( "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/builder/lib:matrix", "//xla/hlo/parser:hlo_parser", + "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/platform:test", + "//xla/tsl/platform:test_benchmark", "@com_google_absl//absl/strings", "@tsl//tsl/platform:ml_dtypes", "@tsl//tsl/platform:test", @@ -2560,6 +2560,7 @@ xla_test( backends = ["gpu"], deps = [ ":hlo_test_base", + ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:literal", diff --git a/xla/tests/custom_call_test.cc b/xla/tests/custom_call_test.cc index 3f264f1996fc6..ff88a0de868cf 100644 --- a/xla/tests/custom_call_test.cc +++ b/xla/tests/custom_call_test.cc @@ -409,6 +409,18 @@ XLA_FFI_DEFINE_HANDLER(kAlwaysFail, AlwaysFail, XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$always_fail", PLATFORM, kAlwaysFail); +static absl::Status Tokens(ffi::Token, ffi::Result, + ffi::Result) { + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER( + kTokens, Tokens, + ffi::Ffi::Bind().Arg().Ret().Ret()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$tokens", PLATFORM, + kTokens); + static absl::Status FfiR0F32Add2(R0F32Buffer in, R0F32ResultBuffer out) { auto in_data = in.typed_data(); auto out_data = out->typed_data(); @@ -843,6 +855,24 @@ XLA_TEST_F(FfiCustomCallTest, FfiReportsSuccess) { EXPECT_EQ(status, absl::OkStatus()); } +XLA_TEST_F(FfiCustomCallTest, Tokens) { + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + + std::vector ret = {ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTokenShape()}; + + auto* token = builder.AddInstruction(HloInstruction::CreateToken()); + builder.AddInstruction(HloInstruction::CreateCustomCall( + ShapeUtil::MakeTupleShape(ret), {token}, "__xla_test$$tokens", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + module->AddEntryComputation(builder.Build()); + + auto status = Execute(std::move(module), {}).status(); + EXPECT_EQ(status, absl::OkStatus()); +} + XLA_TEST_F(FfiCustomCallTest, FfiUnknownTarget) { auto module = CreateNewVerifiedModule(); auto builder = HloComputation::Builder(TestName()); diff --git a/xla/tests/dot_operation_test.cc b/xla/tests/dot_operation_test.cc index 674ada04d96c3..2acc860804d0d 100644 --- a/xla/tests/dot_operation_test.cc +++ b/xla/tests/dot_operation_test.cc @@ -22,21 +22,21 @@ limitations under the License. #include "xla/array3d.h" #include "xla/client/local_client.h" #include "xla/error_spec.h" -#include "xla/hlo/builder/lib/arithmetic.h" #include "xla/hlo/builder/lib/matrix.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/reference_util.h" +#include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/platform/test.h" +#include "xla/tsl/platform/test_benchmark.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/test.h" -#include "tsl/platform/test_benchmark.h" #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" diff --git a/xla/tests/exhaustive/BUILD b/xla/tests/exhaustive/BUILD index 0b0c52554e9a5..9b60d841dd480 100644 --- a/xla/tests/exhaustive/BUILD +++ b/xla/tests/exhaustive/BUILD @@ -250,7 +250,6 @@ exhaustive_xla_test( shard_count = 50, tags = [ "optonly", - "test_xla_cpu_no_thunks", # This is a big test that we skip for capacity reasons in OSS testing. "no_oss", ], diff --git a/xla/tests/hlo_runner_agnostic_test_base.cc b/xla/tests/hlo_runner_agnostic_test_base.cc index 402159a185853..b781a0eebd37d 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.cc +++ b/xla/tests/hlo_runner_agnostic_test_base.cc @@ -30,19 +30,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" -#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" -#include "xla/debug_options_flags.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" @@ -53,11 +47,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/logging.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" +#include "tsl/platform/protobuf.h" namespace xla { diff --git a/xla/tests/hlo_runner_agnostic_test_base.h b/xla/tests/hlo_runner_agnostic_test_base.h index e43ddec3e2892..9b8ae26f615f4 100644 --- a/xla/tests/hlo_runner_agnostic_test_base.h +++ b/xla/tests/hlo_runner_agnostic_test_base.h @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -35,31 +34,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/testlib/verified_hlo_module.h" -#include "xla/layout.h" #include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/backend.h" -#include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" -#include "xla/service/hlo_verifier.h" -#include "xla/service/platform_util.h" -#include "xla/shape_layout.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/test_helpers.h" -#include "xla/tests/literal_test_util.h" +#include "xla/tsl/platform/test.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/test.h" namespace xla { @@ -189,7 +174,7 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // backend, but it might need to be tailored so that it is able to run on the // reference backend. Note that the program shape of the module must not be // modified. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, @@ -197,14 +182,14 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, absl::Span arguments, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and compares the results. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, @@ -212,26 +197,26 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { // Same as above, except that the module will be executed without Hlo // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); // Executes an hlo module with fake inputs and checks that the execution is // successful. - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( std::unique_ptr module, bool run_hlo_passes, const std::function& test_preprocessor = nullptr); // Convenient wrappers for executing and comparing an hlo module with fake // input. Module can be passed in directly, or parsed from an hlo_string, // or loaded from a file. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( + ::testing::AssertionResult RunAndCompare( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr, std::optional args_max_bits_of_precision = std::nullopt); - [[nodiscard]] ::testing::AssertionResult Run( + ::testing::AssertionResult Run( absl::string_view hlo_string, bool run_hlo_passes = true, ExecutionProfile* profile = nullptr, const tsl::protobuf::Message* backend_config = nullptr, @@ -299,19 +284,19 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase { const std::optional& error, bool run_hlo_passes = true); // Executes an hlo module with fake inputs on multiple replicas. - [[nodiscard]] ::testing::AssertionResult RunReplicated( + ::testing::AssertionResult RunReplicated( absl::string_view hlo_string, bool run_hlo_passes = true, int64_t num_replicas = 1, const tsl::protobuf::Message* backend_config = nullptr); // If assert_determinism is true, the assertion will fail unless all runs // produce exactly the same output. - [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( + ::testing::AssertionResult RunMultipleTimes( absl::string_view hlo_string, bool run_hlo_passes, std::vector* profiles, const tsl::protobuf::Message* backend_config = nullptr, bool assert_determinism = false); - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + ::testing::AssertionResult RunAndCompareNoHloPasses( absl::string_view hlo_string, const std::optional& error, const std::function& reference_preprocessor = nullptr, const std::function& test_preprocessor = nullptr); diff --git a/xla/tests/replicated_io_feed_test.cc b/xla/tests/replicated_io_feed_test.cc index 415faa01ff89e..194697936e13a 100644 --- a/xla/tests/replicated_io_feed_test.cc +++ b/xla/tests/replicated_io_feed_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/xla/tsl/distributed_runtime/coordination/coordination_service.cc index d6175c1c1d548..9efc66bdac7a3 100644 --- a/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -1350,8 +1350,9 @@ std::vector CoordinationServiceStandaloneImpl::GetKeyValueDir( for (it = begin; it != kv_store_.end(); ++it) { // Stop once the next key does not have the directory prefix. Since keys are // ordered, none of the other keys would have a matching prefix. - if (std::mismatch(dir.begin(), dir.end(), it->first.begin()).first != - dir.end()) { + if (std::mismatch(dir.begin(), dir.end(), it->first.begin(), + it->first.end()) + .first != dir.end()) { break; } KeyValueEntry kv; @@ -1373,8 +1374,9 @@ absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( auto begin = kv_store_.lower_bound(dir); std::map::iterator end; for (end = begin; end != kv_store_.end(); end++) { - if (std::mismatch(dir.begin(), dir.end(), end->first.begin()).first != - dir.end()) + if (std::mismatch(dir.begin(), dir.end(), end->first.begin(), + end->first.end()) + .first != dir.end()) break; } kv_store_.erase(begin, end); diff --git a/xla/tsl/platform/default/BUILD b/xla/tsl/platform/default/BUILD index f95ba7897dde3..9f8dc1d79cb59 100644 --- a/xla/tsl/platform/default/BUILD +++ b/xla/tsl/platform/default/BUILD @@ -1,5 +1,6 @@ # Tensorflow default + linux implementations of tensorflow/core/platform libraries. load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load( "//xla/tsl:tsl.bzl", "if_cuda_tools", @@ -103,12 +104,16 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", - "@local_config_rocm//rocm:rocm_headers", "@local_config_tensorrt//:tensorrt_headers", "@tsl//tsl/platform:load_library", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path", - ] + if_oss(["@local_config_nccl//:nccl_config"]), + ] + if_oss([ + "@local_config_nccl//:nccl_config", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_config", + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( @@ -264,6 +269,7 @@ cc_library( name = "load_library", srcs = ["load_library.cc"], hdrs = ["@tsl//tsl/platform:load_library.h"], + linkstatic = True, tags = [ "manual", "no_oss", @@ -271,7 +277,9 @@ cc_library( ], deps = [ "@com_google_absl//absl/status", - ], + ] + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_rpath", + ]), ) cc_library( @@ -393,6 +401,7 @@ cc_library( "nobuilder", ], deps = [ + "@local_config_rocm//rocm:rocm_config", "@local_config_rocm//rocm:rocm_headers", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:path",