Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization for commonly used types #770

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions cpp/include/raft/core/serialization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/handle.hpp>
#include <raft/detail/serialization.hpp>

#include <vector>

namespace raft {

/**
* @brief Write a serializable state of an object to memory.
*
* @tparam T type of the serializable object
* @tparam ContextArgs types of context required for serialization
*
* @param[out] out a host pointer to the location where to store the state;
* when `nullptr`, the actual data is not written (only the size-to-written is calculated).
* @param obj the object to be serialized
* @param args context required for serialization
* @return the number of bytes (to be) written by the pointer
*/
template <typename T, typename... ContextArgs>
auto serialize(uint8_t* out, const T& obj, ContextArgs&&... args) -> size_t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you consider using std::span as the output argument? This would better express how many elements the output array has.

template <typename T, size_t OutputExtent, typename... ContextArgs>
auto serialize(std::span<uint8_t, OutputExtent> out, const T& obj, ContextArgs&&... args) -> size_t;

See C++ Core Guidelines F.24.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be happy to add overloads to have span as the output, but raft is currently set to C++17. It's not available till C++20, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we returned a raft::span?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, great, I can add overloads that accept raft::span!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we returned a raft::span?

It does not seem to have python bindings at the moment, right? I guess, we'd need at least one version of serialize that could be easily wrapped by cython for #752

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be straightforward to expose raft::span in a pxd file. We can also rename the functions as needed (for e.g. to_span/from_span and to_bytes/from_bytes).

{
return detail::call_serialize<T, ContextArgs...>(out, obj, std::forward<ContextArgs>(args)...);
}

/**
* @brief Write a serializable state of an object to a host vector.
*
* @tparam T type of the serializable object
* @tparam ContextArgs types of context required for serialization
*
* @param obj the object to be serialized
* @param args context required for serialization
* @return the serialized state
*/
template <typename T, typename... ContextArgs>
auto serialize(const T& obj, ContextArgs&&... args) -> std::vector<uint8_t>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Would you consider an interface that writes to existing storage, instead of performing a new allocation?
  2. If you mean to perform a new allocation, would you consider letting the user optionally pass in a custom allocator?

The first option would imply a state machine that takes an output buffer, and returns the "rest" of the output buffer and whether it finished with output (or some other error occurred, like running out of space).

struct SerializeNoError;
struct SerializeOutOfSpace { size_t remaining_space_required; }
struct SerializeOtherError { /* ... */ };
using SerializeErrorCode = std::variant<SerializeOutOfSpace, SerializeOtherError>;

struct SerializeState {
  std::span<uint8_t> rest_of_buffer;
  SerializeInternalStateRepresentation state;
  bool done = false;
  SerializeErrorCode error_code{SerializeNoError{}};
};

bool keep_going(const SerializeReturn& r) {
  return ! result.done && std::holds_alternative<SerializeNoError>(result.error_code);
}

template<typename T, typename... ContextArgs>
serialize(SerializeState r, const T& obj, ContextArgs&&... args) -> SerializeState;

You then call it in a loop.

const size_t initial_buffer_size = 100;
std::vector<uint8_t> buffer(initial_buffer_size);
SerializeState state{span{buffer.data()}};
while(keep_going(state)) {
  state = serialize(state, obj, args...);
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base implementation (to_bytes) at the moment writes to existing storage, but I don't have any checks whatsoever to see if there is enough space allocated. In the current state, the user is supposed to call serialize with nullptr to find out how much space it needs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achirkin Would you consider a design that names the space query as a separate function from serialization? Compare to cuSPARSE, which separates *_bufferSize functions from *_solve functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be totally fine separating them, though for some reason I've had an impression that in raft/cuml we follow the CUB approach of using the same function. @cjnolet , which one do you prefer?

Copy link
Member

@cjnolet cjnolet Aug 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is hard. @mhoemmen, we've been trying to keep the APIs in RAFT a little more C++ friendly, so generally trying to avoid the two-stage compute buffer size -> perform computation where at all possible. I'm usually not particularly in love w/ returning objects from functions either, unless we are explicitly using something like a factory for a complex object, because then the user cannot control the memory allocation.

In this particular case, if we are only ever expecting the serialized buffer to be in host memory, I could see an argument for having serialize return a vector, or more preferably, accepting a vector and populating it. I think if I was to choose between the various options presented, I most like the idea of returning/populating some sort of streamable/iterable type if we can. I find that to be the most flexible design because a large stream can go right to disk and have to buffer only a small portion into memory at any given time. Further, the stream is also nice because you can collect it into a byte buffer at any point if needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet wrote:

I think if I was to choose between the various options presented, I most like the idea of returning/populating some sort of streamable/iterable type if we can.

I do like the idea of writing to some kind of output range. If it's a random access range, it would be nice to test if it's long enough and report an error if not, rather than just clobbering it. This is always detectable, and is generally recoverable (allocate more memory) or at least worth reporting ("insufficient disk space").

Please do consider not allocating the return buffer each time. This would make it impossible to write sequential records to contiguous storage. It can also be surprisingly expensive in a tight loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do consider not allocating the return buffer each time. This would make it impossible to write sequential records to contiguous storage. It can also be surprisingly expensive in a tight loop.

I do wholeheartedly agree with this.

{
return detail::call_serialize<T, ContextArgs...>(obj, std::forward<ContextArgs>(args)...);
}

/**
* @brief Read a serializable state of an object from memory.
*
* @tparam T type of the serializable object
* @tparam ContextArgs types of context required for serialization
*
* @param[out] p an uninitialized host pointer to a location where the object should be created.
* @param[in] in a host pointer to the location where the state should be read from.
* @param args context required for serialization
* @return the number of bytes read by the pointer
*/
template <typename T, typename... ContextArgs>
auto deserialize(T* p, const void* in, ContextArgs&&... args) -> size_t
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. What happens if deserialization fails (e.g., because the representation is not correct)? How would you report the error?
  2. It sounds like the intent is for deserialize to do placement new in p. Is that correct?
  3. Would you instead consider an interface that writes to an existing T via T& out? The issue with a raw pointer interface is that users may forget the invariant "an uninitialized host pointer" and pass in a pointer to an existing object. In that case, your interface can totally break any invariants of T, e.g., possibly leaking any heap allocations that T might have inside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. At the moment, there are no safety checks at all. That being said, I think, we could add them in to_bytes/from_bytes, i.e. make it the responsibility of a person who adds serialization to a particular class - just throw one of the raft errors. What do you think? CC @cjnolet
  2. Yes
  3. I'd be fine with that, but I'm not sure how to do that better, while keeping the serial::from_bytes the same. Shall I add an overload that accepts T& out, manually calls its destructor and then passes the &out pointer further?.. sounds hackish.

In general, the reason why I'm hesitant to change serial::from_bytes to take a reference to an initialized object is that some of the classes in raft/cuml don't seem to have constructors that do nothing. On the one hand, the pre-initialization would likely mean some extra overhead and a minor nuisance of passing dummy arguments if there is no nullary constructor available. On the other hand, that would make impossible the value-returning overload of deserialize (deserialize(const std::vector<uint8_t>& in, ContextArgs&&... args) -> T) for the types that don't have the nullary constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall I add an overload that accepts T& out, manually calls its destructor and then passes the &out pointer further?.. sounds hackish.

Please do not : - )

On the other hand, that would make impossible the value-returning overload ... for the types that don't have the nullary constructor.

It's possible if you have assignment, but ugly -- https://godbolt.org/z/774Ebfvh5 -- so I understand if you don't want to do it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to say, it would be impossible if I changed serial::from_bytes to accept the reference instead of the pointer (deserialize_raw in your example). Also, deserialize_using_raw in your example does not call the destructor, does it? :) (nb my union trick to address this and copy elision)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achirkin Right, I definitely forgot the important detail of manually invoking the destructor on a placement new'd thing.

{
return detail::call_deserialize<T, ContextArgs...>(
p, reinterpret_cast<const uint8_t*>(in), std::forward<ContextArgs>(args)...);
}

/**
* @brief Read a serializable state of an object from a vector.
*
* @tparam T type of the serializable object
* @tparam ContextArgs types of context required for serialization
*
* @param[out] p an uninitialized host pointer to a location where the object should be created.
* @param in a vector where the state should be read from.
* @param args context required for serialization
* @return the number of bytes read from the vector
*/
template <typename T, typename... ContextArgs>
auto deserialize(T* p, const std::vector<uint8_t>& in, ContextArgs&&... args) -> size_t
achirkin marked this conversation as resolved.
Show resolved Hide resolved
{
return detail::call_deserialize<T, ContextArgs...>(p, in, std::forward<ContextArgs>(args)...);
}

/**
* @brief Read a serializable state of an object.
*
* @tparam T type of the serializable object
* @tparam ContextArgs types of context required for serialization
*
* @param[in] in a host pointer to the location where the state should be read from.
* @param args context required for serialization
* @return the deserialized object;
*/
template <typename T, typename... ContextArgs>
auto deserialize(const void* in, ContextArgs&&... args) -> T
achirkin marked this conversation as resolved.
Show resolved Hide resolved
{
return detail::call_deserialize<T, ContextArgs...>(reinterpret_cast<const uint8_t*>(in),
std::forward<ContextArgs>(args)...);
}

/**
* @brief Read a serializable state of an object.
*
* @tparam T type of the serializable object
* @tparam ContextArgs types of context required for serialization
*
* @param in a vector where the state should be read from.
* @param args context required for serialization
* @return the deserialized object;
*/
template <typename T, typename... ContextArgs>
auto deserialize(const std::vector<uint8_t>& in, ContextArgs&&... args) -> T
achirkin marked this conversation as resolved.
Show resolved Hide resolved
{
return detail::call_deserialize<T, ContextArgs...>(in, std::forward<ContextArgs>(args)...);
}

} // namespace raft
233 changes: 233 additions & 0 deletions cpp/include/raft/detail/serialization.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/core/handle.hpp>
#include <raft/core/mdarray.hpp>

#include <rmm/device_uvector.hpp>

#include <type_traits>
#include <vector>

namespace raft::detail {

/**
* The structure that holds implementation of serialization for a particular type.
* To add serialization to a given type, specialize this structure:
*
* @code{.cpp}
*
* template <>
* struct serial<YourStructure> {
* static auto to_bytes(uint8_t* out, const YourStructure& obj, ...) -> size_t { ... }
* static auto from_bytes(YourStructure* p, const uint8_t* in, ...) -> size_t { ... }
* };
* @endcode
*
* `serial::to_bytes` saves the state of the object to memory if `out` is not null;
* returns the size of the output in any case.
*
* `serial::from_bytes` constructs a new object by a given _uninitialized_ pointer (a.k.a. placement
* new), reads the state from memory (`in`), and returns the number of bytes read.
* NB: If `serial::from_bytes` throws an exception, it's assumed that the output pointer is still
* uninitialized.
*
* _See below this file for examples_.
*
* All of the `call_serialize` and `call_deserialize` are the convenience wrappers providing various
* ways to serialize/deserialize on top of `serial`. See `core/serialization.hpp` for their public
* documentation.
*/
template <typename T>
struct serial;

template <typename T, typename... ContextArgs>
auto call_serialize(uint8_t* out, const T& obj, ContextArgs&&... args) -> size_t
{
return detail::serial<T>::to_bytes(out, obj, std::forward<ContextArgs>(args)...);
}

template <typename T, typename... ContextArgs>
auto call_serialize(const T& obj, ContextArgs&&... args) -> std::vector<uint8_t>
{
std::vector<uint8_t> v(
call_serialize<T, ContextArgs...>(nullptr, obj, std::forward<ContextArgs>(args)...));
call_serialize<T, ContextArgs...>(v.data(), obj, std::forward<ContextArgs>(args)...);
return v;
}

template <typename T, typename... ContextArgs>
auto call_deserialize(T* p, const uint8_t* in, ContextArgs&&... args) -> size_t
{
return detail::serial<T>::from_bytes(p, in, std::forward<ContextArgs>(args)...);
}

template <typename T, typename... ContextArgs>
auto call_deserialize(T* p, const std::vector<uint8_t>& in, ContextArgs&&... args) -> size_t
{
return call_deserialize<T, ContextArgs...>(p, in.data(), std::forward<ContextArgs>(args)...);
}

template <typename T, typename... ContextArgs>
auto call_deserialize(const uint8_t* in, ContextArgs&&... args) -> T
{
union res {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more idiomatic way to avoid initialization of T would be to have call_serialize return optional<T> or expected<T, some_error_code>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Is there an official way to populate the optional using a callback accepting a pointer to the data (i.e. call_deserialize(T*, ...)?

I see I haven't explained why I've gone through such a hussle with these many overloads. My goal here is to provide an easy way for other developers to add serialization to new structures. That is, to add serialization to a new structure, one would need to write just two methods:

template <T>
struct serial {
  static auto to_bytes(uint8_t* out, const T& obj) -> size_t { ... }
  static auto from_bytes(T* p, const uint8_t* in) -> size_t { ... }
};

Although I love the return-value approach, I'm afraid I cannot use it here, because some of the classes in raft and cuml do not have some of the copy, move, or nullary constructors. Hence I've come up with this idea that it should be the responsibility of the serialization implementer to initialize the object in the from_bytes method.

The rest, all call_serialize/call_deserialize are implemented in terms of these two. At this moment, we don't seem to have a consensus over which target type to use for serialization - so I added a few of them here.

Assuming that from_bytes is available as-is, the union approach is the only way I could come up to return the value in this overload. I'd be happy to use optional for this, if it's possible!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@achirkin Thanks for explaining! I see the difficulty now in using optional. Another way to express this besides a union would be to allocate an array of bytes with the same size and alignment as T, and then placement new into it. Here is an example.

class NotDefaultConstructible {
public:
  NotDefaultConstructible() = delete;
  NotDefaultConstructible(int i) : i_(i) {}

  int value() const { return i_; }

private:
  int i_;
};

void deserialize_raw(NotDefaultConstructible* out, std::string_view s)
{
    int value{};  
    auto result = std::from_chars(s.data(), s.data()+s.size(), value);
    if (result.ec == std::errc::invalid_argument) {
        throw std::invalid_argument{"String does not represent an integer"};
    }
    else if (result.ec == std::errc::result_out_of_range) {
        throw std::out_of_range{"String represents an integer, but it does not fit in int"};
    }
    // WARNING: assumes out does NOT point to a valid object.
    new(out) NotDefaultConstructible(value);
}

int main()
{
    constexpr size_t num_bytes = sizeof(NotDefaultConstructible);
    alignas(num_bytes) std::byte ptr[num_bytes];

    NotDefaultConstructible* out = reinterpret_cast<NotDefaultConstructible*>(ptr);
    deserialize_raw(out, s);

    std::cout << "Value: " << out->value();
    out->~NotDefaultConstructible(); // oof, we have to do this by hand

    return 0;
}

Copy link
Contributor Author

@achirkin achirkin Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, I wouldn't want to call the destructor by hand in the presence of exceptions in deserialize_raw :) (even though it seems to be safe in this example) I guess I'd better go with the union approach if you don't see anything criminal in that.

T value;
explicit res(const uint8_t* in, ContextArgs&&... args)
{
call_deserialize<T, ContextArgs...>(&value, in, std::forward<ContextArgs>(args)...);
}
~res() { value.~T(); } // NOLINT
};
// using a union to avoid initialization of T and force copy elision.
return res(in, std::forward<ContextArgs>(args)...).value;
}

template <typename T, typename... ContextArgs>
auto call_deserialize(const std::vector<uint8_t>& in, ContextArgs&&... args) -> T
achirkin marked this conversation as resolved.
Show resolved Hide resolved
{
return call_deserialize<T, ContextArgs...>(in.data(), std::forward<ContextArgs>(args)...);
}

template <typename T>
struct serial {
// Default implementation for all arithmetic types: just write the value by the pointer.
template <typename S = T>
static auto to_bytes(uint8_t* out, const S& obj)
-> std::enable_if_t<std::is_arithmetic_v<S>, size_t>
{
if (out) { memcpy(out, &obj, sizeof(S)); }
return sizeof(S);
}

// SFINAE-style failure
template <typename S = T, typename... ContextArgs>
static auto to_bytes(uint8_t* out, const S& obj, ContextArgs&&... args)
-> std::enable_if_t<!std::is_arithmetic_v<S>, size_t>
{
static_assert(!std::is_same_v<T, S>, "Serialization is not implemented for this type.");
return 0;
}

// Default implementation for all arithmetic types: just read the value by the pointer.
template <typename S = T>
static auto from_bytes(S* p, const uint8_t* in)
-> std::enable_if_t<std::is_arithmetic_v<S>, size_t>
{
memcpy(p, in, sizeof(S));
return sizeof(S);
}

// SFINAE-style failure
template <typename S = T, typename... ContextArgs>
static auto from_bytes(S* p, const uint8_t* in, ContextArgs&&... args)
-> std::enable_if_t<!std::is_arithmetic_v<S>, size_t>
{
static_assert(!std::is_same_v<T, S>, "Deserialization is not implemented for this type.");
return 0;
}
};

template <typename IndexType, size_t... ExtentsPack>
struct serial<extents<IndexType, ExtentsPack...>> {
using obj_t = extents<IndexType, ExtentsPack...>;

static auto to_bytes(uint8_t* out, const obj_t& obj) -> size_t
{
if (out) { *reinterpret_cast<obj_t*>(out) = obj; }
return sizeof(obj_t);
}

static auto from_bytes(obj_t* p, const uint8_t* in) -> size_t
{
new (p) obj_t{*reinterpret_cast<const obj_t*>(in)};
return sizeof(obj_t);
}
};

template <typename ElementType, typename Extents, typename LayoutPolicy>
struct serial<mdarray<ElementType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Serializing mdarray instead of mdspan implies code duplication. Would you consider instead only defining serialization for mdspan?

Extents,
LayoutPolicy,
detail::device_accessor<detail::device_uvector_policy<ElementType>>>> {
using obj_t = mdarray<ElementType,
Extents,
LayoutPolicy,
detail::device_accessor<detail::device_uvector_policy<ElementType>>>;

static auto to_bytes(uint8_t* out, const obj_t& obj, const handle_t& handle) -> size_t
{
auto extents_size = call_serialize<Extents>(out, obj.extents());
auto total_size = obj.size() * sizeof(ElementType) + extents_size;
if (out) {
out += extents_size;
raft::copy(
reinterpret_cast<ElementType*>(out), obj.data_handle(), obj.size(), handle.get_stream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[INCORRECT]

This is only correct if the layout is exhaustive (contiguous in memory, not strided or otherwise skipping parts of memory). Otherwise, it will both write memory that are not part of the mdarray's elements, and it will miss elements of the mdarray.

The right thing to do would be to define an mdspan copy operation, as in P1673. Then you can copy from the input mdspan to a temporary output mdspan that views out with the desired output layout.

}
return total_size;
}

static auto from_bytes(obj_t* p,
const uint8_t* in,
const handle_t& handle,
rmm::mr::device_memory_resource* mr = nullptr) -> size_t
{
Extents exts;
auto extents_size = call_deserialize<Extents>(&exts, in);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is your representation of extents if all the extents are static? Are you able to distinguish this case from simply not having any data at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, at this moment we don't have any sorts of RTTI at all. That is, in this version the caller should know it has serialized the static extents, by possibly not writing anything at all; then from_bytes should read the same thing - possibly nothing.

in += extents_size;
typename obj_t::mapping_type layout{exts};
typename obj_t::container_policy_type policy{handle.get_stream(), mr};
new (p) obj_t{layout, policy};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach first value-initializes all the elements of the mdarray, then deserializes into the mdarray. This iterates over the mdarray twice. If your mdarray has std::vector as its container, you could replace the raft::copy call with a range over the input, do ranges::to initialization of a std::vector, and then move the vector into the mdarray.

raft::copy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[INCORRECT]

This is only correct if the layout is exhaustive (contiguous in memory, not strided or otherwise skipping parts of memory). Please see above note.

p->data_handle(), reinterpret_cast<const ElementType*>(in), p->size(), handle.get_stream());
return p->size() * sizeof(ElementType) + extents_size;
}
};

template <typename T>
struct serial<rmm::device_uvector<T>> {
static auto to_bytes(uint8_t* out,
const rmm::device_uvector<T>& obj,
rmm::cuda_stream_view stream) -> size_t
{
auto pref_size = call_serialize<size_t>(out, obj.size());
if (out) {
out += pref_size;
raft::copy(reinterpret_cast<T*>(out), obj.data(), obj.size(), stream);
}
return obj.size() * sizeof(T) + pref_size;
}

static auto from_bytes(rmm::device_uvector<T>* p,
const uint8_t* in,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = nullptr) -> size_t
{
size_t n;
auto pref_size = call_deserialize<size_t>(&n, in);
in += pref_size;
if (mr) {
new (p) rmm::device_uvector<T>{n, stream, mr};
} else {
new (p) rmm::device_uvector<T>{n, stream};
}
raft::copy(p->data(), reinterpret_cast<const T*>(in), p->size(), stream);
return p->size() * sizeof(T) + pref_size;
}
};

} // namespace raft::detail
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ add_executable(test_raft
test/interruptible.cu
test/nvtx.cpp
test/pow2_utils.cu
test/serialization.cpp
test/label/label.cu
test/label/merge_labels.cu
test/lap/lap.cu
Expand Down
Loading