diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index e6629cbd1aa01..713d3cc387a63 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -136,6 +136,7 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "@com_google_absl//absl/log", diff --git a/xla/service/gpu/fusions/fusion_emitter.cc b/xla/service/gpu/fusions/fusion_emitter.cc index 845c206d22de9..c2481dae3c80b 100644 --- a/xla/service/gpu/fusions/fusion_emitter.cc +++ b/xla/service/gpu/fusions/fusion_emitter.cc @@ -91,15 +91,11 @@ absl::Status AnnotateKernelLaunchDimensions( const se::DeviceDescription& device_info, const LaunchDimensions& launch_dims, const std::string& kernel_name, llvm::Module* llvm_module) { - TF_RET_CHECK( - (device_info.block_dim_limit().x == 0 || - launch_dims.block_counts().x < device_info.block_dim_limit().x) && - (device_info.block_dim_limit().y == 0 || - launch_dims.block_counts().y < device_info.block_dim_limit().y)) + TF_RET_CHECK(device_info.block_dim_limit().x == 0 || + launch_dims.block_counts().x < device_info.block_dim_limit().x) << "Kernel '" << kernel_name << "' launch needs more blocks (" - << launch_dims.block_counts().x << ", " << launch_dims.block_counts().y - << ") than allowed by hardware (" << device_info.block_dim_limit().x - << ", " << device_info.block_dim_limit().y << ")."; + << launch_dims.block_counts().x << ") than allowed by hardware (" + << device_info.block_dim_limit().x << ")."; // Add __launch_bounds__ to metadata. This limits registers per thread to // avoid out-of-resources launching errors. diff --git a/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo b/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo deleted file mode 100644 index b225b378acc73..0000000000000 --- a/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ -// RUN: --inline="default-pipeline='cse'" | FileCheck %s - -bcast { - one = bf16[] constant(1) - ROOT broadcast = bf16[24,2048,2048,3,4096]{4,3,2,1,0} broadcast(one), dimensions={} -} - -// CHECK: func.func @main(%[[ARG0:.*]]: tensor<24x2048x2048x3x4096xbf16> -// CHECK: gpu.block_id x {xla.range = [0 : index, 1207959551 : index]} -// CHECK: gpu.block_id y {xla.range = [0 : index, 1 : index]} -// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]], %[[RD:.*]], %[[RE:.*]]) in -// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) -// CHECK: %[[CST:.*]] = arith.constant 1.000 -// CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ITER]][%[[RA]], %[[RB]], %[[RC]], %[[RD]], %[[RE]]] diff --git a/xla/service/gpu/launch_dimensions.cc b/xla/service/gpu/launch_dimensions.cc index db060f1eb4b66..21cc4758dbc56 100644 --- a/xla/service/gpu/launch_dimensions.cc +++ b/xla/service/gpu/launch_dimensions.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -37,16 +38,19 @@ LaunchDimensions CalculateLaunchDimensions( num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); const int kWarpSchedulers = 4; - int64_t threads_per_block = std::min( + int64_t threads_per_block_x = std::min( gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - int64_t num_blocks_total = CeilOfRatio(num_elements, threads_per_block); - - int64_t num_blocks_y = CeilOfRatio( - num_blocks_total, gpu_device_info.block_dim_limit().x); - int64_t num_blocks_x = CeilOfRatio(num_blocks_total, num_blocks_y); - - return LaunchDimensions(se::BlockDim(num_blocks_x, num_blocks_y, 1), - se::ThreadDim(threads_per_block, 1, 1)); + int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x); + CHECK(num_blocks < gpu_device_info.block_dim_limit().x); + int threads_per_block_y = 1; + if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") { + while ((num_blocks * threads_per_block_x) > std::numeric_limits::max()) { + threads_per_block_x /= 2; + threads_per_block_y *= 2; + } + } + return LaunchDimensions(se::BlockDim(num_blocks, 1, 1), + se::ThreadDim(threads_per_block_x, threads_per_block_y, 1)); } } // namespace gpu diff --git a/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 2810796727ea3..598277c18a9a1 100644 --- a/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -401,8 +401,8 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) { absl::Status status = CompileToExecutable(std::move(hlo_module)).status(); EXPECT_THAT(status.message(), ::testing::ContainsRegex( - "Kernel '.*' launch needs more blocks [(]4294967296, 1[)] " - "than allowed by hardware [(]2147483647, 65535[)]")); + "Kernel '.*' launch needs more blocks [(]4294967296[)] than " + "allowed by hardware [(]2147483647[)]")); } } // namespace diff --git a/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/xla/service/gpu/tests/gpu_too_many_blocks_test.cc index 6c59c8da9bf9b..bf346ed724cf8 100644 --- a/xla/service/gpu/tests/gpu_too_many_blocks_test.cc +++ b/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -40,11 +40,13 @@ TEST_F(TooManyBlocksTest, FailsWithInvalidStatus) { HloModule primitive_computation_mul.8 ENTRY primitive_computation_mul.8 { - parameter.1 = s8[65536] parameter(0) - parameter.2 = s8[65536] parameter(1) - broadcast.3 = s8[65536,65536,65536,128,2] broadcast(parameter.1), dimensions={0} - broadcast.4 = s8[65536,65536,65536,128,2] broadcast(parameter.2), dimensions={1} - ROOT multiply.5 = s8[65536,65536,65536,128,2] multiply(broadcast.3, broadcast.4) + parameter.1 = f32[4,1048576,1,1]{3,2,1,0} parameter(0) + reshape.3 = f32[4,1048576,1]{2,1,0} reshape(parameter.1) + broadcast.4 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.3), dimensions={0,1,3} + parameter.2 = f32[4,1,1048576,1]{3,2,1,0} parameter(1) + reshape.5 = f32[4,1048576,1]{2,1,0} reshape(parameter.2) + broadcast.6 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.5), dimensions={0,2,3} + ROOT multiply.7 = f32[4,1048576,1048576,1]{3,2,1,0} multiply(broadcast.4, broadcast.6) } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module,