Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert edf18ce and fix launch dimension triplet #19582

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cc_library(
deps = [
"//xla:shape_util",
"//xla:util",
"//xla/service:platform_util",
"//xla/stream_executor:device_description",
"//xla/stream_executor:launch_dim",
"@com_google_absl//absl/log",
Expand Down
12 changes: 4 additions & 8 deletions xla/service/gpu/fusions/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ absl::Status AnnotateKernelLaunchDimensions(
const se::DeviceDescription& device_info,
const LaunchDimensions& launch_dims, const std::string& kernel_name,
llvm::Module* llvm_module) {
TF_RET_CHECK(
(device_info.block_dim_limit().x == 0 ||
launch_dims.block_counts().x < device_info.block_dim_limit().x) &&
(device_info.block_dim_limit().y == 0 ||
launch_dims.block_counts().y < device_info.block_dim_limit().y))
TF_RET_CHECK(device_info.block_dim_limit().x == 0 ||
launch_dims.block_counts().x < device_info.block_dim_limit().x)
<< "Kernel '" << kernel_name << "' launch needs more blocks ("
<< launch_dims.block_counts().x << ", " << launch_dims.block_counts().y
<< ") than allowed by hardware (" << device_info.block_dim_limit().x
<< ", " << device_info.block_dim_limit().y << ").";
<< launch_dims.block_counts().x << ") than allowed by hardware ("
<< device_info.block_dim_limit().x << ").";
// Add __launch_bounds__ to metadata. This limits registers per thread to
// avoid out-of-resources launching errors.

Expand Down

This file was deleted.

22 changes: 13 additions & 9 deletions xla/service/gpu/launch_dimensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cstdint>

#include "xla/service/platform_util.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_description.h"
Expand All @@ -37,16 +38,19 @@ LaunchDimensions CalculateLaunchDimensions(
num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor});

const int kWarpSchedulers = 4;
int64_t threads_per_block = std::min<int64_t>(
int64_t threads_per_block_x = std::min<int64_t>(
gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements);
int64_t num_blocks_total = CeilOfRatio(num_elements, threads_per_block);

int64_t num_blocks_y = CeilOfRatio<uint64_t>(
num_blocks_total, gpu_device_info.block_dim_limit().x);
int64_t num_blocks_x = CeilOfRatio(num_blocks_total, num_blocks_y);

return LaunchDimensions(se::BlockDim(num_blocks_x, num_blocks_y, 1),
se::ThreadDim(threads_per_block, 1, 1));
int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block_x);
CHECK(num_blocks < gpu_device_info.block_dim_limit().x);
int threads_per_block_y = 1;
if (xla::PlatformUtil::CanonicalPlatformName("gpu").value() == "rocm") {
while ((num_blocks * threads_per_block_x) > std::numeric_limits<uint32_t>::max()) {
threads_per_block_x /= 2;
threads_per_block_y *= 2;
}
}
return LaunchDimensions(se::BlockDim(num_blocks, 1, 1),
se::ThreadDim(threads_per_block_x, threads_per_block_y, 1));
}

} // namespace gpu
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/tests/gpu_kernel_tiling_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,8 +401,8 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) {
absl::Status status = CompileToExecutable(std::move(hlo_module)).status();
EXPECT_THAT(status.message(),
::testing::ContainsRegex(
"Kernel '.*' launch needs more blocks [(]4294967296, 1[)] "
"than allowed by hardware [(]2147483647, 65535[)]"));
"Kernel '.*' launch needs more blocks [(]4294967296[)] than "
"allowed by hardware [(]2147483647[)]"));
}

} // namespace
Expand Down
12 changes: 7 additions & 5 deletions xla/service/gpu/tests/gpu_too_many_blocks_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ TEST_F(TooManyBlocksTest, FailsWithInvalidStatus) {
HloModule primitive_computation_mul.8

ENTRY primitive_computation_mul.8 {
parameter.1 = s8[65536] parameter(0)
parameter.2 = s8[65536] parameter(1)
broadcast.3 = s8[65536,65536,65536,128,2] broadcast(parameter.1), dimensions={0}
broadcast.4 = s8[65536,65536,65536,128,2] broadcast(parameter.2), dimensions={1}
ROOT multiply.5 = s8[65536,65536,65536,128,2] multiply(broadcast.3, broadcast.4)
parameter.1 = f32[4,1048576,1,1]{3,2,1,0} parameter(0)
reshape.3 = f32[4,1048576,1]{2,1,0} reshape(parameter.1)
broadcast.4 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.3), dimensions={0,1,3}
parameter.2 = f32[4,1,1048576,1]{3,2,1,0} parameter(1)
reshape.5 = f32[4,1048576,1]{2,1,0} reshape(parameter.2)
broadcast.6 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.5), dimensions={0,2,3}
ROOT multiply.7 = f32[4,1048576,1048576,1]{3,2,1,0} multiply(broadcast.4, broadcast.6)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> optimized_module,
Expand Down