Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster RecordIO Scanner #11116

Merged
merged 2 commits into from
Jun 5, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions paddle/fluid/recordio/chunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,40 +119,56 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
}

bool Chunk::Parse(std::istream& sin) {
Header hdr;
bool ok = hdr.Parse(sin);
ChunkParser parser(sin);
if (!parser.Init()) {
return false;
}
Clear();
while (parser.HasNext()) {
Add(parser.Next());
}
return true;
}

ChunkParser::ChunkParser(std::istream& sin) : in_(sin) {}
bool ChunkParser::Init() {
pos_ = 0;
bool ok = header_.Parse(in_);
if (!ok) {
return ok;
}
auto beg_pos = sin.tellg();
uint32_t crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
Clear();
sin.seekg(beg_pos, sin.beg);
std::unique_ptr<std::istream> compressed_stream;
switch (hdr.CompressType()) {
auto beg_pos = in_.tellg();
uint32_t crc = Crc32Stream(in_, header_.CompressSize());
PADDLE_ENFORCE_EQ(header_.Checksum(), crc);
in_.seekg(beg_pos, in_.beg);

switch (header_.CompressType()) {
case Compressor::kNoCompress:
break;
case Compressor::kSnappy:
compressed_stream.reset(new snappy::iSnappyStream(sin));
compressed_stream_.reset(new snappy::iSnappyStream(in_));
break;
default:
PADDLE_THROW("Not implemented");
}
return true;
}

std::istream& stream = compressed_stream ? *compressed_stream : sin;
bool ChunkParser::HasNext() const { return pos_ < header_.NumRecords(); }

for (uint32_t i = 0; i < hdr.NumRecords(); ++i) {
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf);
std::string ChunkParser::Next() {
if (!HasNext()) {
return "";
}
return true;
++pos_;
std::istream& stream = compressed_stream_ ? *compressed_stream_ : in_;
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
return buf;
}

} // namespace recordio
} // namespace paddle
16 changes: 14 additions & 2 deletions paddle/fluid/recordio/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#pragma once
#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -53,9 +54,20 @@ class Chunk {
DISABLE_COPY_AND_ASSIGN(Chunk);
};

size_t CompressData(const char* in, size_t in_length, Compressor ct, char* out);
class ChunkParser {
public:
explicit ChunkParser(std::istream& sin);

bool Init();
std::string Next();
bool HasNext() const;

void DeflateData(const char* in, size_t in_length, Compressor ct, char* out);
private:
Header header_;
uint32_t pos_{0};
std::istream& in_;
std::unique_ptr<std::istream> compressed_stream_;
};

} // namespace recordio
} // namespace paddle
26 changes: 12 additions & 14 deletions paddle/fluid/recordio/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,33 @@ namespace paddle {
namespace recordio {

Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
: stream_(std::move(stream)) {
: stream_(std::move(stream)), parser_(*stream_) {
Reset();
}

Scanner::Scanner(const std::string &filename) {
stream_.reset(new std::ifstream(filename));
Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) {
Reset();
}

void Scanner::Reset() {
stream_->clear();
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
parser_.Init();
}

std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration");
auto rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
if (stream_->eof()) {
return "";
}
return rec;
}

void Scanner::ParseNextChunk() {
eof_ = !cur_chunk_.Parse(*stream_);
offset_ = 0;
auto res = parser_.Next();
if (!parser_.HasNext() && HasNext()) {
parser_.Init();
}
return res;
}

bool Scanner::HasNext() const { return !eof_; }
bool Scanner::HasNext() const { return !stream_->eof(); }
} // namespace recordio
} // namespace paddle
6 changes: 1 addition & 5 deletions paddle/fluid/recordio/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ class Scanner {

private:
std::unique_ptr<std::istream> stream_;
Chunk cur_chunk_;
size_t offset_;
bool eof_;

void ParseNextChunk();
ChunkParser parser_;
};
} // namespace recordio
} // namespace paddle