diff --git a/dali/operators/reader/loader/indexed_file_loader.h b/dali/operators/reader/loader/indexed_file_loader.h index 444783e86e4..9932a29f9b7 100755 --- a/dali/operators/reader/loader/indexed_file_loader.h +++ b/dali/operators/reader/loader/indexed_file_loader.h @@ -20,20 +20,33 @@ #include #include #include +#include +#include +#include #include "dali/core/common.h" +#include "dali/core/mm/memory.h" #include "dali/operators/reader/loader/loader.h" #include "dali/util/file.h" +#include "dali/util/odirect_file.h" namespace dali { class IndexedFileLoader : public Loader> { public: - explicit IndexedFileLoader(const OpSpec& options) - : Loader(options), - uris_(options.GetRepeatedArgument("path")), - index_uris_(options.GetRepeatedArgument("index_path")), - current_index_(0), current_file_index_(0), current_file_(nullptr) { + explicit IndexedFileLoader(const OpSpec& spec) + : Loader(spec), + uris_(spec.GetRepeatedArgument("path")), + index_uris_(spec.GetRepeatedArgument("index_path")), + current_index_(0), current_file_index_(0), current_file_(nullptr), + use_o_direct_(spec.HasArgument("use_o_direct") && spec.GetArgument("use_o_direct")) { + DALI_ENFORCE(dont_use_mmap_ || !use_o_direct_, make_string("Cannot use use_o_direct with ", + "``dont_use_mmap=False``.")); + if (use_o_direct_) { + o_direct_chunk_size_ = ODirectFileStream::GetChunkSize(); + o_direct_alignm_ = ODirectFileStream::GetAlignment(); + o_direct_read_len_alignm_ = ODirectFileStream::GetLenAlignment(); + } } void ReadSample(Tensor& tensor) override { @@ -50,9 +63,12 @@ class IndexedFileLoader : public Loader> { meta.SetSkipSample(false); if (file_index != current_file_index_) { - current_file_->Close(); - current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_); + current_file_.reset(); + current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, + use_o_direct_); current_file_index_ = file_index; + // invalidate the buffer + if (use_o_direct_) read_buffer_.reset(); } // if image is cached, skip loading @@ -80,11 +96,76 @@ class IndexedFileLoader : public Loader> { if (tensor.shares_data()) { tensor.Reset(); } - tensor.Resize({size}, DALI_UINT8); + if (use_o_direct_) { + /* + * ** - sample data + * XX - buffer padding, data of other samples + * + * <-- TFRecord file --> + * | <- read_buffer_ -> | + * |<- seek_pos -><- size -> | | + * |<-block_start -> | | | | + * |<- | block_end | -> | + * | <- aligned_len/read_buffer_size_ -> | + * ----------------XXXX************XXXXXXXXXXXXXXXXXXX---------------- + */ + // read again if there is no buffer of the requested piece if outside of the it + bool after_buffer_start = seek_pos >= static_cast(read_buffer_pos_); + bool before_buffer_end = seek_pos + size < + static_cast(read_buffer_pos_ + read_buffer_data_size_); + // buffer need to exists and the ata we look for needs to be inside it + if (!read_buffer_ || !(after_buffer_start && before_buffer_end)) { + // check how much we need to allocate to house the required sample, but no less than + // o_direct_chunk_size_ + auto block_start = align_down(seek_pos, o_direct_alignm_); + auto block_end = align_up(seek_pos + size, o_direct_alignm_); + auto aligned_len = align_up(block_end - block_start, o_direct_chunk_size_); + // make the staging buffer as big as the biggest sample so far + if (aligned_len > static_cast(read_buffer_size_)) { + read_buffer_size_ = aligned_len; + } + // the old memory will be used as long as any piece of it uses its, so it is safe + // to release the old buffer from read_buffer_ + read_buffer_ = mm::alloc_raw_shared(read_buffer_size_, + o_direct_alignm_); + read_buffer_pos_ = block_start; + read_buffer_data_size_ = aligned_len; + auto file_name = uris_[file_index]; + auto file = dynamic_cast(current_file_.get()); + auto o_direct_chunk_size_tmp = o_direct_chunk_size_; + // capture shared ptr to file in lambda to make sure it is alive as long as we want to + // access it in any piece of work and it is not closed + shared_ptr tmp_file_ptr = current_file_; + // split reads into chunks + for (size_t read_off = 0; static_cast(aligned_len) > read_off; + read_off += o_direct_chunk_size_) { + auto dst_ptr = read_buffer_.get() + read_off; + auto read_start = block_start + read_off; + // we should read either the chunk size or the reminder of the file + auto min_read = std::min(o_direct_chunk_size_tmp, seek_pos + size - read_start); + auto work = [tmp_file_ptr, file, dst_ptr, o_direct_chunk_size_tmp, min_read, + read_start, file_name]() { + auto ret = file->ReadAt(dst_ptr, o_direct_chunk_size_tmp, read_start); + DALI_ENFORCE(ret >= min_read && ret <= o_direct_chunk_size_tmp, + make_string("Failed to read file: ", file_name, + ", read: ", ret, " while it should be in range [", min_read, + ", ", o_direct_chunk_size_tmp, "]")); + }; + // store the work lambda into queue so the prefetch thread can pick them up latter and + // execute in multiple threads + PutReadWork(std::move(work)); + } + } + shared_ptr tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_)); + // make sure it is a big value in signed range + tensor.ShareData(tmp_mem, size, false, {size}, DALI_UINT8, -1); + } else { + tensor.Resize({size}, DALI_UINT8); - int64 n_read = current_file_->Read(reinterpret_cast(tensor.raw_mutable_data()), - size); - DALI_ENFORCE(n_read == size, "Error reading from a file " + uris_[current_file_index_]); + int64 n_read = current_file_->Read(reinterpret_cast(tensor.raw_mutable_data()), + size); + DALI_ENFORCE(n_read == size, "Error reading from a file " + uris_[current_file_index_]); + } } tensor.SetMeta(meta); @@ -92,9 +173,7 @@ class IndexedFileLoader : public Loader> { } ~IndexedFileLoader() override { - if (current_file_ != nullptr) { - current_file_->Close(); - } + current_file_.reset(); } virtual void ReadIndexFile(const std::vector& index_uris) { @@ -140,11 +219,12 @@ class IndexedFileLoader : public Loader> { } std::tie(seek_pos, size, file_index) = indices_[current_index_]; if (file_index != current_file_index_) { - if (current_file_index_ != static_cast(INVALID_INDEX)) { - current_file_->Close(); - } - current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_); + current_file_.reset(); + current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, + use_o_direct_); current_file_index_ = file_index; + // invalidate the buffer + if (use_o_direct_) read_buffer_.reset(); } current_file_->SeekRead(seek_pos); } @@ -154,11 +234,41 @@ class IndexedFileLoader : public Loader> { std::vector> indices_; size_t current_index_; size_t current_file_index_; - std::unique_ptr current_file_; + std::shared_ptr current_file_; FileStream::MappingReserver mmap_reserver_; static constexpr int INVALID_INDEX = -1; bool should_seek_ = false; int64 next_seek_pos_ = 0; + bool use_o_direct_ = false; + size_t o_direct_chunk_size_ = 0; + size_t o_direct_alignm_ = 0; + size_t o_direct_read_len_alignm_ = 0; + shared_ptr read_buffer_; + size_t read_buffer_pos_ = 0; + size_t read_buffer_size_ = 0; + size_t read_buffer_data_size_ = 0; + + typedef std::function ReadWork; + std::queue jobs_; + std::mutex mutex_; + + void PutReadWork(ReadWork work) { + std::lock_guard lock(mutex_); + jobs_.push(std::move(work)); + } + + public: + ReadWork GetReadWork() { + std::lock_guard lock(mutex_); + auto work = std::move(jobs_.front()); + jobs_.pop(); + return work; + } + + bool AnyWorkLeft() { + std::lock_guard lock(mutex_); + return jobs_.size(); + } }; } // namespace dali diff --git a/dali/operators/reader/parser/tfrecord_parser.h b/dali/operators/reader/parser/tfrecord_parser.h index 69fc9254c00..b78dfb810dd 100644 --- a/dali/operators/reader/parser/tfrecord_parser.h +++ b/dali/operators/reader/parser/tfrecord_parser.h @@ -60,7 +60,7 @@ class TFRecordParser : public Parser> { raw_data = raw_data + sizeof(length) + sizeof(crc); DALI_ENFORCE(example.ParseFromArray(raw_data, length), make_string("Error while parsing TFRecord file: ", data.GetSourceInfo(), - " (raw data length: ", length, "bytes).")); + " (raw data length: ", length, " bytes).")); for (size_t i = 0; i < features_.size(); ++i) { auto& output = ws->Output(i); diff --git a/dali/operators/reader/tfrecord_reader_op.cc b/dali/operators/reader/tfrecord_reader_op.cc index be30d21bc06..ab0188a378b 100644 --- a/dali/operators/reader/tfrecord_reader_op.cc +++ b/dali/operators/reader/tfrecord_reader_op.cc @@ -17,6 +17,7 @@ #include #include +#include #include "dali/operators/reader/tfrecord_reader_op.h" @@ -45,7 +46,13 @@ DALI_SCHEMA(readers___TFRecordBase) The index files can be obtained from TFRecord files by using the ``tfrecord2idx`` script that is distributed with DALI.)code", - DALI_STRING_VEC); + DALI_STRING_VEC) + .AddOptionalArg("use_o_direct", + R"code(If set to True, the data will be read directly from the storage bypassing the system +cache. + +Mutually exclusive with ``dont_use_mmap=False``.)code", + false); // Internal readers._tfrecord schema. DALI_SCHEMA(readers___TFRecord) @@ -108,6 +115,22 @@ DALI_SCHEMA(TFRecordReader) submodule and renamed to follow a common pattern. This is a placeholder operator with identical functionality to allow for backward compatibility.)code"); // Deprecated in 1.0; +void TFRecordReader::Prefetch() { + // We actually prepare the next batch + DomainTimeRange tr("[DALI][TFRecordReader] Prefetch #" + to_string(curr_batch_producer_), + DomainTimeRange::kRed); + DataReader>::Prefetch(); + + auto idx_loader = dynamic_cast(loader_.get()); + while (idx_loader->AnyWorkLeft()) { + auto work = idx_loader->GetReadWork(); + thread_pool_.AddWork([work = std::move(work)] (int tid) { + work(); + }); + } + thread_pool_.RunAll(); +} + } // namespace dali #endif // DALI_BUILD_PROTO3 diff --git a/dali/operators/reader/tfrecord_reader_op.h b/dali/operators/reader/tfrecord_reader_op.h index be8bd4bb5a0..a7c080e90fb 100644 --- a/dali/operators/reader/tfrecord_reader_op.h +++ b/dali/operators/reader/tfrecord_reader_op.h @@ -26,7 +26,12 @@ namespace dali { class TFRecordReader : public DataReader> { public: explicit TFRecordReader(const OpSpec& spec) - : DataReader>(spec) { + : DataReader>(spec), + dont_use_mmap_(spec.GetArgument("dont_use_mmap")), + use_o_direct_(spec.GetArgument("use_o_direct")), + thread_pool_(num_threads_, spec.GetArgument("device_id"), false, "TFRecordReader") { + DALI_ENFORCE(dont_use_mmap_ || !use_o_direct_, make_string("Cannot use use_o_direct with ", + "``dont_use_mmap=False``.")); loader_ = InitLoader(spec); parser_.reset(new TFRecordParser(spec)); DALI_ENFORCE(!skip_cached_images_, @@ -38,8 +43,21 @@ class TFRecordReader : public DataReader> { parser_->Parse(tensor, &ws); } + ~TFRecordReader() override { + // Stop the prefetch thread as it uses the thread pool from this class. So before we can + // destroy the thread pool make sure no one is using it anymore. + this->StopPrefetchThread(); + } + + void Prefetch() override; + protected: USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor); + bool dont_use_mmap_ = false; + bool use_o_direct_ = false; + size_t o_direct_chunk_size_ = 0; + // ThreadPool for prefetch which is a separate thread + ThreadPool thread_pool_; }; } // namespace dali diff --git a/dali/test/python/reader/test_index.py b/dali/test/python/reader/test_index.py index 861f371d59a..16fb5a0cebf 100644 --- a/dali/test/python/reader/test_index.py +++ b/dali/test/python/reader/test_index.py @@ -22,6 +22,8 @@ import numpy as np from test_utils import compare_pipelines, get_dali_extra_path from nose_utils import assert_raises +from nose2.tools import cartesian_params +from nose import SkipTest def skip_second(src, dst): @@ -72,6 +74,86 @@ def define_graph(self): _ = pipe_org.run() +def test_tfrecord_odirect(): + batch_size = 16 + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def tfrecord_pipe(path, index_path, dont_use_mmap, use_o_direct): + input = fn.readers.tfrecord( + path=path, + index_path=index_path, + dont_use_mmap=dont_use_mmap, + use_o_direct=use_o_direct, + features={ + "image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)}, + name="Reader") + return input["image/class/label"] + + tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train') + tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx') + + pipe = tfrecord_pipe(tfrecord, tfrecord_idx, True, True) + pipe_ref = tfrecord_pipe(tfrecord, tfrecord_idx, False, False) + pipe.build() + pipe_ref.build() + iters = (pipe.epoch_size("Reader") + batch_size) // batch_size + for _ in range(iters): + out = pipe.run() + out_ref = pipe_ref.run() + for a, b in zip(out, out_ref): + assert np.array_equal(a.as_array(), b.as_array()) + + +@cartesian_params(((1, 2, 1), (3, 1, 2)), + (True, False), + (True, False)) +def test_tfrecord_pad_last_batch(batch_description, dont_use_mmap, use_o_direct): + if not dont_use_mmap and use_o_direct: + raise SkipTest("Cannot use O_DIRECT with mmap") + num_samples, batch_size, num_shards = batch_description + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def tfrecord_pipe(path, index_path, dont_use_mmap, use_o_direct): + input = fn.readers.tfrecord( + path=path, + index_path=index_path, + num_shards=num_shards, + dont_use_mmap=dont_use_mmap, + use_o_direct=use_o_direct, + features={ + "image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)}, + name="Reader") + return input["image/class/label"] + + tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train') + tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx') + + idx_files_dir = tempfile.TemporaryDirectory() + recordio_idx = "rio_train.idx" + idx_file = os.path.join(idx_files_dir.name, recordio_idx) + + def leave_only_N(src, dst, n): + with open(src, 'r') as tmp_f: + with open(dst, 'w') as f: + for i, x in enumerate(tmp_f): + if i == n: + break + f.write(x) + + leave_only_N(tfrecord_idx, idx_file, num_samples) + + pipe = tfrecord_pipe(tfrecord, idx_file, dont_use_mmap, use_o_direct) + pipe_ref = tfrecord_pipe(tfrecord, idx_file, False, False) + pipe.build() + pipe_ref.build() + iters = (pipe.epoch_size("Reader") + batch_size) // batch_size + for _ in range(iters): + out = pipe.run() + out_ref = pipe_ref.run() + for a, b in zip(out, out_ref): + assert np.array_equal(a.as_array(), b.as_array()) + + def test_recordio(): class MXNetReaderPipeline(Pipeline): def __init__(self, batch_size, num_threads, device_id, num_gpus, data, data_idx):