Skip to content

Commit

Permalink
Hiding implementation details for comms (#409)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet authored Feb 1, 2022
1 parent 727419d commit ffe08fe
Show file tree
Hide file tree
Showing 12 changed files with 1,294 additions and 957 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_thrust.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down
42 changes: 31 additions & 11 deletions cpp/include/raft/comms/comms.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,52 +38,70 @@ enum class status_t {
};

template <typename value_t>
constexpr datatype_t get_type();
constexpr datatype_t

get_type();

template <>
constexpr datatype_t get_type<char>()
constexpr datatype_t

get_type<char>()
{
return datatype_t::CHAR;
}

template <>
constexpr datatype_t get_type<uint8_t>()
constexpr datatype_t

get_type<uint8_t>()
{
return datatype_t::UINT8;
}

template <>
constexpr datatype_t get_type<int>()
constexpr datatype_t

get_type<int>()
{
return datatype_t::INT32;
}

template <>
constexpr datatype_t get_type<uint32_t>()
constexpr datatype_t

get_type<uint32_t>()
{
return datatype_t::UINT32;
}

template <>
constexpr datatype_t get_type<int64_t>()
constexpr datatype_t

get_type<int64_t>()
{
return datatype_t::INT64;
}

template <>
constexpr datatype_t get_type<uint64_t>()
constexpr datatype_t

get_type<uint64_t>()
{
return datatype_t::UINT64;
}

template <>
constexpr datatype_t get_type<float>()
constexpr datatype_t

get_type<float>()
{
return datatype_t::FLOAT32;
}

template <>
constexpr datatype_t get_type<double>()
constexpr datatype_t

get_type<double>()
{
return datatype_t::FLOAT64;
}
Expand All @@ -93,10 +111,12 @@ class comms_iface {
virtual ~comms_iface() {}

virtual int get_size() const = 0;

virtual int get_rank() const = 0;

virtual std::unique_ptr<comms_iface> comm_split(int color, int key) const = 0;
virtual void barrier() const = 0;

virtual void barrier() const = 0;

virtual status_t sync_stream(cudaStream_t stream) const = 0;

Expand Down
168 changes: 168 additions & 0 deletions cpp/include/raft/comms/comms_test.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/comms/comms.hpp>
#include <raft/comms/detail/test.hpp>

#include <raft/handle.hpp>

namespace raft {
namespace comms {

/**
* @brief A simple sanity check that NCCL is able to perform a collective operation
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_allreduce(const handle_t& handle, int root)
{
return detail::test_collective_allreduce(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective operation
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_broadcast(const handle_t& handle, int root)
{
return detail::test_collective_broadcast(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective reduce
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_reduce(const handle_t& handle, int root)
{
return detail::test_collective_reduce(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective allgather
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_allgather(const handle_t& handle, int root)
{
return detail::test_collective_allgather(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective gather
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_gather(const handle_t& handle, int root)
{
return detail::test_collective_gather(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective gatherv
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_gatherv(const handle_t& handle, int root)
{
return detail::test_collective_gatherv(handle, root);
}

/**
* @brief A simple sanity check that NCCL is able to perform a collective reducescatter
*
* @param[in] handle the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] root the root rank id
*/
bool test_collective_reducescatter(const handle_t& handle, int root)
{
return detail::test_collective_reducescatter(handle, root);
}

/**
* A simple sanity check that UCX is able to send messages between all ranks
*
* @param[in] h the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param[in] numTrials number of iterations of all-to-all messaging to perform
*/
bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials)
{
return detail::test_pointToPoint_simple_send_recv(h, numTrials);
}

/**
* A simple sanity check that device is able to send OR receive.
*
* @param h the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param numTrials number of iterations of send or receive messaging to perform
*/
bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials)
{
return detail::test_pointToPoint_device_send_or_recv(h, numTrials);
}

/**
* A simple sanity check that device is able to send and receive at the same time.
*
* @param h the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param numTrials number of iterations of send or receive messaging to perform
*/
bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials)
{
return detail::test_pointToPoint_device_sendrecv(h, numTrials);
}

/**
* A simple sanity check that device is able to perform multiple concurrent sends and receives.
*
* @param h the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param numTrials number of iterations of send or receive messaging to perform
*/
bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrials)
{
return detail::test_pointToPoint_device_multicast_sendrecv(h, numTrials);
}

/**
* A simple test that the comms can be split into 2 separate subcommunicators
*
* @param h the raft handle to use. This is expected to already have an
* initialized comms instance.
* @param n_colors number of different colors to test
*/
bool test_commsplit(const handle_t& h, int n_colors) { return detail::test_commsplit(h, n_colors); }
} // namespace comms
}; // namespace raft
Loading

0 comments on commit ffe08fe

Please sign in to comment.