diff --git a/test/mpi/impls/mpich/cuda/stream.cu b/test/mpi/impls/mpich/cuda/stream.cu index 1cb62c83982..08964211411 100644 --- a/test/mpi/impls/mpich/cuda/stream.cu +++ b/test/mpi/impls/mpich/cuda/stream.cu @@ -36,10 +36,22 @@ void saxpy(int n, float a, float *x, float *y) if (i < n) y[i] = a*x[i] + y[i]; } -int main(void) +static int need_progress_thread = 0; +static void parse_args(int argc, char **argv) +{ + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "-progress-thread") == 0) { + need_progress_thread = 1; + } + } +} + +int main(int argc, char **argv) { int errs = 0; + parse_args(argc, argv); + cudaStream_t stream; cudaStreamCreate(&stream); @@ -71,6 +83,10 @@ int main(void) MPI_Info_free(&info); + if (need_progress_thread) { + MPIX_Start_progress_thread(mpi_stream); + } + MPI_Comm stream_comm; MPIX_Stream_comm_create(MPI_COMM_WORLD, mpi_stream, &stream_comm); @@ -139,6 +155,9 @@ int main(void) errs += check_result(y); } + if (need_progress_thread) { + MPIX_Stop_progress_thread(mpi_stream); + } MPI_Comm_free(&stream_comm); MPIX_Stream_free(&mpi_stream); diff --git a/test/mpi/impls/mpich/cuda/testlist b/test/mpi/impls/mpich/cuda/testlist index 67d29c2f951..006e11573b6 100644 --- a/test/mpi/impls/mpich/cuda/testlist +++ b/test/mpi/impls/mpich/cuda/testlist @@ -1,3 +1,4 @@ saxpy 2 stream 2 env=MPIR_CVAR_CH4_RESERVE_VCIS=1 +stream 2 env=MPIR_CVAR_CH4_RESERVE_VCIS=1 env=MPIR_CVAR_CH4_ENABLE_STREAM_WORKQ=1 env=MPIR_CVAR_GPU_HAS_WAIT_KERNEL=1 arg=-progress_thread stream_allred 4 env=MPIR_CVAR_CH4_RESERVE_VCIS=1