-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] nightly test failed ROW_CONVERSION constantly #1567
Comments
I tried to git bisect this, but it's failing even on commits from over two weeks ago. That makes me think the failure was triggered not by a change in spark-rapids-jni or thirdparty/cudf but some dependency that always gets downloaded from latest (e.g.: rapids-cmake, rmm, etc.) |
The issue reproduces on my laptop deterministically,
e.g for ./target/cmake-build/gtests/ROW_CONVERSION --gtest_filter=ColumnToRowTests.Tall --rerun-failed --output-on-failure --gtest_recreate_environments_when_repeating --gtest_repeat=10 Rerunning with initcheck RUN ] ColumnToRowTests.Tall
========= Uninitialized __global__ memory read of size 16 bytes
========= at 0x2760 in void cub::CUB_101702_860_NS::DeviceScanKernel<cub::CUB_101702_860_NS::DeviceScanPolicy<unsigned long>::Policy600, thrust::constant_iterator<unsigned long, thrust::use_default, thrust::use_default>, unsigned long *, cub::CUB_101702_860_NS::ScanTileState<unsigned long, (bool)1>, thrust::plus<void>, cub::CUB_101702_860_NS::NullType, int>(T2, T3, T4, int, T5, T6, T7)
========= by thread (1,0,0) in block (1,0,0)
========= Address 0x7f0a6ec041f0
========= Saved host backtrace up to driver entry point at kernel launch time
========= Host Frame: [0x2fa190]
========= in /lib/x86_64-linux-gnu/libcuda.so.1
========= Host Frame:__cudart1071 [0x31f4beb]
========= in /path/issue1567/target/cmake-build/libcudf.so
========= Host Frame:cudaLaunchKernel_ptsz [0x3234838]
========= in /path/issue1567/target/cmake-build/libcudf.so
========= Host Frame:spark_rapids_jni::detail::batch_data spark_rapids_jni::detail::build_batches<thrust::constant_iterator<unsigned long, thrust::use_default, thrust::use_default> >(int, thrust::constant_iterator<unsigned long, thrust::use_default, thrust::use_default>, bool, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) [0x1146686]
========= in /path/issue1567/target/cmake-build/libcudf.so
========= Host Frame:spark_rapids_jni::convert_to_rows(cudf::table_view const&, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) [0x113539d]
========= in /path/issue1567/target/cmake-build/libcudf.so
========= Host Frame:ColumnToRowTests_Tall_Test::TestBody() [0x1f4435]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) [0x82995d]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:testing::Test::Run() [0x81a0ae]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:testing::TestInfo::Run() [0x81a24d]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:testing::TestSuite::Run() [0x81a7dd]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:testing::internal::UnitTestImpl::RunAllTests() [0x82036f]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:testing::UnitTest::Run() [0x81a320]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:main [0x1bc24b]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
========= Host Frame:../sysdeps/nptl/libc_start_call_main.h:58:__libc_start_call_main [0x29d90]
========= in /lib/x86_64-linux-gnu/libc.so.6
========= Host Frame:../csu/libc-start.c:379:__libc_start_main [0x29e40]
========= in /lib/x86_64-linux-gnu/libc.so.6
========= Host Frame: [0x1e414e]
========= in /path/issue1567/./target/cmake-build/gtests/ROW_CONVERSION
===== |
With memcheck the test actually passes after intercepting
|
Or it was always a bug and we are just now hitting it for some reason. That doesn't seem to be the case if it is deterministic. I originally thought this was due to the rand() inside this test. If I had to pick a line to investigate it would be
|
Can you try changing thrust/libcxx to their older version and test again? |
I tried to revert libcudacxx version from version 2.1.0 back to 1.9.1. That also requires reverting rmm since there is a recent rmm commit depending on libcudacxx 2.1.0. However, that doesn't help. I tried to add a bunch of code to retrieve cuda error throughout the code, which also adds |
It seems that
So there seems to be something else wrong, likely due to missing stream synchronization? |
There seems to be a bug in Here is the diff for (rough, not correct) fixing the issue above, I'm still investigating it... diff --git a/src/main/cpp/src/row_conversion.cu b/src/main/cpp/src/row_conversion.cu
index c1f94598d0..7da2c74c2f 100644
--- a/src/main/cpp/src/row_conversion.cu
+++ b/src/main/cpp/src/row_conversion.cu
@@ -1516,7 +1516,7 @@ struct row_size_functor {
__device__ inline uint64_t operator()(int i) const
{
- return i >= _row_end ? 0 : _row_sizes[i + _last_row_end];
+ return i < 0 || i >= _row_end ? 0 : _row_sizes[i + _last_row_end];
}
size_type _row_end;
@@ -1556,8 +1556,14 @@ batch_data build_batches(size_type num_rows,
batch_row_boundaries.push_back(0);
size_type last_row_end = 0;
device_uvector<uint64_t> cumulative_row_sizes(num_rows, stream);
- thrust::inclusive_scan(
- rmm::exec_policy(stream), row_sizes, row_sizes + num_rows, cumulative_row_sizes.begin());
+
+ thrust::inclusive_scan(rmm::exec_policy(stream),
+ thrust::make_counting_iterator<int64_t>(0L),
+ thrust::make_counting_iterator<int64_t>((int64_t) num_rows),
+ cumulative_row_sizes.begin(),
+ [row_sizes]__device__(auto i, auto j) -> uint64_t {
+ return row_sizes[i] + row_sizes[j];
+ });
|
Hit another overflow bug. Still investigating....
|
I've filed a follow on issue: #1579. |
This temporarily moves the row conversion code from spark-rapids-jni into libcudf. It is necessary to have the row conversion code compiled in a static library to overcome a CCCL issue that triggers invalid memory access when calling to `thrust::in(ex)clusive_scan` (NVIDIA/spark-rapids-jni#1567). In the future, when we have CCCL updated to fix the issue (1567), we may need to move the code back into spark-rapids-jni. Authors: - Nghia Truong (https://github.com/ttnghia) Approvers: - Mike Wilson (https://github.com/hyperbolic2346) - Vyas Ramasubramani (https://github.com/vyasr) - MithunR (https://github.com/mythrocks) URL: #14664
This is to remove the row conversion code from libcudf. It was move from spark-rapids-jni (by #14664) to temporarily workaround the issue due to conflict of kernel names that causes invalid memory access when calling to `thrust::in(ex)clusive_scan` (NVIDIA/spark-rapids-jni#1567). Now we have fixes for the namespace visibility issue (by marking all libcudf kenels private in rapidsai/rapids-cmake#523 and NVIDIA/cuCollections#422) and need to move back the code. Closes #14853. Authors: - Nghia Truong (https://github.com/ttnghia) Approvers: - David Wendt (https://github.com/davidwendt) - Bradley Dice (https://github.com/bdice) URL: #15234
Describe the bug
spark-rapids-jni_nightly-pre_release, build ID: 204
currently we only saw this in cuda12 once, jni ref: ff59e68, cudf ref: rapidsai/cudf@330d389we started seeing this more frequent in also submodule sync up pipeline (cuda11 ENV)
spark-rapids-jni_submodule-sync-pre_release, build ID 593, 591, 592 (constantly failing now)
failed
from target/cmake-build/Testing/Temporary/LastTest.log (full log 204.log)
Steps/Code to reproduce bug
run test with cuda 12 (cuda 12.2, driver: 535.104, GPU: A30)
Expected behavior
Pass the test
The text was updated successfully, but these errors were encountered: