-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Testing stream pool implementation #14437
Changes from all commits
2ac86e8
1247e6d
3fa0140
3959e9e
c612fd6
fe166a4
ca98a45
4e04b8e
a720444
362cdd1
bb54197
c4e72d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,6 +25,62 @@ | |||||
|
||||||
namespace cudf::detail { | ||||||
|
||||||
class cuda_stream_pool { | ||||||
public: | ||||||
// matching type used in rmm::cuda_stream_pool::get_stream(stream_id) | ||||||
using stream_id_type = std::size_t; | ||||||
|
||||||
virtual ~cuda_stream_pool() = default; | ||||||
|
||||||
/** | ||||||
* @brief Get a `cuda_stream_view` of a stream in the pool. | ||||||
* | ||||||
* This function is thread safe with respect to other calls to the same function. | ||||||
* | ||||||
* @return Stream view. | ||||||
*/ | ||||||
virtual rmm::cuda_stream_view get_stream() = 0; | ||||||
|
||||||
/** | ||||||
* @brief Get a `cuda_stream_view` of the stream associated with `stream_id`. | ||||||
* | ||||||
* Equivalent values of `stream_id` return a `cuda_stream_view` to the same underlying stream. | ||||||
* This function is thread safe with respect to other calls to the same function. | ||||||
* | ||||||
* @param stream_id Unique identifier for the desired stream | ||||||
* @return Requested stream view. | ||||||
*/ | ||||||
virtual rmm::cuda_stream_view get_stream(stream_id_type stream_id) = 0; | ||||||
|
||||||
/** | ||||||
* @brief Get a set of `cuda_stream_view` objects from the pool. | ||||||
* | ||||||
* An attempt is made to ensure that the returned vector does not contain duplicate | ||||||
* streams, but this cannot be guaranteed if `count` is greater than the value returned by | ||||||
* `get_stream_pool_size()`. | ||||||
* | ||||||
* This function is thread safe with respect to other calls to the same function. | ||||||
* | ||||||
* @param count The number of stream views to return. | ||||||
* @return Vector containing `count` stream views. | ||||||
*/ | ||||||
virtual std::vector<rmm::cuda_stream_view> get_streams(std::size_t count) = 0; | ||||||
|
||||||
/** | ||||||
* @brief Get the number of unique stream objects in the pool. | ||||||
* | ||||||
* This function is thread safe with respect to other calls to the same function. | ||||||
* | ||||||
* @return the number of stream objects in the pool | ||||||
*/ | ||||||
virtual std::size_t get_stream_pool_size() const = 0; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name is a bit redundant. We know that this is a stream pool.
Suggested change
or just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, but then please address it in a follow up work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pointing this out! I'll work on this change in a follow-up PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why isn't this a change for this PR, if all the code in this section is new in this PR? |
||||||
}; | ||||||
|
||||||
/** | ||||||
* @brief Initialize global stream pool. | ||||||
*/ | ||||||
cuda_stream_pool* create_global_cuda_stream_pool(); | ||||||
|
||||||
/** | ||||||
* @brief Acquire a set of `cuda_stream_view` objects and synchronize them to an event on another | ||||||
* stream. | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cudf_test/base_fixture.hpp> | ||
#include <cudf_test/default_stream.hpp> | ||
shrshi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#include <cudf/detail/utilities/stream_pool.hpp> | ||
#include <rmm/cuda_stream_view.hpp> | ||
|
||
class StreamPoolTest : public cudf::test::BaseFixture {}; | ||
|
||
__global__ void do_nothing_kernel() {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if this kernel without any code will be optimized out so the for loop below will never be executed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out! I printed the GPU trace from the nsys profiler to check this and it shows that the |
||
|
||
TEST_F(StreamPoolTest, ForkStreams) | ||
vuule marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
auto streams = cudf::detail::fork_streams(cudf::test::get_default_stream(), 2); | ||
for (auto& stream : streams) { | ||
do_nothing_kernel<<<1, 32, 0, stream.value()>>>(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a copy-paste of the declarations from stream_pool.cpp, right? It's a good change, just want to make sure I'm not missing any other changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have moved the parent class declaration and
create_global_cuda_stream_pool
fromstream_pool.cpp
to the header file so thattest_cuda_stream_pool
inidentify_stream_usage.cpp
can include it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notes from an offline discussion -
cuda_stream_pool
was "hidden" in the source file on purpose; we want the pool usage to be limited tofork_streams
/join_streams
. The change here exposes more of the stream pool than we'd like.However, we don't see a better solution, since the current approach at least does not require additional APIs in libcudf.