diff --git a/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh b/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh index 792ec9ad98..66ce25677f 100644 --- a/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh +++ b/cpp/include/cuml/experimental/fil/detail/infer/gpu.cuh @@ -178,17 +178,19 @@ std::enable_if_t infer( rows_per_block_iteration = index_type{32}; } - do { - output_workspace_size = - compute_output_size(row_output_size, threads_per_block, rows_per_block_iteration, infer_type); - output_workspace_size_bytes = output_item_bytes * output_workspace_size; + if (row_output_size != 0) { + do { + output_workspace_size = compute_output_size( + row_output_size, threads_per_block, rows_per_block_iteration, infer_type); + output_workspace_size_bytes = output_item_bytes * output_workspace_size; - shared_mem_per_block = - (rows_per_block_iteration * row_size_bytes + output_workspace_size_bytes); - if (shared_mem_per_block > max_overall_shared_mem) { - rows_per_block_iteration >>= index_type{1}; - } - } while (shared_mem_per_block > max_overall_shared_mem && rows_per_block_iteration > 1); + shared_mem_per_block = + (rows_per_block_iteration * row_size_bytes + output_workspace_size_bytes); + if (shared_mem_per_block > max_overall_shared_mem) { + rows_per_block_iteration >>= index_type{1}; + } + } while (shared_mem_per_block > max_overall_shared_mem && rows_per_block_iteration > 1); + } shared_mem_per_block = std::min(shared_mem_per_block, max_overall_shared_mem);