Skip to content

Commit

Permalink
Address reviewer's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed May 11, 2021
1 parent 1056b88 commit 561cb0e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
16 changes: 9 additions & 7 deletions include/treelite/tree_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ template <typename ThresholdType, typename LeafOutputType>
template <typename HeaderFieldHandlerFunc, typename TreeHandlerFunc>
inline void
ModelImpl<ThresholdType, LeafOutputType>::DeserializeTemplate(
size_t num_tree,
std::size_t num_tree,
HeaderFieldHandlerFunc header_field_handler,
TreeHandlerFunc tree_handler) {
/* Header */
Expand All @@ -870,7 +870,7 @@ ModelImpl<ThresholdType, LeafOutputType>::DeserializeTemplate(
header_field_handler(&param);
/* Body */
trees.clear();
for (size_t i = 0; i < num_tree; ++i) {
for (std::size_t i = 0; i < num_tree; ++i) {
trees.emplace_back();
tree_handler(trees.back());
}
Expand Down Expand Up @@ -917,18 +917,20 @@ ModelImpl<ThresholdType, LeafOutputType>::InitFromPyBuffer(
if (num_frame < kNumFrameInHeader || (num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) {
throw std::runtime_error("Wrong number of frames");
}
const size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree;
const std::size_t num_tree = (num_frame - kNumFrameInHeader) / kNumFramePerTree;

auto header_field_handler = [&begin](auto* field) {
InitScalarFromPyBuffer(field, *begin++);
};

auto tree_hanlder = [&begin](Tree<ThresholdType, LeafOutputType>& tree) {
auto tree_handler = [&begin](Tree<ThresholdType, LeafOutputType>& tree) {
// Read the frames in the range [begin, begin + kNumFramePerTree) into the tree
tree.InitFromPyBuffer(begin, begin + kNumFramePerTree);
begin += kNumFramePerTree;
// Advance the iterator so that the next tree reads the next kNumFramePerTree frames
};

DeserializeTemplate(num_tree, header_field_handler, tree_hanlder);
DeserializeTemplate(num_tree, header_field_handler, tree_handler);
}

template <typename ThresholdType, typename LeafOutputType>
Expand All @@ -941,11 +943,11 @@ ModelImpl<ThresholdType, LeafOutputType>::DeserializeFromFileImpl(FILE* src_fp)
ReadScalarFromFile(field, src_fp);
};

auto tree_hanlder = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
auto tree_handler = [src_fp](Tree<ThresholdType, LeafOutputType>& tree) {
tree.DeserializeFromFile(src_fp);
};

DeserializeTemplate(num_tree, header_field_handler, tree_hanlder);
DeserializeTemplate(num_tree, header_field_handler, tree_handler);
}

inline void InitParamAndCheck(ModelParam* param,
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/test_serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ inline void TestRoundTrip(treelite::Model* model) {
auto buffer = model->GetPyBuffer();
std::unique_ptr<treelite::Model> received_model = treelite::Model::CreateFromPyBuffer(buffer);

// Use ASSERT_TRUE, since ASSERT_EQ will dump all the raw bytes into a string, potentially
// causing an OOM error
ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get()));
}

Expand All @@ -44,6 +46,8 @@ inline void TestRoundTrip(treelite::Model* model) {
std::unique_ptr<treelite::Model> received_model = treelite::Model::DeserializeFromFile(fp);
std::fclose(fp);

// Use ASSERT_TRUE, since ASSERT_EQ will dump all the raw bytes into a string, potentially
// causing an OOM error
ASSERT_TRUE(TreeliteToBytes(model) == TreeliteToBytes(received_model.get()));
}
}
Expand Down

0 comments on commit 561cb0e

Please sign in to comment.