Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dmlc stream when URI protocol is not local file. #5857

Merged
merged 3 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions src/common/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <algorithm>
#include <fstream>
#include <string>
#include <memory>
#include <utility>
#include <cstdio>

Expand Down Expand Up @@ -93,35 +94,53 @@ void FixedSizeStream::Take(std::string* out) {
*out = std::move(buffer_);
}

std::string LoadSequentialFile(std::string fname) {
auto OpenErr = [&fname]() {
std::string msg;
msg = "Opening " + fname + " failed: ";
msg += strerror(errno);
LOG(FATAL) << msg;
};
auto ReadErr = [&fname]() {
std::string msg {"Error in reading file: "};
msg += fname;
msg += ": ";
msg += strerror(errno);
LOG(FATAL) << msg;
};
std::string LoadSequentialFile(std::string uri, bool stream) {
auto OpenErr = [&uri]() {
std::string msg;
msg = "Opening " + uri + " failed: ";
msg += strerror(errno);
LOG(FATAL) << msg;
};

std::string buffer;
// Open in binary mode so that correct file size can be computed with seekg().
// This accommodates Windows platform:
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
std::ifstream ifs(fname, std::ios_base::binary | std::ios_base::in);
ifs.seekg(0, std::ios_base::end);
const size_t file_size = static_cast<size_t>(ifs.tellg());
ifs.seekg(0, std::ios_base::beg);
buffer.resize(file_size + 1);
ifs.read(&buffer[0], file_size);
buffer.back() = '\0';
auto parsed = dmlc::io::URI(uri.c_str());
// Read from file.
if ((parsed.protocol == "file://" || parsed.protocol.length() == 0) && !stream) {
std::string buffer;
// Open in binary mode so that correct file size can be computed with
// seekg(). This accommodates Windows platform:
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
std::ifstream ifs(uri, std::ios_base::binary | std::ios_base::in);
if (!ifs) {
// https://stackoverflow.com/a/17338934
OpenErr();
}

ifs.seekg(0, std::ios_base::end);
const size_t file_size = static_cast<size_t>(ifs.tellg());
ifs.seekg(0, std::ios_base::beg);
buffer.resize(file_size + 1);
ifs.read(&buffer[0], file_size);
buffer.back() = '\0';

return buffer;
}

// Read from remote.
std::unique_ptr<dmlc::Stream> fs{dmlc::Stream::Create(uri.c_str(), "r")};
std::string buffer;
size_t constexpr kInitialSize = 4096;
size_t size {kInitialSize}, total {0};
while (true) {
buffer.resize(total + size);
size_t read = fs->Read(&buffer[total], size);
total += read;
if (read < size) {
break;
}
size *= 2;
}
buffer.resize(total);
return buffer;
}

} // namespace common
} // namespace xgboost
11 changes: 10 additions & 1 deletion src/common/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,16 @@ class FixedSizeStream : public PeekableInStream {
std::string buffer_;
};

std::string LoadSequentialFile(std::string fname);
/*!
* \brief Helper function for loading consecutive file to avoid dmlc Stream when possible.
*
* \param uri URI or file name to file.
* \param stream Use dmlc Stream unconditionally if set to true. Used for running test
* without remote filesystem.
*
* \return File content.
*/
std::string LoadSequentialFile(std::string uri, bool stream = false);

inline std::string FileExtension(std::string const& fname) {
auto splited = Split(fname, '.');
Expand Down
41 changes: 41 additions & 0 deletions tests/cpp/common/test_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
* Copyright (c) by XGBoost Contributors 2019
*/
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>

#include <fstream>

#include "../helpers.h"
#include "../../../src/common/io.h"

namespace xgboost {
Expand Down Expand Up @@ -39,5 +44,41 @@ TEST(IO, FixedSizeStream) {
ASSERT_EQ(huge_buffer, out_buffer);
}
}

TEST(IO, LoadSequentialFile) {
EXPECT_THROW(LoadSequentialFile("non-exist"), dmlc::Error);

dmlc::TemporaryDirectory tempdir;
std::ofstream fout(tempdir.path + "test_file");
std::string content;

// Generate a JSON file.
size_t constexpr kRows = 1000, kCols = 100;
std::shared_ptr<DMatrix> p_dmat{
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
learner->SetParam("tree_method", "hist");
learner->Configure();

for (int32_t iter = 0; iter < 10; ++iter) {
learner->UpdateOneIter(iter, p_dmat);
}
Json out { Object() };
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);

std::string tmpfile = tempdir.path + "/model.json";
{
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(tmpfile.c_str(), "w"));
fo->Write(str.c_str(), str.size());
}

auto loaded = LoadSequentialFile(tmpfile, true);
ASSERT_EQ(loaded, str);

ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
}
} // namespace common
} // namespace xgboost