From 521edd7aca9bd51ca636a0a4f6dcd1a071bf9b87 Mon Sep 17 00:00:00 2001 From: Janusz Lisiecki Date: Tue, 2 May 2023 10:34:35 +0200 Subject: [PATCH 1/4] Add O_DIRECT support to the TFRecord reader - adds the `use_o_direct` option to the TFRecord reader. In effect the reader reads to the internal buffer which chunks are shared with samples. When the buffer runs out of content new one is allocated and the old lives as long as any sample still uses a piece of it Signed-off-by: Janusz Lisiecki --- .../reader/loader/indexed_file_loader.h | 89 +++++++++++++++---- .../operators/reader/parser/tfrecord_parser.h | 2 +- dali/operators/reader/tfrecord_reader_op.cc | 8 +- dali/test/python/reader/test_index.py | 82 +++++++++++++++++ 4 files changed, 162 insertions(+), 19 deletions(-) diff --git a/dali/operators/reader/loader/indexed_file_loader.h b/dali/operators/reader/loader/indexed_file_loader.h index 444783e86e4..304b873d263 100755 --- a/dali/operators/reader/loader/indexed_file_loader.h +++ b/dali/operators/reader/loader/indexed_file_loader.h @@ -22,18 +22,28 @@ #include #include "dali/core/common.h" +#include "dali/core/mm/memory.h" #include "dali/operators/reader/loader/loader.h" #include "dali/util/file.h" +#include "dali/util/odirect_file.h" namespace dali { class IndexedFileLoader : public Loader> { public: - explicit IndexedFileLoader(const OpSpec& options) - : Loader(options), - uris_(options.GetRepeatedArgument("path")), - index_uris_(options.GetRepeatedArgument("index_path")), - current_index_(0), current_file_index_(0), current_file_(nullptr) { + explicit IndexedFileLoader(const OpSpec& spec) + : Loader(spec), + uris_(spec.GetRepeatedArgument("path")), + index_uris_(spec.GetRepeatedArgument("index_path")), + current_index_(0), current_file_index_(0), current_file_(nullptr), + use_o_direct_(spec.HasArgument("use_o_direct") && spec.GetArgument("use_o_direct")) { + DALI_ENFORCE(dont_use_mmap_ || !use_o_direct_, make_string("Cannot use use_o_direct with ", + "``dont_use_mmap=False``.")); + if (use_o_direct_) { + o_direct_chunk_size_ = ODirectFileStream::GetChunkSize(); + o_direct_alignm_ = ODirectFileStream::GetAlignment(); + o_direct_read_len_alignm_ = ODirectFileStream::GetLenAlignment(); + } } void ReadSample(Tensor& tensor) override { @@ -50,9 +60,12 @@ class IndexedFileLoader : public Loader> { meta.SetSkipSample(false); if (file_index != current_file_index_) { - current_file_->Close(); - current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_); + current_file_.reset(); + current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, + use_o_direct_); current_file_index_ = file_index; + // invalidate the position in the tmp read buffer + if (use_o_direct_) read_buffer_data_size_ = static_cast(-1); } // if image is cached, skip loading @@ -80,11 +93,44 @@ class IndexedFileLoader : public Loader> { if (tensor.shares_data()) { tensor.Reset(); } - tensor.Resize({size}, DALI_UINT8); - - int64 n_read = current_file_->Read(reinterpret_cast(tensor.raw_mutable_data()), - size); - DALI_ENFORCE(n_read == size, "Error reading from a file " + uris_[current_file_index_]); + if (use_o_direct_) { + // read again + if (!read_buffer_ || !(seek_pos > static_cast(read_buffer_pos_) && + seek_pos + size < + static_cast(read_buffer_pos_ + read_buffer_data_size_))) { + // allocate + auto block_start = align_down(seek_pos, o_direct_alignm_); + auto block_end = align_up(seek_pos + size, o_direct_alignm_); + auto aligned_len = align_up(block_end - block_start, o_direct_chunk_size_); + if (aligned_len > static_cast(read_buffer_size_)) { + read_buffer_size_ = aligned_len; + } + // the old memory will be used as long as any piece of it uses it + read_buffer_ = mm::alloc_raw_shared(read_buffer_size_, + o_direct_alignm_); + auto file = dynamic_cast(current_file_.get()); + auto ret = file->ReadAt(read_buffer_.get(), aligned_len, block_start); + read_buffer_pos_ = block_start; + read_buffer_data_size_ = ret; + DALI_ENFORCE(static_cast(ret) >= size && + static_cast(ret) <= aligned_len, + make_string("Failed to read file: ", uris_[file_index], + ", read: ", ret, " while it should be [", size, ", ", + aligned_len, "]")); + } + // we need to create a tmp variable that is a copy of read_buffer_ as members cannot be + // captured by value thus copied, and this is all about here + auto read_buffer_tmp = read_buffer_; + shared_ptr tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_)); + + tensor.ShareData(tmp_mem, size, false, {size}, DALI_UINT8, -1); + } else { + tensor.Resize({size}, DALI_UINT8); + + int64 n_read = current_file_->Read(reinterpret_cast(tensor.raw_mutable_data()), + size); + DALI_ENFORCE(n_read == size, "Error reading from a file " + uris_[current_file_index_]); + } } tensor.SetMeta(meta); @@ -92,9 +138,7 @@ class IndexedFileLoader : public Loader> { } ~IndexedFileLoader() override { - if (current_file_ != nullptr) { - current_file_->Close(); - } + current_file_.reset(); } virtual void ReadIndexFile(const std::vector& index_uris) { @@ -141,10 +185,13 @@ class IndexedFileLoader : public Loader> { std::tie(seek_pos, size, file_index) = indices_[current_index_]; if (file_index != current_file_index_) { if (current_file_index_ != static_cast(INVALID_INDEX)) { - current_file_->Close(); + current_file_.reset(); } - current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_); + current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, + use_o_direct_); current_file_index_ = file_index; + // invalidate the position in the tmp read buffer + if (use_o_direct_) read_buffer_pos_ = static_cast(-1); } current_file_->SeekRead(seek_pos); } @@ -159,6 +206,14 @@ class IndexedFileLoader : public Loader> { static constexpr int INVALID_INDEX = -1; bool should_seek_ = false; int64 next_seek_pos_ = 0; + bool use_o_direct_ = false; + size_t o_direct_chunk_size_ = 0; + size_t o_direct_alignm_ = 0; + size_t o_direct_read_len_alignm_ = 0; + shared_ptr read_buffer_; + size_t read_buffer_pos_ = 0; + size_t read_buffer_size_ = 0; + size_t read_buffer_data_size_ = 0; }; } // namespace dali diff --git a/dali/operators/reader/parser/tfrecord_parser.h b/dali/operators/reader/parser/tfrecord_parser.h index 69fc9254c00..b78dfb810dd 100644 --- a/dali/operators/reader/parser/tfrecord_parser.h +++ b/dali/operators/reader/parser/tfrecord_parser.h @@ -60,7 +60,7 @@ class TFRecordParser : public Parser> { raw_data = raw_data + sizeof(length) + sizeof(crc); DALI_ENFORCE(example.ParseFromArray(raw_data, length), make_string("Error while parsing TFRecord file: ", data.GetSourceInfo(), - " (raw data length: ", length, "bytes).")); + " (raw data length: ", length, " bytes).")); for (size_t i = 0; i < features_.size(); ++i) { auto& output = ws->Output(i); diff --git a/dali/operators/reader/tfrecord_reader_op.cc b/dali/operators/reader/tfrecord_reader_op.cc index be30d21bc06..f3275227aaa 100644 --- a/dali/operators/reader/tfrecord_reader_op.cc +++ b/dali/operators/reader/tfrecord_reader_op.cc @@ -45,7 +45,13 @@ DALI_SCHEMA(readers___TFRecordBase) The index files can be obtained from TFRecord files by using the ``tfrecord2idx`` script that is distributed with DALI.)code", - DALI_STRING_VEC); + DALI_STRING_VEC) + .AddOptionalArg("use_o_direct", + R"code(If set to True, the data will be read directly from the storage bypassing system +cache. + +Mutually exclusive with ``dont_use_mmap=False``.)code", + false); // Internal readers._tfrecord schema. DALI_SCHEMA(readers___TFRecord) diff --git a/dali/test/python/reader/test_index.py b/dali/test/python/reader/test_index.py index 861f371d59a..16fb5a0cebf 100644 --- a/dali/test/python/reader/test_index.py +++ b/dali/test/python/reader/test_index.py @@ -22,6 +22,8 @@ import numpy as np from test_utils import compare_pipelines, get_dali_extra_path from nose_utils import assert_raises +from nose2.tools import cartesian_params +from nose import SkipTest def skip_second(src, dst): @@ -72,6 +74,86 @@ def define_graph(self): _ = pipe_org.run() +def test_tfrecord_odirect(): + batch_size = 16 + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def tfrecord_pipe(path, index_path, dont_use_mmap, use_o_direct): + input = fn.readers.tfrecord( + path=path, + index_path=index_path, + dont_use_mmap=dont_use_mmap, + use_o_direct=use_o_direct, + features={ + "image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)}, + name="Reader") + return input["image/class/label"] + + tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train') + tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx') + + pipe = tfrecord_pipe(tfrecord, tfrecord_idx, True, True) + pipe_ref = tfrecord_pipe(tfrecord, tfrecord_idx, False, False) + pipe.build() + pipe_ref.build() + iters = (pipe.epoch_size("Reader") + batch_size) // batch_size + for _ in range(iters): + out = pipe.run() + out_ref = pipe_ref.run() + for a, b in zip(out, out_ref): + assert np.array_equal(a.as_array(), b.as_array()) + + +@cartesian_params(((1, 2, 1), (3, 1, 2)), + (True, False), + (True, False)) +def test_tfrecord_pad_last_batch(batch_description, dont_use_mmap, use_o_direct): + if not dont_use_mmap and use_o_direct: + raise SkipTest("Cannot use O_DIRECT with mmap") + num_samples, batch_size, num_shards = batch_description + + @pipeline_def(batch_size=batch_size, device_id=0, num_threads=4) + def tfrecord_pipe(path, index_path, dont_use_mmap, use_o_direct): + input = fn.readers.tfrecord( + path=path, + index_path=index_path, + num_shards=num_shards, + dont_use_mmap=dont_use_mmap, + use_o_direct=use_o_direct, + features={ + "image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)}, + name="Reader") + return input["image/class/label"] + + tfrecord = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train') + tfrecord_idx = os.path.join(get_dali_extra_path(), 'db', 'tfrecord', 'train.idx') + + idx_files_dir = tempfile.TemporaryDirectory() + recordio_idx = "rio_train.idx" + idx_file = os.path.join(idx_files_dir.name, recordio_idx) + + def leave_only_N(src, dst, n): + with open(src, 'r') as tmp_f: + with open(dst, 'w') as f: + for i, x in enumerate(tmp_f): + if i == n: + break + f.write(x) + + leave_only_N(tfrecord_idx, idx_file, num_samples) + + pipe = tfrecord_pipe(tfrecord, idx_file, dont_use_mmap, use_o_direct) + pipe_ref = tfrecord_pipe(tfrecord, idx_file, False, False) + pipe.build() + pipe_ref.build() + iters = (pipe.epoch_size("Reader") + batch_size) // batch_size + for _ in range(iters): + out = pipe.run() + out_ref = pipe_ref.run() + for a, b in zip(out, out_ref): + assert np.array_equal(a.as_array(), b.as_array()) + + def test_recordio(): class MXNetReaderPipeline(Pipeline): def __init__(self, batch_size, num_threads, device_id, num_gpus, data, data_idx): From 73a45136a45cb84f32f0ccbe729b3a4a3be5adcf Mon Sep 17 00:00:00 2001 From: Janusz Lisiecki Date: Tue, 16 May 2023 11:24:18 +0200 Subject: [PATCH 2/4] Parallelize reading Signed-off-by: Janusz Lisiecki --- .../reader/loader/indexed_file_loader.h | 66 ++++++++++++++----- dali/operators/reader/tfrecord_reader_op.cc | 17 +++++ dali/operators/reader/tfrecord_reader_op.h | 20 +++++- 3 files changed, 85 insertions(+), 18 deletions(-) diff --git a/dali/operators/reader/loader/indexed_file_loader.h b/dali/operators/reader/loader/indexed_file_loader.h index 304b873d263..d7ed0a91b4d 100755 --- a/dali/operators/reader/loader/indexed_file_loader.h +++ b/dali/operators/reader/loader/indexed_file_loader.h @@ -20,6 +20,9 @@ #include #include #include +#include +#include +#include #include "dali/core/common.h" #include "dali/core/mm/memory.h" @@ -65,7 +68,7 @@ class IndexedFileLoader : public Loader> { use_o_direct_); current_file_index_ = file_index; // invalidate the position in the tmp read buffer - if (use_o_direct_) read_buffer_data_size_ = static_cast(-1); + if (use_o_direct_) read_buffer_pos_ = static_cast(-1); } // if image is cached, skip loading @@ -107,20 +110,34 @@ class IndexedFileLoader : public Loader> { } // the old memory will be used as long as any piece of it uses it read_buffer_ = mm::alloc_raw_shared(read_buffer_size_, - o_direct_alignm_); - auto file = dynamic_cast(current_file_.get()); - auto ret = file->ReadAt(read_buffer_.get(), aligned_len, block_start); + o_direct_alignm_); read_buffer_pos_ = block_start; - read_buffer_data_size_ = ret; - DALI_ENFORCE(static_cast(ret) >= size && - static_cast(ret) <= aligned_len, - make_string("Failed to read file: ", uris_[file_index], - ", read: ", ret, " while it should be [", size, ", ", - aligned_len, "]")); + read_buffer_data_size_ = aligned_len; + auto file_name = uris_[file_index]; + auto file = dynamic_cast(current_file_.get()); + auto o_direct_chunk_size_tmp = o_direct_chunk_size_; + // capture shared ptr to file in lambda to make sure it is alive as long as we want to + // access it in any piece of work + shared_ptr tmp_file_ptr = current_file_; + for (size_t read_off = 0; static_cast(aligned_len) > read_off; + read_off += o_direct_chunk_size_) { + auto dst_ptr = read_buffer_.get() + read_off; + auto read_start = block_start + read_off; + auto min_read = std::min(o_direct_chunk_size_tmp, seek_pos + size - read_start); + auto work = [tmp_file_ptr, file, dst_ptr, o_direct_chunk_size_tmp, min_read, + read_start, file_name]() { + auto ret = file->ReadAt(dst_ptr, o_direct_chunk_size_tmp, read_start); + DALI_ENFORCE(ret >= min_read && ret <= o_direct_chunk_size_tmp, + make_string("Failed to read file: ", file_name, + ", read: ", ret, " while it should be in range [", min_read, + ", ", o_direct_chunk_size_tmp, "]")); + }; + { + std::lock_guard lock(mutex_); + jobs_.push(std::move(work)); + } + } } - // we need to create a tmp variable that is a copy of read_buffer_ as members cannot be - // captured by value thus copied, and this is all about here - auto read_buffer_tmp = read_buffer_; shared_ptr tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_)); tensor.ShareData(tmp_mem, size, false, {size}, DALI_UINT8, -1); @@ -184,9 +201,7 @@ class IndexedFileLoader : public Loader> { } std::tie(seek_pos, size, file_index) = indices_[current_index_]; if (file_index != current_file_index_) { - if (current_file_index_ != static_cast(INVALID_INDEX)) { - current_file_.reset(); - } + current_file_.reset(); current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, use_o_direct_); current_file_index_ = file_index; @@ -201,7 +216,7 @@ class IndexedFileLoader : public Loader> { std::vector> indices_; size_t current_index_; size_t current_file_index_; - std::unique_ptr current_file_; + std::shared_ptr current_file_; FileStream::MappingReserver mmap_reserver_; static constexpr int INVALID_INDEX = -1; bool should_seek_ = false; @@ -214,6 +229,23 @@ class IndexedFileLoader : public Loader> { size_t read_buffer_pos_ = 0; size_t read_buffer_size_ = 0; size_t read_buffer_data_size_ = 0; + + typedef std::function ReadWork; + std::queue jobs_; + std::mutex mutex_; + + public: + ReadWork GetReadWork() { + std::lock_guard lock(mutex_); + auto work = std::move(jobs_.front()); + jobs_.pop(); + return work; + } + + bool AnyWorkLeft() { + std::lock_guard lock(mutex_); + return jobs_.size(); + } }; } // namespace dali diff --git a/dali/operators/reader/tfrecord_reader_op.cc b/dali/operators/reader/tfrecord_reader_op.cc index f3275227aaa..07caab2c047 100644 --- a/dali/operators/reader/tfrecord_reader_op.cc +++ b/dali/operators/reader/tfrecord_reader_op.cc @@ -17,6 +17,7 @@ #include #include +#include #include "dali/operators/reader/tfrecord_reader_op.h" @@ -114,6 +115,22 @@ DALI_SCHEMA(TFRecordReader) submodule and renamed to follow a common pattern. This is a placeholder operator with identical functionality to allow for backward compatibility.)code"); // Deprecated in 1.0; +void TFRecordReader::Prefetch() { + // We actually prepare the next batch + DomainTimeRange tr("[DALI][TFRecordReader] Prefetch #" + to_string(curr_batch_producer_), + DomainTimeRange::kRed); + DataReader>::Prefetch(); + + auto idx_loader = dynamic_cast(loader_.get()); + while (idx_loader->AnyWorkLeft()) { + auto work = idx_loader->GetReadWork(); + thread_pool_.AddWork([work = std::move(work)] (int tid) { + work(); + }); + } + thread_pool_.RunAll(); +} + } // namespace dali #endif // DALI_BUILD_PROTO3 diff --git a/dali/operators/reader/tfrecord_reader_op.h b/dali/operators/reader/tfrecord_reader_op.h index be8bd4bb5a0..a7c080e90fb 100644 --- a/dali/operators/reader/tfrecord_reader_op.h +++ b/dali/operators/reader/tfrecord_reader_op.h @@ -26,7 +26,12 @@ namespace dali { class TFRecordReader : public DataReader> { public: explicit TFRecordReader(const OpSpec& spec) - : DataReader>(spec) { + : DataReader>(spec), + dont_use_mmap_(spec.GetArgument("dont_use_mmap")), + use_o_direct_(spec.GetArgument("use_o_direct")), + thread_pool_(num_threads_, spec.GetArgument("device_id"), false, "TFRecordReader") { + DALI_ENFORCE(dont_use_mmap_ || !use_o_direct_, make_string("Cannot use use_o_direct with ", + "``dont_use_mmap=False``.")); loader_ = InitLoader(spec); parser_.reset(new TFRecordParser(spec)); DALI_ENFORCE(!skip_cached_images_, @@ -38,8 +43,21 @@ class TFRecordReader : public DataReader> { parser_->Parse(tensor, &ws); } + ~TFRecordReader() override { + // Stop the prefetch thread as it uses the thread pool from this class. So before we can + // destroy the thread pool make sure no one is using it anymore. + this->StopPrefetchThread(); + } + + void Prefetch() override; + protected: USE_READER_OPERATOR_MEMBERS(CPUBackend, Tensor); + bool dont_use_mmap_ = false; + bool use_o_direct_ = false; + size_t o_direct_chunk_size_ = 0; + // ThreadPool for prefetch which is a separate thread + ThreadPool thread_pool_; }; } // namespace dali From 49e36eda59d9c9ce986240da2dd4a1eddbeeaf56 Mon Sep 17 00:00:00 2001 From: Janusz Lisiecki Date: Tue, 16 May 2023 21:14:33 +0200 Subject: [PATCH 3/4] Fix Signed-off-by: Janusz Lisiecki --- .../reader/loader/indexed_file_loader.h | 49 +++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/dali/operators/reader/loader/indexed_file_loader.h b/dali/operators/reader/loader/indexed_file_loader.h index d7ed0a91b4d..c7ca1641919 100755 --- a/dali/operators/reader/loader/indexed_file_loader.h +++ b/dali/operators/reader/loader/indexed_file_loader.h @@ -67,8 +67,8 @@ class IndexedFileLoader : public Loader> { current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, use_o_direct_); current_file_index_ = file_index; - // invalidate the position in the tmp read buffer - if (use_o_direct_) read_buffer_pos_ = static_cast(-1); + // invalidate the buffer + if (use_o_direct_) read_buffer_.reset(); } // if image is cached, skip loading @@ -97,18 +97,33 @@ class IndexedFileLoader : public Loader> { tensor.Reset(); } if (use_o_direct_) { - // read again - if (!read_buffer_ || !(seek_pos > static_cast(read_buffer_pos_) && + /* + * ** - sample data + * XX - buffer padding, data of other samples + * + * <-- TFRecord file --> + * | <- read_buffer_ -> | + * |<- seek_pos -><- size -> | | + * |<-block_start -> | | | | + * |<- | block_end | -> | + * | <- aligned_len/read_buffer_size_ -> | + * ----------------XXXX************XXXXXXXXXXXXXXXXXXX---------------- + */ + // read again if there is no buffer of the requested piece if outside of the it + if (!read_buffer_ || !(seek_pos >= static_cast(read_buffer_pos_) && seek_pos + size < static_cast(read_buffer_pos_ + read_buffer_data_size_))) { - // allocate + // check how much we need to allocate to house the required sample, but no less than + // o_direct_chunk_size_ auto block_start = align_down(seek_pos, o_direct_alignm_); auto block_end = align_up(seek_pos + size, o_direct_alignm_); auto aligned_len = align_up(block_end - block_start, o_direct_chunk_size_); + // make the staging buffer as big as the biggest sample so far if (aligned_len > static_cast(read_buffer_size_)) { read_buffer_size_ = aligned_len; } - // the old memory will be used as long as any piece of it uses it + // the old memory will be used as long as any piece of it uses its, so it is safe + // to release the old buffer from read_buffer_ read_buffer_ = mm::alloc_raw_shared(read_buffer_size_, o_direct_alignm_); read_buffer_pos_ = block_start; @@ -117,12 +132,14 @@ class IndexedFileLoader : public Loader> { auto file = dynamic_cast(current_file_.get()); auto o_direct_chunk_size_tmp = o_direct_chunk_size_; // capture shared ptr to file in lambda to make sure it is alive as long as we want to - // access it in any piece of work + // access it in any piece of work and it is not closed shared_ptr tmp_file_ptr = current_file_; + // split reads into chunks for (size_t read_off = 0; static_cast(aligned_len) > read_off; read_off += o_direct_chunk_size_) { auto dst_ptr = read_buffer_.get() + read_off; auto read_start = block_start + read_off; + // we should read either the chunk size or the reminder of the file auto min_read = std::min(o_direct_chunk_size_tmp, seek_pos + size - read_start); auto work = [tmp_file_ptr, file, dst_ptr, o_direct_chunk_size_tmp, min_read, read_start, file_name]() { @@ -132,14 +149,13 @@ class IndexedFileLoader : public Loader> { ", read: ", ret, " while it should be in range [", min_read, ", ", o_direct_chunk_size_tmp, "]")); }; - { - std::lock_guard lock(mutex_); - jobs_.push(std::move(work)); - } + // store the work lambda into queue so the prefetch thread can pick them up latter and + // execute in multiple threads + PutReadWork(work); } } shared_ptr tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_)); - + // make sure it is a big value in signed range tensor.ShareData(tmp_mem, size, false, {size}, DALI_UINT8, -1); } else { tensor.Resize({size}, DALI_UINT8); @@ -205,8 +221,8 @@ class IndexedFileLoader : public Loader> { current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_, use_o_direct_); current_file_index_ = file_index; - // invalidate the position in the tmp read buffer - if (use_o_direct_) read_buffer_pos_ = static_cast(-1); + // invalidate the buffer + if (use_o_direct_) read_buffer_.reset(); } current_file_->SeekRead(seek_pos); } @@ -234,6 +250,11 @@ class IndexedFileLoader : public Loader> { std::queue jobs_; std::mutex mutex_; + void PutReadWork(ReadWork work) { + std::lock_guard lock(mutex_); + jobs_.push(std::move(work)); + } + public: ReadWork GetReadWork() { std::lock_guard lock(mutex_); From a9b400bfc492ccdcaec2bae60176d8266d02a686 Mon Sep 17 00:00:00 2001 From: Janusz Lisiecki Date: Tue, 6 Jun 2023 13:32:20 +0200 Subject: [PATCH 4/4] Review fixes Signed-off-by: Janusz Lisiecki --- dali/operators/reader/loader/indexed_file_loader.h | 10 ++++++---- dali/operators/reader/tfrecord_reader_op.cc | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dali/operators/reader/loader/indexed_file_loader.h b/dali/operators/reader/loader/indexed_file_loader.h index c7ca1641919..9932a29f9b7 100755 --- a/dali/operators/reader/loader/indexed_file_loader.h +++ b/dali/operators/reader/loader/indexed_file_loader.h @@ -110,9 +110,11 @@ class IndexedFileLoader : public Loader> { * ----------------XXXX************XXXXXXXXXXXXXXXXXXX---------------- */ // read again if there is no buffer of the requested piece if outside of the it - if (!read_buffer_ || !(seek_pos >= static_cast(read_buffer_pos_) && - seek_pos + size < - static_cast(read_buffer_pos_ + read_buffer_data_size_))) { + bool after_buffer_start = seek_pos >= static_cast(read_buffer_pos_); + bool before_buffer_end = seek_pos + size < + static_cast(read_buffer_pos_ + read_buffer_data_size_); + // buffer need to exists and the ata we look for needs to be inside it + if (!read_buffer_ || !(after_buffer_start && before_buffer_end)) { // check how much we need to allocate to house the required sample, but no less than // o_direct_chunk_size_ auto block_start = align_down(seek_pos, o_direct_alignm_); @@ -151,7 +153,7 @@ class IndexedFileLoader : public Loader> { }; // store the work lambda into queue so the prefetch thread can pick them up latter and // execute in multiple threads - PutReadWork(work); + PutReadWork(std::move(work)); } } shared_ptr tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_)); diff --git a/dali/operators/reader/tfrecord_reader_op.cc b/dali/operators/reader/tfrecord_reader_op.cc index 07caab2c047..ab0188a378b 100644 --- a/dali/operators/reader/tfrecord_reader_op.cc +++ b/dali/operators/reader/tfrecord_reader_op.cc @@ -48,7 +48,7 @@ The index files can be obtained from TFRecord files by using the ``tfrecord2idx` that is distributed with DALI.)code", DALI_STRING_VEC) .AddOptionalArg("use_o_direct", - R"code(If set to True, the data will be read directly from the storage bypassing system + R"code(If set to True, the data will be read directly from the storage bypassing the system cache. Mutually exclusive with ``dont_use_mmap=False``.)code",