Skip to content

Commit

Permalink
Add O_DIRECT support to the TFRecord reader (#4820)
Browse files Browse the repository at this point in the history
- adds the `use_o_direct` option to the TFRecord reader. In effect
  the reader reads to the internal buffer which chunks are shared
  with samples. When the buffer runs out of content new one
  is allocated and the old lives as long as any sample still
  uses a piece of it

Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL authored Jun 12, 2023
1 parent 62a84af commit 39f885b
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 22 deletions.
148 changes: 129 additions & 19 deletions dali/operators/reader/loader/indexed_file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@
#include <tuple>
#include <fstream>
#include <memory>
#include <queue>
#include <mutex>
#include <utility>

#include "dali/core/common.h"
#include "dali/core/mm/memory.h"
#include "dali/operators/reader/loader/loader.h"
#include "dali/util/file.h"
#include "dali/util/odirect_file.h"

namespace dali {

class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
public:
explicit IndexedFileLoader(const OpSpec& options)
: Loader(options),
uris_(options.GetRepeatedArgument<std::string>("path")),
index_uris_(options.GetRepeatedArgument<std::string>("index_path")),
current_index_(0), current_file_index_(0), current_file_(nullptr) {
explicit IndexedFileLoader(const OpSpec& spec)
: Loader(spec),
uris_(spec.GetRepeatedArgument<std::string>("path")),
index_uris_(spec.GetRepeatedArgument<std::string>("index_path")),
current_index_(0), current_file_index_(0), current_file_(nullptr),
use_o_direct_(spec.HasArgument("use_o_direct") && spec.GetArgument<bool>("use_o_direct")) {
DALI_ENFORCE(dont_use_mmap_ || !use_o_direct_, make_string("Cannot use use_o_direct with ",
"``dont_use_mmap=False``."));
if (use_o_direct_) {
o_direct_chunk_size_ = ODirectFileStream::GetChunkSize();
o_direct_alignm_ = ODirectFileStream::GetAlignment();
o_direct_read_len_alignm_ = ODirectFileStream::GetLenAlignment();
}
}

void ReadSample(Tensor<CPUBackend>& tensor) override {
Expand All @@ -50,9 +63,12 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
meta.SetSkipSample(false);

if (file_index != current_file_index_) {
current_file_->Close();
current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_);
current_file_.reset();
current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_,
use_o_direct_);
current_file_index_ = file_index;
// invalidate the buffer
if (use_o_direct_) read_buffer_.reset();
}

// if image is cached, skip loading
Expand Down Expand Up @@ -80,21 +96,84 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
if (tensor.shares_data()) {
tensor.Reset();
}
tensor.Resize({size}, DALI_UINT8);
if (use_o_direct_) {
/*
* ** - sample data
* XX - buffer padding, data of other samples
*
* <-- TFRecord file -->
* | <- read_buffer_ -> |
* |<- seek_pos -><- size -> | |
* |<-block_start -> | | | |
* |<- | block_end | -> |
* | <- aligned_len/read_buffer_size_ -> |
* ----------------XXXX************XXXXXXXXXXXXXXXXXXX----------------
*/
// read again if there is no buffer of the requested piece if outside of the it
bool after_buffer_start = seek_pos >= static_cast<int64>(read_buffer_pos_);
bool before_buffer_end = seek_pos + size <
static_cast<int64>(read_buffer_pos_ + read_buffer_data_size_);
// buffer need to exists and the ata we look for needs to be inside it
if (!read_buffer_ || !(after_buffer_start && before_buffer_end)) {
// check how much we need to allocate to house the required sample, but no less than
// o_direct_chunk_size_
auto block_start = align_down(seek_pos, o_direct_alignm_);
auto block_end = align_up(seek_pos + size, o_direct_alignm_);
auto aligned_len = align_up(block_end - block_start, o_direct_chunk_size_);
// make the staging buffer as big as the biggest sample so far
if (aligned_len > static_cast<int64>(read_buffer_size_)) {
read_buffer_size_ = aligned_len;
}
// the old memory will be used as long as any piece of it uses its, so it is safe
// to release the old buffer from read_buffer_
read_buffer_ = mm::alloc_raw_shared<char, mm::memory_kind::host>(read_buffer_size_,
o_direct_alignm_);
read_buffer_pos_ = block_start;
read_buffer_data_size_ = aligned_len;
auto file_name = uris_[file_index];
auto file = dynamic_cast<ODirectFileStream*>(current_file_.get());
auto o_direct_chunk_size_tmp = o_direct_chunk_size_;
// capture shared ptr to file in lambda to make sure it is alive as long as we want to
// access it in any piece of work and it is not closed
shared_ptr<FileStream> tmp_file_ptr = current_file_;
// split reads into chunks
for (size_t read_off = 0; static_cast<size_t>(aligned_len) > read_off;
read_off += o_direct_chunk_size_) {
auto dst_ptr = read_buffer_.get() + read_off;
auto read_start = block_start + read_off;
// we should read either the chunk size or the reminder of the file
auto min_read = std::min(o_direct_chunk_size_tmp, seek_pos + size - read_start);
auto work = [tmp_file_ptr, file, dst_ptr, o_direct_chunk_size_tmp, min_read,
read_start, file_name]() {
auto ret = file->ReadAt(dst_ptr, o_direct_chunk_size_tmp, read_start);
DALI_ENFORCE(ret >= min_read && ret <= o_direct_chunk_size_tmp,
make_string("Failed to read file: ", file_name,
", read: ", ret, " while it should be in range [", min_read,
", ", o_direct_chunk_size_tmp, "]"));
};
// store the work lambda into queue so the prefetch thread can pick them up latter and
// execute in multiple threads
PutReadWork(std::move(work));
}
}
shared_ptr<void> tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_));
// make sure it is a big value in signed range
tensor.ShareData(tmp_mem, size, false, {size}, DALI_UINT8, -1);
} else {
tensor.Resize({size}, DALI_UINT8);

int64 n_read = current_file_->Read(reinterpret_cast<uint8_t*>(tensor.raw_mutable_data()),
size);
DALI_ENFORCE(n_read == size, "Error reading from a file " + uris_[current_file_index_]);
int64 n_read = current_file_->Read(reinterpret_cast<uint8_t*>(tensor.raw_mutable_data()),
size);
DALI_ENFORCE(n_read == size, "Error reading from a file " + uris_[current_file_index_]);
}
}

tensor.SetMeta(meta);
return;
}

~IndexedFileLoader() override {
if (current_file_ != nullptr) {
current_file_->Close();
}
current_file_.reset();
}

virtual void ReadIndexFile(const std::vector<std::string>& index_uris) {
Expand Down Expand Up @@ -140,11 +219,12 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
}
std::tie(seek_pos, size, file_index) = indices_[current_index_];
if (file_index != current_file_index_) {
if (current_file_index_ != static_cast<size_t>(INVALID_INDEX)) {
current_file_->Close();
}
current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_);
current_file_.reset();
current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_,
use_o_direct_);
current_file_index_ = file_index;
// invalidate the buffer
if (use_o_direct_) read_buffer_.reset();
}
current_file_->SeekRead(seek_pos);
}
Expand All @@ -154,11 +234,41 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
std::vector<std::tuple<int64, int64, size_t>> indices_;
size_t current_index_;
size_t current_file_index_;
std::unique_ptr<FileStream> current_file_;
std::shared_ptr<FileStream> current_file_;
FileStream::MappingReserver mmap_reserver_;
static constexpr int INVALID_INDEX = -1;
bool should_seek_ = false;
int64 next_seek_pos_ = 0;
bool use_o_direct_ = false;
size_t o_direct_chunk_size_ = 0;
size_t o_direct_alignm_ = 0;
size_t o_direct_read_len_alignm_ = 0;
shared_ptr<char> read_buffer_;
size_t read_buffer_pos_ = 0;
size_t read_buffer_size_ = 0;
size_t read_buffer_data_size_ = 0;

typedef std::function<void(void)> ReadWork;
std::queue<ReadWork> jobs_;
std::mutex mutex_;

void PutReadWork(ReadWork work) {
std::lock_guard<std::mutex> lock(mutex_);
jobs_.push(std::move(work));
}

public:
ReadWork GetReadWork() {
std::lock_guard<std::mutex> lock(mutex_);
auto work = std::move(jobs_.front());
jobs_.pop();
return work;
}

bool AnyWorkLeft() {
std::lock_guard<std::mutex> lock(mutex_);
return jobs_.size();
}
};

} // namespace dali
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/parser/tfrecord_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TFRecordParser : public Parser<Tensor<CPUBackend>> {
raw_data = raw_data + sizeof(length) + sizeof(crc);
DALI_ENFORCE(example.ParseFromArray(raw_data, length),
make_string("Error while parsing TFRecord file: ", data.GetSourceInfo(),
" (raw data length: ", length, "bytes)."));
" (raw data length: ", length, " bytes)."));

for (size_t i = 0; i < features_.size(); ++i) {
auto& output = ws->Output<CPUBackend>(i);
Expand Down
25 changes: 24 additions & 1 deletion dali/operators/reader/tfrecord_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <vector>
#include <string>
#include <utility>

#include "dali/operators/reader/tfrecord_reader_op.h"

Expand Down Expand Up @@ -45,7 +46,13 @@ DALI_SCHEMA(readers___TFRecordBase)
The index files can be obtained from TFRecord files by using the ``tfrecord2idx`` script
that is distributed with DALI.)code",
DALI_STRING_VEC);
DALI_STRING_VEC)
.AddOptionalArg("use_o_direct",
R"code(If set to True, the data will be read directly from the storage bypassing the system
cache.
Mutually exclusive with ``dont_use_mmap=False``.)code",
false);

// Internal readers._tfrecord schema.
DALI_SCHEMA(readers___TFRecord)
Expand Down Expand Up @@ -108,6 +115,22 @@ DALI_SCHEMA(TFRecordReader)
submodule and renamed to follow a common pattern. This is a placeholder operator with identical
functionality to allow for backward compatibility.)code"); // Deprecated in 1.0;

void TFRecordReader::Prefetch() {
// We actually prepare the next batch
DomainTimeRange tr("[DALI][TFRecordReader] Prefetch #" + to_string(curr_batch_producer_),
DomainTimeRange::kRed);
DataReader<CPUBackend, Tensor<CPUBackend>>::Prefetch();

auto idx_loader = dynamic_cast<IndexedFileLoader*>(loader_.get());
while (idx_loader->AnyWorkLeft()) {
auto work = idx_loader->GetReadWork();
thread_pool_.AddWork([work = std::move(work)] (int tid) {
work();
});
}
thread_pool_.RunAll();
}

} // namespace dali

#endif // DALI_BUILD_PROTO3
20 changes: 19 additions & 1 deletion dali/operators/reader/tfrecord_reader_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ namespace dali {
class TFRecordReader : public DataReader<CPUBackend, Tensor<CPUBackend>> {
public:
explicit TFRecordReader(const OpSpec& spec)
: DataReader<CPUBackend, Tensor<CPUBackend>>(spec) {
: DataReader<CPUBackend, Tensor<CPUBackend>>(spec),
dont_use_mmap_(spec.GetArgument<bool>("dont_use_mmap")),
use_o_direct_(spec.GetArgument<bool>("use_o_direct")),
thread_pool_(num_threads_, spec.GetArgument<int>("device_id"), false, "TFRecordReader") {
DALI_ENFORCE(dont_use_mmap_ || !use_o_direct_, make_string("Cannot use use_o_direct with ",
"``dont_use_mmap=False``."));
loader_ = InitLoader<IndexedFileLoader>(spec);
parser_.reset(new TFRecordParser(spec));
DALI_ENFORCE(!skip_cached_images_,
Expand All @@ -38,8 +43,21 @@ class TFRecordReader : public DataReader<CPUBackend, Tensor<CPUBackend>> {
parser_->Parse(tensor, &ws);
}

~TFRecordReader() override {
// Stop the prefetch thread as it uses the thread pool from this class. So before we can
// destroy the thread pool make sure no one is using it anymore.
this->StopPrefetchThread();
}

void Prefetch() override;

protected:
USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor<CPUBackend>);
bool dont_use_mmap_ = false;
bool use_o_direct_ = false;
size_t o_direct_chunk_size_ = 0;
// ThreadPool for prefetch which is a separate thread
ThreadPool thread_pool_;
};

} // namespace dali
Expand Down
82 changes: 82 additions & 0 deletions dali/test/python/reader/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy as np
from test_utils import compare_pipelines, get_dali_extra_path
from nose_utils import assert_raises
from nose2.tools import cartesian_params
from nose import SkipTest


def skip_second(src, dst):
Expand Down Expand Up @@ -72,6 +74,86 @@ def define_graph(self):
_ = pipe_org.run()


def test_tfrecord_odirect():
batch_size = 16

@pipeline_def(batch_size=batch_size, device_id=0, num_threads=4)
def tfrecord_pipe(path, index_path, dont_use_mmap, use_o_direct):
input = fn.readers.tfrecord(
path=path,
index_path=index_path,
dont_use_mmap=dont_use_mmap,
use_o_direct=use_o_direct,
features={
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)},
name="Reader")
return input["image/class/label"]

tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train')
tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx')

pipe = tfrecord_pipe(tfrecord, tfrecord_idx, True, True)
pipe_ref = tfrecord_pipe(tfrecord, tfrecord_idx, False, False)
pipe.build()
pipe_ref.build()
iters = (pipe.epoch_size("Reader") + batch_size) // batch_size
for _ in range(iters):
out = pipe.run()
out_ref = pipe_ref.run()
for a, b in zip(out, out_ref):
assert np.array_equal(a.as_array(), b.as_array())


@cartesian_params(((1, 2, 1), (3, 1, 2)),
(True, False),
(True, False))
def test_tfrecord_pad_last_batch(batch_description, dont_use_mmap, use_o_direct):
if not dont_use_mmap and use_o_direct:
raise SkipTest("Cannot use O_DIRECT with mmap")
num_samples, batch_size, num_shards = batch_description

@pipeline_def(batch_size=batch_size, device_id=0, num_threads=4)
def tfrecord_pipe(path, index_path, dont_use_mmap, use_o_direct):
input = fn.readers.tfrecord(
path=path,
index_path=index_path,
num_shards=num_shards,
dont_use_mmap=dont_use_mmap,
use_o_direct=use_o_direct,
features={
"image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)},
name="Reader")
return input["image/class/label"]

tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train')
tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx')

idx_files_dir = tempfile.TemporaryDirectory()
recordio_idx = "rio_train.idx"
idx_file = os.path.join(idx_files_dir.name, recordio_idx)

def leave_only_N(src, dst, n):
with open(src, 'r') as tmp_f:
with open(dst, 'w') as f:
for i, x in enumerate(tmp_f):
if i == n:
break
f.write(x)

leave_only_N(tfrecord_idx, idx_file, num_samples)

pipe = tfrecord_pipe(tfrecord, idx_file, dont_use_mmap, use_o_direct)
pipe_ref = tfrecord_pipe(tfrecord, idx_file, False, False)
pipe.build()
pipe_ref.build()
iters = (pipe.epoch_size("Reader") + batch_size) // batch_size
for _ in range(iters):
out = pipe.run()
out_ref = pipe_ref.run()
for a, b in zip(out, out_ref):
assert np.array_equal(a.as_array(), b.as_array())


def test_recordio():
class MXNetReaderPipeline(Pipeline):
def __init__(self, batch_size, num_threads, device_id, num_gpus, data, data_idx):
Expand Down

0 comments on commit 39f885b

Please sign in to comment.