Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL committed May 17, 2023
1 parent 3c877ed commit 8ae5221
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions dali/operators/reader/loader/indexed_file_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_,
use_o_direct_);
current_file_index_ = file_index;
// invalidate the position in the tmp read buffer
if (use_o_direct_) read_buffer_pos_ = static_cast<size_t>(-1);
// invalidate the buffer
if (use_o_direct_) read_buffer_.reset();
}

// if image is cached, skip loading
Expand Down Expand Up @@ -97,18 +97,33 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
tensor.Reset();
}
if (use_o_direct_) {
// read again
if (!read_buffer_ || !(seek_pos > static_cast<int64>(read_buffer_pos_) &&
/*
* ** - sample data
* XX - buffer padding, data of other samples
*
* <-- TFRecord file -->
* | <- read_buffer_ -> |
* |<- seek_pos -><- size -> | |
* |<-block_start -> | | | |
* |<- | block_end | -> |
* | <- aligned_len/read_buffer_size_ -> |
* ----------------XXXX************XXXXXXXXXXXXXXXXXXX----------------
*/
// read again if there is no buffer of the requested piece if outside of the it
if (!read_buffer_ || !(seek_pos >= static_cast<int64>(read_buffer_pos_) &&
seek_pos + size <
static_cast<int64>(read_buffer_pos_ + read_buffer_data_size_))) {
// allocate
// check how much we need to allocate to house the required sample, but no less than
// o_direct_chunk_size_
auto block_start = align_down(seek_pos, o_direct_alignm_);
auto block_end = align_up(seek_pos + size, o_direct_alignm_);
auto aligned_len = align_up(block_end - block_start, o_direct_chunk_size_);
// make the staging buffer as big as the biggest sample so far
if (aligned_len > static_cast<int64>(read_buffer_size_)) {
read_buffer_size_ = aligned_len;
}
// the old memory will be used as long as any piece of it uses it
// the old memory will be used as long as any piece of it uses its, so it is safe
// to release the old buffer from read_buffer_
read_buffer_ = mm::alloc_raw_shared<char, mm::memory_kind::host>(read_buffer_size_,
o_direct_alignm_);
read_buffer_pos_ = block_start;
Expand All @@ -117,12 +132,14 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
auto file = dynamic_cast<ODirectFileStream*>(current_file_.get());
auto o_direct_chunk_size_tmp = o_direct_chunk_size_;
// capture shared ptr to file in lambda to make sure it is alive as long as we want to
// access it in any piece of work
// access it in any piece of work and it is not closed
shared_ptr<FileStream> tmp_file_ptr = current_file_;
// split reads into chunks
for (size_t read_off = 0; static_cast<size_t>(aligned_len) > read_off;
read_off += o_direct_chunk_size_) {
auto dst_ptr = read_buffer_.get() + read_off;
auto read_start = block_start + read_off;
// we should read either the chunk size or the reminder of the file
auto min_read = std::min(o_direct_chunk_size_tmp, seek_pos + size - read_start);
auto work = [tmp_file_ptr, file, dst_ptr, o_direct_chunk_size_tmp, min_read,
read_start, file_name]() {
Expand All @@ -132,14 +149,13 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
", read: ", ret, " while it should be in range [", min_read,
", ", o_direct_chunk_size_tmp, "]"));
};
{
std::lock_guard<std::mutex> lock(mutex_);
jobs_.push(std::move(work));
}
// store the work lambda into queue so the prefetch thread can pick them up latter and
// execute in multiple threads
PutReadWork(work);
}
}
shared_ptr<void> tmp_mem(read_buffer_, read_buffer_.get() + (seek_pos - read_buffer_pos_));

// make sure it is a big value in signed range
tensor.ShareData(tmp_mem, size, false, {size}, DALI_UINT8, -1);
} else {
tensor.Resize({size}, DALI_UINT8);
Expand Down Expand Up @@ -205,8 +221,8 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
current_file_ = FileStream::Open(uris_[file_index], read_ahead_, !copy_read_data_,
use_o_direct_);
current_file_index_ = file_index;
// invalidate the position in the tmp read buffer
if (use_o_direct_) read_buffer_pos_ = static_cast<size_t>(-1);
// invalidate the buffer
if (use_o_direct_) read_buffer_.reset();
}
current_file_->SeekRead(seek_pos);
}
Expand Down Expand Up @@ -234,6 +250,11 @@ class IndexedFileLoader : public Loader<CPUBackend, Tensor<CPUBackend>> {
std::queue<ReadWork> jobs_;
std::mutex mutex_;

void PutReadWork(ReadWork work) {
std::lock_guard<std::mutex> lock(mutex_);
jobs_.push(std::move(work));
}

public:
ReadWork GetReadWork() {
std::lock_guard<std::mutex> lock(mutex_);
Expand Down

0 comments on commit 8ae5221

Please sign in to comment.