From b03f9f1cf86379e7fc1f098968e0fe4f8a741ba0 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Thu, 11 May 2023 17:30:24 -0400 Subject: [PATCH] Avoid invalid memory access in experimental FIL for large output size (#5365) If the output size for a prediction exceeds the maximum available shared memory, the size of the output workspace was previously being set to 0, resulting in an invalid memory access. The correct behavior is to simply stop trying to fit the output in shared memory and fall back to using global memory. This change causes FIL to only go through the process of reducing rows per block iteration if it has a chance of changing the outcome. If we have already determined that we cannot store output to shared memory, we simply skip that step, knowing that we will fall back to global memory. Authors: - William Hicks (https://github.com/wphicks) Approvers: - Philip Hyunsu Cho (https://github.com/hcho3) - Dante Gama Dessavre (https://github.com/dantegd) URL: https://github.com/rapidsai/cuml/pull/5365 --- .../experimental/fil/detail/infer/gpu.cuh | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh b/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh index 792ec9ad98..66ce25677f 100644 --- a/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh +++ b/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh @@ -178,17 +178,19 @@ std::enable_if_t infer( rows_per_block_iteration = index_type{32}; } - do { - output_workspace_size = - compute_output_size(row_output_size, threads_per_block, rows_per_block_iteration, infer_type); - output_workspace_size_bytes = output_item_bytes * output_workspace_size; + if (row_output_size != 0) { + do { + output_workspace_size = compute_output_size( + row_output_size, threads_per_block, rows_per_block_iteration, infer_type); + output_workspace_size_bytes = output_item_bytes * output_workspace_size; - shared_mem_per_block = - (rows_per_block_iteration * row_size_bytes + output_workspace_size_bytes); - if (shared_mem_per_block > max_overall_shared_mem) { - rows_per_block_iteration >>= index_type{1}; - } - } while (shared_mem_per_block > max_overall_shared_mem && rows_per_block_iteration > 1); + shared_mem_per_block = + (rows_per_block_iteration * row_size_bytes + output_workspace_size_bytes); + if (shared_mem_per_block > max_overall_shared_mem) { + rows_per_block_iteration >>= index_type{1}; + } + } while (shared_mem_per_block > max_overall_shared_mem && rows_per_block_iteration > 1); + } shared_mem_per_block = std::min(shared_mem_per_block, max_overall_shared_mem);