-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Serialization for commonly used types #770
Changes from 3 commits
7ef0042
fafe5aa
d21d9d7
5c72bec
5de6132
466660f
0cc1b51
2ce8c96
453f782
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <raft/core/handle.hpp> | ||
#include <raft/detail/serialization.hpp> | ||
|
||
#include <vector> | ||
|
||
namespace raft { | ||
|
||
/** | ||
* @brief Write a serializable state of an object to memory. | ||
* | ||
* @tparam T type of the serializable object | ||
* @tparam ContextArgs types of context required for serialization | ||
* | ||
* @param[out] out a host pointer to the location where to store the state; | ||
* when `nullptr`, the actual data is not written (only the size-to-written is calculated). | ||
* @param obj the object to be serialized | ||
* @param args context required for serialization | ||
* @return the number of bytes (to be) written by the pointer | ||
*/ | ||
template <typename T, typename... ContextArgs> | ||
auto serialize(uint8_t* out, const T& obj, ContextArgs&&... args) -> size_t | ||
{ | ||
return detail::call_serialize<T, ContextArgs...>(out, obj, std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
/** | ||
* @brief Write a serializable state of an object to a host vector. | ||
* | ||
* @tparam T type of the serializable object | ||
* @tparam ContextArgs types of context required for serialization | ||
* | ||
* @param obj the object to be serialized | ||
* @param args context required for serialization | ||
* @return the serialized state | ||
*/ | ||
template <typename T, typename... ContextArgs> | ||
auto serialize(const T& obj, ContextArgs&&... args) -> std::vector<uint8_t> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The first option would imply a state machine that takes an output buffer, and returns the "rest" of the output buffer and whether it finished with output (or some other error occurred, like running out of space). struct SerializeNoError;
struct SerializeOutOfSpace { size_t remaining_space_required; }
struct SerializeOtherError { /* ... */ };
using SerializeErrorCode = std::variant<SerializeOutOfSpace, SerializeOtherError>;
struct SerializeState {
std::span<uint8_t> rest_of_buffer;
SerializeInternalStateRepresentation state;
bool done = false;
SerializeErrorCode error_code{SerializeNoError{}};
};
bool keep_going(const SerializeReturn& r) {
return ! result.done && std::holds_alternative<SerializeNoError>(result.error_code);
}
template<typename T, typename... ContextArgs>
serialize(SerializeState r, const T& obj, ContextArgs&&... args) -> SerializeState; You then call it in a loop. const size_t initial_buffer_size = 100;
std::vector<uint8_t> buffer(initial_buffer_size);
SerializeState state{span{buffer.data()}};
while(keep_going(state)) {
state = serialize(state, obj, args...);
}; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The base implementation ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @achirkin Would you consider a design that names the space query as a separate function from serialization? Compare to cuSPARSE, which separates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd be totally fine separating them, though for some reason I've had an impression that in raft/cuml we follow the CUB approach of using the same function. @cjnolet , which one do you prefer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one is hard. @mhoemmen, we've been trying to keep the APIs in RAFT a little more C++ friendly, so generally trying to avoid the two-stage In this particular case, if we are only ever expecting the serialized buffer to be in host memory, I could see an argument for having serialize return a vector, or more preferably, accepting a vector and populating it. I think if I was to choose between the various options presented, I most like the idea of returning/populating some sort of streamable/iterable type if we can. I find that to be the most flexible design because a large stream can go right to disk and have to buffer only a small portion into memory at any given time. Further, the stream is also nice because you can collect it into a byte buffer at any point if needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cjnolet wrote:
I do like the idea of writing to some kind of output range. If it's a random access range, it would be nice to test if it's long enough and report an error if not, rather than just clobbering it. This is always detectable, and is generally recoverable (allocate more memory) or at least worth reporting ("insufficient disk space"). Please do consider not allocating the return buffer each time. This would make it impossible to write sequential records to contiguous storage. It can also be surprisingly expensive in a tight loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I do wholeheartedly agree with this. |
||
{ | ||
return detail::call_serialize<T, ContextArgs...>(obj, std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
/** | ||
* @brief Read a serializable state of an object from memory. | ||
* | ||
* @tparam T type of the serializable object | ||
* @tparam ContextArgs types of context required for serialization | ||
* | ||
* @param[out] p an unitialized host pointer to a location where the object should be created. | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* @param[in] in a host pointer to the location where the state should be read from. | ||
* @param args context required for serialization | ||
* @return the number of bytes read by the pointer | ||
*/ | ||
template <typename T, typename... ContextArgs> | ||
auto deserialize(T* p, const void* in, ContextArgs&&... args) -> size_t | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In general, the reason why I'm hesitant to change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Please do not : - )
It's possible if you have assignment, but ugly -- https://godbolt.org/z/774Ebfvh5 -- so I understand if you don't want to do it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant to say, it would be impossible if I changed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @achirkin Right, I definitely forgot the important detail of manually invoking the destructor on a placement new'd thing. |
||
{ | ||
return detail::call_deserialize<T, ContextArgs...>( | ||
p, reinterpret_cast<const uint8_t*>(in), std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
/** | ||
* @brief Read a serializable state of an object from a vector. | ||
* | ||
* @tparam T type of the serializable object | ||
* @tparam ContextArgs types of context required for serialization | ||
* | ||
* @param[out] p an unitialized host pointer to a location where the object should be created. | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* @param in a vector where the state should be read from. | ||
* @param args context required for serialization | ||
* @return the number of bytes read from the vector | ||
*/ | ||
template <typename T, typename... ContextArgs> | ||
auto deserialize(T* p, const std::vector<uint8_t>& in, ContextArgs&&... args) -> size_t | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return detail::call_deserialize<T, ContextArgs...>(p, in, std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
/** | ||
* @brief Read a serializable state of an object. | ||
* | ||
* @tparam T type of the serializable object | ||
* @tparam ContextArgs types of context required for serialization | ||
* | ||
* @param[in] in a host pointer to the location where the state should be read from. | ||
* @param args context required for serialization | ||
* @return the deserialized object; | ||
*/ | ||
template <typename T, typename... ContextArgs> | ||
auto deserialize(const void* in, ContextArgs&&... args) -> T | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return detail::call_deserialize<T, ContextArgs...>(reinterpret_cast<const uint8_t*>(in), | ||
std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
/** | ||
* @brief Read a serializable state of an object. | ||
* | ||
* @tparam T type of the serializable object | ||
* @tparam ContextArgs types of context required for serialization | ||
* | ||
* @param in a vector where the state should be read from. | ||
* @param args context required for serialization | ||
* @return the deserialized object; | ||
*/ | ||
template <typename T, typename... ContextArgs> | ||
auto deserialize(const std::vector<uint8_t>& in, ContextArgs&&... args) -> T | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return detail::call_deserialize<T, ContextArgs...>(in, std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
} // namespace raft |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <raft/core/handle.hpp> | ||
#include <raft/core/mdarray.hpp> | ||
|
||
#include <rmm/device_uvector.hpp> | ||
|
||
#include <type_traits> | ||
#include <vector> | ||
|
||
namespace raft::detail { | ||
|
||
template <typename T> | ||
struct serial; | ||
|
||
template <typename T, typename... ContextArgs> | ||
auto call_serialize(uint8_t* out, const T& obj, ContextArgs&&... args) -> size_t | ||
{ | ||
return detail::serial<T>::to_bytes(out, obj, std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
template <typename T, typename... ContextArgs> | ||
auto call_serialize(const T& obj, ContextArgs&&... args) -> std::vector<uint8_t> | ||
{ | ||
std::vector<uint8_t> v( | ||
call_serialize<T, ContextArgs...>(nullptr, obj, std::forward<ContextArgs>(args)...)); | ||
call_serialize<T, ContextArgs...>(v.data(), obj, std::forward<ContextArgs>(args)...); | ||
return v; | ||
} | ||
|
||
template <typename T, typename... ContextArgs> | ||
auto call_deserialize(T* p, const uint8_t* in, ContextArgs&&... args) -> size_t | ||
{ | ||
return detail::serial<T>::from_bytes(p, in, std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
template <typename T, typename... ContextArgs> | ||
auto call_deserialize(T* p, const std::vector<uint8_t>& in, ContextArgs&&... args) -> size_t | ||
{ | ||
return call_deserialize<T, ContextArgs...>(p, in.data(), std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
template <typename T, typename... ContextArgs> | ||
auto call_deserialize(const uint8_t* in, ContextArgs&&... args) -> T | ||
{ | ||
union res { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A more idiomatic way to avoid initialization of T would be to have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! Is there an official way to populate the optional using a callback accepting a pointer to the data (i.e. I see I haven't explained why I've gone through such a hussle with these many overloads. My goal here is to provide an easy way for other developers to add serialization to new structures. That is, to add serialization to a new structure, one would need to write just two methods:
Although I love the return-value approach, I'm afraid I cannot use it here, because some of the classes in raft and cuml do not have some of the copy, move, or nullary constructors. Hence I've come up with this idea that it should be the responsibility of the serialization implementer to initialize the object in the The rest, all Assuming that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @achirkin Thanks for explaining! I see the difficulty now in using class NotDefaultConstructible {
public:
NotDefaultConstructible() = delete;
NotDefaultConstructible(int i) : i_(i) {}
int value() const { return i_; }
private:
int i_;
};
void deserialize_raw(NotDefaultConstructible* out, std::string_view s)
{
int value{};
auto result = std::from_chars(s.data(), s.data()+s.size(), value);
if (result.ec == std::errc::invalid_argument) {
throw std::invalid_argument{"String does not represent an integer"};
}
else if (result.ec == std::errc::result_out_of_range) {
throw std::out_of_range{"String represents an integer, but it does not fit in int"};
}
// WARNING: assumes out does NOT point to a valid object.
new(out) NotDefaultConstructible(value);
}
int main()
{
constexpr size_t num_bytes = sizeof(NotDefaultConstructible);
alignas(num_bytes) std::byte ptr[num_bytes];
NotDefaultConstructible* out = reinterpret_cast<NotDefaultConstructible*>(ptr);
deserialize_raw(out, s);
std::cout << "Value: " << out->value();
out->~NotDefaultConstructible(); // oof, we have to do this by hand
return 0;
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ouch, I wouldn't want to call the destructor by hand in the presence of exceptions in |
||
T value; | ||
explicit res(const uint8_t* in, ContextArgs&&... args) | ||
{ | ||
call_deserialize<T, ContextArgs...>(&value, in, std::forward<ContextArgs>(args)...); | ||
} | ||
~res() { value.~T(); } // NOLINT | ||
}; | ||
// using a union to avoid initialization of T and force copy elision. | ||
return res(in, std::forward<ContextArgs>(args)...).value; | ||
} | ||
|
||
template <typename T, typename... ContextArgs> | ||
auto call_deserialize(const std::vector<uint8_t>& in, ContextArgs&&... args) -> T | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
return call_deserialize<T, ContextArgs...>(in.data(), std::forward<ContextArgs>(args)...); | ||
} | ||
|
||
template <typename T> | ||
struct serial { | ||
// Default implementation for all arithmetic types: just write the value by the pointer. | ||
template <typename S = T> | ||
static auto to_bytes(uint8_t* out, const S& obj) | ||
-> std::enable_if_t<std::is_arithmetic_v<S>, size_t> | ||
{ | ||
if (out) { *reinterpret_cast<T*>(out) = obj; } | ||
achirkin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return sizeof(S); | ||
} | ||
|
||
// SFINAE-style failure | ||
template <typename S = T, typename... ContextArgs> | ||
static auto to_bytes(uint8_t* out, const S& obj, ContextArgs&&... args) | ||
-> std::enable_if_t<!std::is_arithmetic_v<S>, size_t> | ||
{ | ||
static_assert(!std::is_same_v<T, S>, "Serialization is not implemented for this type."); | ||
return 0; | ||
} | ||
|
||
// Default implementation for all arithmetic types: just read the value by the pointer. | ||
template <typename S = T> | ||
static auto from_bytes(S* p, const uint8_t* in) | ||
-> std::enable_if_t<std::is_arithmetic_v<S>, size_t> | ||
{ | ||
*p = *reinterpret_cast<const S*>(in); | ||
return sizeof(S); | ||
} | ||
|
||
// SFINAE-style failure | ||
template <typename S = T, typename... ContextArgs> | ||
static auto from_bytes(S* p, const uint8_t* in, ContextArgs&&... args) | ||
-> std::enable_if_t<!std::is_arithmetic_v<S>, size_t> | ||
{ | ||
static_assert(!std::is_same_v<T, S>, "Deserialization is not implemented for this type."); | ||
return 0; | ||
} | ||
}; | ||
|
||
template <typename IndexType, size_t... ExtentsPack> | ||
struct serial<extents<IndexType, ExtentsPack...>> { | ||
using obj_t = extents<IndexType, ExtentsPack...>; | ||
|
||
static auto to_bytes(uint8_t* out, const obj_t& obj) -> size_t | ||
{ | ||
if (out) { *reinterpret_cast<obj_t*>(out) = obj; } | ||
return sizeof(obj_t); | ||
} | ||
|
||
static auto from_bytes(obj_t* p, const uint8_t* in) -> size_t | ||
{ | ||
new (p) obj_t{*reinterpret_cast<const obj_t*>(in)}; | ||
return sizeof(obj_t); | ||
} | ||
}; | ||
|
||
template <typename ElementType, typename Extents, typename LayoutPolicy> | ||
struct serial<mdarray<ElementType, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Serializing |
||
Extents, | ||
LayoutPolicy, | ||
detail::device_accessor<detail::device_uvector_policy<ElementType>>>> { | ||
using obj_t = mdarray<ElementType, | ||
Extents, | ||
LayoutPolicy, | ||
detail::device_accessor<detail::device_uvector_policy<ElementType>>>; | ||
|
||
static auto to_bytes(uint8_t* out, const obj_t& obj, const handle_t& handle) -> size_t | ||
{ | ||
auto extents_size = call_serialize<Extents>(out, obj.extents()); | ||
auto total_size = obj.size() * sizeof(ElementType) + extents_size; | ||
if (out) { | ||
out += extents_size; | ||
raft::copy( | ||
reinterpret_cast<ElementType*>(out), obj.data_handle(), obj.size(), handle.get_stream()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [INCORRECT] This is only correct if the layout is exhaustive (contiguous in memory, not strided or otherwise skipping parts of memory). Otherwise, it will both write memory that are not part of the The right thing to do would be to define an |
||
} | ||
return total_size; | ||
} | ||
|
||
static auto from_bytes(obj_t* p, | ||
const uint8_t* in, | ||
const handle_t& handle, | ||
rmm::mr::device_memory_resource* mr = nullptr) -> size_t | ||
{ | ||
Extents exts; | ||
auto extents_size = call_deserialize<Extents>(&exts, in); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is your representation of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, at this moment we don't have any sorts of RTTI at all. That is, in this version the caller should know it has serialized the static extents, by possibly not writing anything at all; then |
||
in += extents_size; | ||
typename obj_t::mapping_type layout{exts}; | ||
typename obj_t::container_policy_type policy{handle.get_stream(), mr}; | ||
new (p) obj_t{layout, policy}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This approach first value-initializes all the elements of the mdarray, then deserializes into the mdarray. This iterates over the mdarray twice. If your mdarray has |
||
raft::copy( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [INCORRECT] This is only correct if the layout is exhaustive (contiguous in memory, not strided or otherwise skipping parts of memory). Please see above note. |
||
p->data_handle(), reinterpret_cast<const ElementType*>(in), p->size(), handle.get_stream()); | ||
return p->size() * sizeof(ElementType) + extents_size; | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct serial<rmm::device_uvector<T>> { | ||
static auto to_bytes(uint8_t* out, | ||
const rmm::device_uvector<T>& obj, | ||
rmm::cuda_stream_view stream) -> size_t | ||
{ | ||
if (out) { | ||
*reinterpret_cast<size_t*>(out) = obj.size(); | ||
out += sizeof(size_t); | ||
raft::copy(reinterpret_cast<T*>(out), obj.data(), obj.size(), stream); | ||
} | ||
return obj.size() * sizeof(T) + sizeof(size_t); | ||
} | ||
|
||
static auto from_bytes(rmm::device_uvector<T>* p, | ||
const uint8_t* in, | ||
rmm::cuda_stream_view stream, | ||
rmm::mr::device_memory_resource* mr = nullptr) -> size_t | ||
{ | ||
auto n = *reinterpret_cast<const size_t*>(in); | ||
in += sizeof(size_t); | ||
if (mr) { | ||
new (p) rmm::device_uvector<T>{n, stream, mr}; | ||
} else { | ||
new (p) rmm::device_uvector<T>{n, stream}; | ||
} | ||
raft::copy(p->data(), reinterpret_cast<const T*>(in), p->size(), stream); | ||
return p->size() * sizeof(T) + sizeof(size_t); | ||
} | ||
}; | ||
|
||
} // namespace raft::detail |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you consider using
std::span
as the output argument? This would better express how many elements the output array has.See C++ Core Guidelines F.24.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be happy to add overloads to have span as the output, but raft is currently set to
C++17
. It's not available till C++20, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we returned a raft::span?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, great, I can add overloads that accept
raft::span
!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not seem to have python bindings at the moment, right? I guess, we'd need at least one version of serialize that could be easily wrapped by cython for #752
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be straightforward to expose
raft::span
in apxd
file. We can also rename the functions as needed (for e.g.to_span
/from_span
andto_bytes
/from_bytes
).