Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use resource_ref for upstream in stream_checking_resource_adaptor #16187

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions cpp/include/cudf_test/stream_checking_resource_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@

#include <iostream>

namespace cudf::test {

/**
* @brief Resource that verifies that the default stream is not used in any allocation.
*
* @tparam Upstream Type of the upstream resource used for
* allocation/deallocation.
*/
template <typename Upstream>
class stream_checking_resource_adaptor final : public rmm::mr::device_memory_resource {
public:
/**
Expand All @@ -40,14 +38,13 @@ class stream_checking_resource_adaptor final : public rmm::mr::device_memory_res
*
* @param upstream The resource used for allocating/deallocating device memory
*/
stream_checking_resource_adaptor(Upstream* upstream,
stream_checking_resource_adaptor(rmm::device_async_resource_ref upstream,
bool error_on_invalid_stream,
bool check_default_stream)
: upstream_{upstream},
error_on_invalid_stream_{error_on_invalid_stream},
check_default_stream_{check_default_stream}
{
CUDF_EXPECTS(nullptr != upstream, "Unexpected null upstream resource pointer.");
}

stream_checking_resource_adaptor() = delete;
Expand Down Expand Up @@ -86,7 +83,7 @@ class stream_checking_resource_adaptor final : public rmm::mr::device_memory_res
void* do_allocate(std::size_t bytes, rmm::cuda_stream_view stream) override
{
verify_stream(stream);
return upstream_->allocate(bytes, stream);
return upstream_.allocate_async(bytes, rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
}

/**
Expand All @@ -101,7 +98,7 @@ class stream_checking_resource_adaptor final : public rmm::mr::device_memory_res
void do_deallocate(void* ptr, std::size_t bytes, rmm::cuda_stream_view stream) override
{
verify_stream(stream);
upstream_->deallocate(ptr, bytes, stream);
upstream_.deallocate_async(ptr, bytes, rmm::CUDA_ALLOCATION_ALIGNMENT, stream);
}

/**
Expand All @@ -113,8 +110,8 @@ class stream_checking_resource_adaptor final : public rmm::mr::device_memory_res
[[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override
{
if (this == &other) { return true; }
auto cast = dynamic_cast<stream_checking_resource_adaptor<Upstream> const*>(&other);
if (cast == nullptr) { return upstream_->is_equal(other); }
auto cast = dynamic_cast<stream_checking_resource_adaptor const*>(&other);
if (cast == nullptr) { return false; }
return get_upstream_resource() == cast->get_upstream_resource();
}

Expand Down Expand Up @@ -150,7 +147,8 @@ class stream_checking_resource_adaptor final : public rmm::mr::device_memory_res
}
}

Upstream* upstream_; // the upstream resource used for satisfying allocation requests
rmm::device_async_resource_ref
upstream_; // the upstream resource used for satisfying allocation requests
bool error_on_invalid_stream_; // If true, throw an exception when the wrong stream is detected.
// If false, simply print to stdout.
bool check_default_stream_; // If true, throw an exception when the default stream is observed.
Expand All @@ -162,13 +160,12 @@ class stream_checking_resource_adaptor final : public rmm::mr::device_memory_res
* @brief Convenience factory to return a `stream_checking_resource_adaptor` around the
* upstream resource `upstream`.
*
* @tparam Upstream Type of the upstream `device_memory_resource`.
* @param upstream Pointer to the upstream resource
* @param upstream Reference to the upstream resource
*/
template <typename Upstream>
stream_checking_resource_adaptor<Upstream> make_stream_checking_resource_adaptor(
Upstream* upstream, bool error_on_invalid_stream, bool check_default_stream)
inline stream_checking_resource_adaptor make_stream_checking_resource_adaptor(
rmm::device_async_resource_ref upstream, bool error_on_invalid_stream, bool check_default_stream)
{
return stream_checking_resource_adaptor<Upstream>{
upstream, error_on_invalid_stream, check_default_stream};
return stream_checking_resource_adaptor{upstream, error_on_invalid_stream, check_default_stream};
}

} // namespace cudf::test
10 changes: 4 additions & 6 deletions cpp/include/cudf_test/testing_main.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

namespace cudf {
namespace test {
namespace cudf::test {

/// MR factory functions
inline auto make_cuda() { return std::make_shared<rmm::mr::cuda_memory_resource>(); }
Expand Down Expand Up @@ -91,8 +90,7 @@ inline std::shared_ptr<rmm::mr::device_memory_resource> create_memory_resource(
CUDF_FAIL("Invalid RMM allocation mode: " + allocation_mode);
}

} // namespace test
} // namespace cudf
} // namespace cudf::test

/**
* @brief Parses the cuDF test command line options.
Expand Down Expand Up @@ -182,8 +180,8 @@ inline auto make_stream_mode_adaptor(cxxopts::ParseResult const& cmd_opts)
auto const stream_error_mode = cmd_opts["stream_error_mode"].as<std::string>();
auto const error_on_invalid_stream = (stream_error_mode == "error");
auto const check_default_stream = (stream_mode == "new_cudf_default");
auto adaptor =
make_stream_checking_resource_adaptor(resource, error_on_invalid_stream, check_default_stream);
auto adaptor = cudf::test::make_stream_checking_resource_adaptor(
resource, error_on_invalid_stream, check_default_stream);
if ((stream_mode == "new_cudf_default") || (stream_mode == "new_testing_default")) {
rmm::mr::set_current_device_resource(&adaptor);
}
Expand Down
Loading