Skip to content

Commit

Permalink
Avoid invalid memory access in experimental FIL for large output size (
Browse files Browse the repository at this point in the history
…#5365)

If the output size for a prediction exceeds the maximum available shared memory, the size of the output workspace was previously being set to 0, resulting in an invalid memory access. The correct behavior is to simply stop trying to fit the output in shared memory and fall back to using global memory.

This change causes FIL to only go through the process of reducing rows per block iteration if it has a chance of changing the outcome. If we have already determined that we cannot store output to shared memory, we simply skip that step, knowing that we will fall back to global memory.

Authors:
  - William Hicks (https://github.com/wphicks)

Approvers:
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #5365
  • Loading branch information
wphicks authored May 11, 2023
1 parent 95ca779 commit b03f9f1
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,19 @@ std::enable_if_t<D == raft_proto::device_type::gpu, void> infer(
rows_per_block_iteration = index_type{32};
}

do {
output_workspace_size =
compute_output_size(row_output_size, threads_per_block, rows_per_block_iteration, infer_type);
output_workspace_size_bytes = output_item_bytes * output_workspace_size;
if (row_output_size != 0) {
do {
output_workspace_size = compute_output_size(
row_output_size, threads_per_block, rows_per_block_iteration, infer_type);
output_workspace_size_bytes = output_item_bytes * output_workspace_size;

shared_mem_per_block =
(rows_per_block_iteration * row_size_bytes + output_workspace_size_bytes);
if (shared_mem_per_block > max_overall_shared_mem) {
rows_per_block_iteration >>= index_type{1};
}
} while (shared_mem_per_block > max_overall_shared_mem && rows_per_block_iteration > 1);
shared_mem_per_block =
(rows_per_block_iteration * row_size_bytes + output_workspace_size_bytes);
if (shared_mem_per_block > max_overall_shared_mem) {
rows_per_block_iteration >>= index_type{1};
}
} while (shared_mem_per_block > max_overall_shared_mem && rows_per_block_iteration > 1);
}

shared_mem_per_block = std::min(shared_mem_per_block, max_overall_shared_mem);

Expand Down

0 comments on commit b03f9f1

Please sign in to comment.