-
Notifications
You must be signed in to change notification settings - Fork 629
/
Copy pathdaliop.cc
419 lines (378 loc) · 17.4 KB
/
daliop.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime_api.h>
#include <chrono>
#include <sstream>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/public/version.h"
// for Eigen::GpuDevice
#define EIGEN_USE_GPU
#if TF_MAJOR_VERSION == 2 && TF_MINOR_VERSION >= 16
#include "unsupported/Eigen/CXX11/Tensor"
#else
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#endif
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#define USE_TF_ALLOCATOR 0
#if USE_TF_ALLOCATOR
#include "dali_tf_plugin/tfallocator.h"
#endif
#include "dali/c_api.h"
#include "dali/core/common.h"
#include "dali_tf_plugin/dali_helper.h"
typedef std::chrono::high_resolution_clock Clock;
namespace tf = tensorflow;
#define TF_DALI_CALL(FUNC) \
do { \
try { \
FUNC; \
} catch (std::exception &e) { \
std::string error = "DALI " + std::string(#FUNC) \
+ " failed: " + std::string(e.what()); \
std::cout << error << std::endl; \
context->SetStatus(tf::errors::Internal(error)); \
return; \
} \
} while (0)
namespace dali_tf_impl {
REGISTER_OP("Dali")
.Attr("serialized_pipeline: string")
.Attr("shapes: list(shape) >= 1")
.Attr("num_threads: int = -1")
.Attr("device_id: int = -1")
.Attr("exec_separated: bool = false")
.Attr("exec_dynamic: bool = false")
.Attr("gpu_prefetch_queue_depth: int = 2")
.Attr("cpu_prefetch_queue_depth: int = 2")
.Attr("sparse: list(bool) = []")
.Attr("batch_size: int = -1")
.Attr("enable_memory_stats: bool = false")
.Output("data: dtypes")
.Attr("dtypes: list({half, float, uint8, int16, int32, int64}) >= 1")
// To prevent replacing DALI op with constant tensor during TF constant folding process
.SetIsStateful()
.SetShapeFn([](tf::shape_inference::InferenceContext* c) {
std::vector<tf::PartialTensorShape> shapes;
TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
for (unsigned i = 0; i < shapes.size(); ++i) {
if (shapes[i].dims() > 0) {
tf::shape_inference::ShapeHandle passed_shape;
TF_RETURN_IF_ERROR(
c->MakeShapeFromPartialTensorShape(shapes[i], &passed_shape));
TF_RETURN_IF_ERROR(
c->WithRank(passed_shape, shapes[i].dims(), &passed_shape));
c->set_output(i, passed_shape);
}
}
return tf::Status();
})
.Doc(R"doc(
DALI TensorFlow plugin
Creates a DALI pipeline from a serialized pipeline, obtained from `serialized_pipeline` argument.
`shapes` must match the shape of the coresponding DALI Pipeline output tensor shape.
`dtypes` must match the type of the coresponding DALI Pipeline output tensors type.
)doc");
class DaliOp : public tf::OpKernel {
public:
explicit DaliOp(tf::OpKernelConstruction* context)
: OpKernel(context) {
std::string serialized_pipeline;
OP_REQUIRES_OK(context, context->GetAttr("serialized_pipeline", &serialized_pipeline));
int num_threads;
int device_id;
int max_batch_size;
bool exec_separated;
bool exec_dynamic;
int cpu_prefetch_queue_depth;
OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_));
OP_REQUIRES_OK(context, context->GetAttr("dtypes", &types_));
OP_REQUIRES_OK(context, context->GetAttr("num_threads", &num_threads));
OP_REQUIRES_OK(context, context->GetAttr("device_id", &device_id));
OP_REQUIRES_OK(context, context->GetAttr("exec_separated", &exec_separated));
OP_REQUIRES_OK(context, context->GetAttr("exec_dynamic", &exec_dynamic));
// In exec_separated==false case, gpu_prefetch_queue_depth is the global prefetch_queue_depth_
OP_REQUIRES_OK(context, context->GetAttr("gpu_prefetch_queue_depth", &prefetch_queue_depth_));
OP_REQUIRES_OK(context, context->GetAttr("sparse", &sparse_));
OP_REQUIRES_OK(context, context->GetAttr("batch_size", &max_batch_size));
OP_REQUIRES_OK(context, context->GetAttr("cpu_prefetch_queue_depth",
&cpu_prefetch_queue_depth));
OP_REQUIRES_OK(context, context->GetAttr("enable_memory_stats", &enable_memory_stats_));
// TF doing constant propagation runs all operators on the CPU first, so we need to provide
// ability to copy memory from the GPU pipeline to the CPU seamlessly
this->device_type_ = (context->device_type() == "CPU") ?
device_type_t::CPU : device_type_t::GPU;
if (std::any_of(sparse_.begin(), sparse_.end(), [] (const bool &v) {return v;}) &&
this->device_type_ == device_type_t::GPU) {
OP_REQUIRES_OK(context, tf::errors::Internal("Cannot output sparse tensors on the GPU"));
}
this->device_id_ = device_id;
this->batch_size_ = max_batch_size;
LOG_LINE << "Initializing...\n";
if (max_batch_size < 0) {
max_batch_size = shapes_[0].dim_size(0);
}
dali_exec_flags_t flags = DALI_EXEC_ASYNC_PIPELINED;
if (exec_dynamic)
flags = flags | DALI_EXEC_IS_DYNAMIC;
if (exec_separated)
flags = flags | DALI_EXEC_IS_SEPARATED;
TF_DALI_CALL(daliCreatePipeline3(&pipe_handle_,
serialized_pipeline.c_str(),
serialized_pipeline.length(),
max_batch_size,
num_threads,
device_id,
flags,
prefetch_queue_depth_,
cpu_prefetch_queue_depth,
prefetch_queue_depth_,
enable_memory_stats_));
#if USE_TF_ALLOCATOR
SetupTFAllocator(device_id_);
UpdateTFAllocaterContext<tf::OpKernelConstruction>(context, device_id_);
#endif
LOG_LINE << "Pipeline created\n";
LOG_LINE << "Prefetching...\n";
TF_DALI_CALL(daliPrefetch(&pipe_handle_));
LOG_LINE << "After first run\n";
}
~DaliOp() override {
if (pipe_handle_) {
if (enable_memory_stats_) {
size_t N;
daliExecutorMetadata *meta;
daliGetExecutorMetadata(&pipe_handle_, &meta, &N);
std::cout << "DALI operator memory statistics: " << std::endl;
for (size_t i = 0; i < N; ++i) {
std::cout << "Operator " << meta[i].operator_name;
for (size_t j = 0; j < meta[i].out_num; ++j) {
std::cout << " output [ " << j << " ] : "
<< meta[i].real_size[j] << "B allocated "
<< meta[i].max_real_size[j] << "B max allocated "
<< meta[i].reserved[j] << "B reserved"
<< meta[i].max_reserved[j] << "B max reserved";
if (j != meta[i].out_num - 1) {
std::cout << ",";
}
}
std::cout << std::endl;
}
daliFreeExecutorMetadata(meta, N);
}
daliDeletePipeline(&pipe_handle_);
}
}
void Compute(tf::OpKernelContext* context) override {
auto total_s = Clock::now();
#if USE_TF_ALLOCATOR
UpdateTFAllocaterContext<tf::OpKernelContext>(context, device_id_);
LOG_LINE << "Updated context\n";
#endif
LOG_LINE << "Before output...\n";
auto s = Clock::now();
TF_DALI_CALL(daliShareOutput(&pipe_handle_));
int64_t output_time = std::chrono::duration_cast<std::chrono::microseconds>(
Clock::now() - s).count();
LOG_LINE << "After output...\n";
s = Clock::now();
tf::OpOutputList outputs;
std::vector<tf::Tensor*> data_output_tensors;
// each sparse tensor need 3 tensors in total - values, indices and shape
unsigned additional_sparse_tensors = std::accumulate(sparse_.begin(), sparse_.end(), 0) * 2;
unsigned dali_num_out = 0;
TF_DALI_CALL(dali_num_out = daliGetNumOutput(&pipe_handle_));
data_output_tensors.resize(dali_num_out + additional_sparse_tensors);
OP_REQUIRES_OK(context, context->output_list("data", &outputs));
cudaStream_t stream = 0;
if (this->device_type_ == device_type_t::GPU) {
stream = context->eigen_device<Eigen::GpuDevice>().stream();
}
for (unsigned i = 0, j = 0; i < dali_num_out; ++i, ++j) {
bool should_be_sparse_tensor = i < sparse_.size() && sparse_[i];
unsigned elms = 0;
unsigned dims = 0;
std::vector<tf::int64> max_dims;
if (!should_be_sparse_tensor) {
bool is_uniform = false;
TF_DALI_CALL(is_uniform = daliOutputHasUniformShape(&pipe_handle_, i));
if (!is_uniform) {
std::stringstream shapes;
for (int sample_id = 0; sample_id < batch_size_; sample_id++) {
AutoCPtr<int64_t> dali_shape;
TF_DALI_CALL(dali_shape = AutoCPtr<int64_t>(
daliShapeAtSample(&pipe_handle_, i, sample_id)));
shapes << DaliToShape(dali_shape);
if (sample_id < batch_size_ - 1) {
shapes << ", ";
}
}
OP_REQUIRES(
context,
is_uniform,
tensorflow::errors::FailedPrecondition(
"Batch output at index '", i,
"' from DALI pipeline is not uniform - individual samples have different "
"dimensions. This output cannot be represented as single, dense Tensor, which is "
"required by TensorFlow. Ensure that all the samples that you produce in given "
"batch have equal shape. Or use sparse output representation. Got shapes: ",
shapes.str()));
}
tf::TensorShape data_output_shape;
TF_DALI_CALL(data_output_shape = DaliToShape(AutoCPtr<int64_t>(
daliShapeAt(&pipe_handle_, i))));
// If tensor has shape provided it need to match
OP_REQUIRES(context, shapes_[i].dims() <= 0 || data_output_shape == shapes_[i],
tf::errors::InvalidArgument("DALI pipeline output shape at " + std::to_string(i) +
" " + data_output_shape.DebugString() + " != "
+ shapes_[i].DebugString() + " plugin `shapes` argument"));
OP_REQUIRES_OK(context, outputs.allocate(j, data_output_shape, &data_output_tensors[j]));
} else {
TF_DALI_CALL(elms = daliNumTensors(&pipe_handle_, i));
// maximum number of dimension + one additional to hold tensor list number
TF_DALI_CALL(dims = daliMaxDimTensors(&pipe_handle_, i) + 1);
max_dims.resize(dims, 0);
// first dim is number of elements in the tensor list
max_dims[0] = elms;
tf::TensorShape data_output_shape;
tf::int64 total_elms = 0;
TF_DALI_CALL(total_elms = daliNumElements(&pipe_handle_, i));
OP_REQUIRES_OK(context, outputs.allocate(j, tf::TensorShape({total_elms, dims}),
&data_output_tensors[j]));
tf::Tensor* out_tensor = data_output_tensors[j];
auto p_out_indices = out_tensor->flat<tf::int64>().data();
for (unsigned n = 0; n < elms; ++n) {
TF_DALI_CALL(data_output_shape = DaliToShape(AutoCPtr<int64_t>(
daliShapeAtSample(&pipe_handle_, i, n))));
// it seems that num_elements() return 1 for empty tensors
if (data_output_shape.dims() == 0) {
continue;
}
// squeeze
if (data_output_shape.dim_size(data_output_shape.dims() - 1) == 1) {
data_output_shape.RemoveLastDims(1);
}
for (unsigned elm_idx = 0; elm_idx < data_output_shape.num_elements(); ++elm_idx) {
unsigned idx_val = elm_idx;
// first value of indices is tensor index
*p_out_indices = n;
++p_out_indices;
for (unsigned k = 0; k < dims - 1; ++k, ++p_out_indices) {
const int dims_idx = k - (dims - 1 - data_output_shape.dims());
// if current element has less dims than max then set first idices to 0
if (k + data_output_shape.dims() < dims - 1) {
*p_out_indices = 0;
} else {
max_dims[k + 1] = std::max(max_dims[k + 1], data_output_shape.dim_size(dims_idx));
if (dims_idx < data_output_shape.dims() - 1) {
*p_out_indices = idx_val / data_output_shape.dim_size(dims_idx + 1);
idx_val %= data_output_shape.dim_size(dims_idx + 1);
} else {
*p_out_indices = idx_val;
}
}
}
}
}
++j;
// allocate output
OP_REQUIRES_OK(context, outputs.allocate(j, tf::TensorShape({total_elms}),
&data_output_tensors[j]));
}
void *dst = nullptr;
tf::Tensor* out_tensor = data_output_tensors[j];
size_t dali_tensor_size = 0;
TF_DALI_CALL(dali_tensor_size = daliTensorSize(&pipe_handle_, i));
if (dali_tensor_size > out_tensor->TotalBytes()) {
context->CtxFailure(__FILE__, __LINE__,
tf::errors::InvalidArgument("Output " + std::to_string(i) +
" has bigger size than allocated by TensorFlow - check if type requested matches" +
" with one produced by the DALI pipeline"));
}
switch (types_[j]) {
case tf::DT_HALF:
dst = reinterpret_cast<void*>(out_tensor->flat<Eigen::half>().data());
break;
case tf::DT_FLOAT:
dst = reinterpret_cast<void*>(out_tensor->flat<float>().data());
break;
case tf::DT_UINT8:
dst = reinterpret_cast<void*>(out_tensor->flat<uint8_t>().data());
break;
case tf::DT_INT16:
dst = reinterpret_cast<void*>(out_tensor->flat<int16_t>().data());
break;
case tf::DT_INT32:
dst = reinterpret_cast<void*>(out_tensor->flat<int32_t>().data());
break;
case tf::DT_INT64:
dst = reinterpret_cast<void*>(out_tensor->flat<tf::int64>().data());
break;
default:
context->CtxFailure(__FILE__, __LINE__,
tf::errors::InvalidArgument("Unsupported type: " + tf::DataTypeString(types_[i]) +
"for tensor " + std::to_string(i)));
break;
}
// Synchronize with the dataset()->stream_ when doing the last copy, so the outputs
// are fully finished before we release the output buffers for reuse.
// if the OP runs on the CPU the output memory is not pinned and we don't need to sync
unsigned int wait_flag = this->device_type_ != device_type_t::CPU && (i == dali_num_out - 1) ?
DALI_ext_force_sync :
DALI_ext_default;
TF_DALI_CALL(
daliOutputCopy(&pipe_handle_, dst, i, this->device_type_, stream, wait_flag));
if (should_be_sparse_tensor) {
++j;
// copy out shape
OP_REQUIRES_OK(context, outputs.allocate(j, tf::TensorShape({dims}),
&data_output_tensors[j]));
auto out_tensor = data_output_tensors[j];
auto out_shape = out_tensor->flat<tf::int64>().data();
for (unsigned k = 0; k < dims; ++k) {
out_shape[k] = max_dims[k];
}
}
}
int64_t copy_time = std::chrono::duration_cast<std::chrono::microseconds>(
Clock::now() - s).count();
TF_DALI_CALL(daliOutputRelease(&pipe_handle_));
LOG_LINE << "Computing...\n";
s = Clock::now();
TF_DALI_CALL(daliRun(&pipe_handle_));
int64_t run_time = std::chrono::duration_cast<std::chrono::microseconds>(
Clock::now() - s).count();
int64_t total_time = std::chrono::duration_cast<std::chrono::microseconds>(
Clock::now() - total_s).count();
LOG_LINE << "[TIMES] TOTAL " << total_time << " RUN " << run_time
<< " - OUTPUT " << output_time << " - ALLOC + COPY " << copy_time << std::endl;
}
private:
daliPipelineHandle pipe_handle_ = nullptr;
std::vector<tf::TensorShape> shapes_;
tf::DataTypeVector types_;
int device_id_ = -1;
int batch_size_ = 0;
int prefetch_queue_depth_ = -1;
device_type_t device_type_ = CPU;
std::vector<bool> sparse_;
bool enable_memory_stats_ = false;
};
using tf::int64;
REGISTER_KERNEL_BUILDER(Name("Dali").Device(tf::DEVICE_GPU), DaliOp)
REGISTER_KERNEL_BUILDER(Name("Dali").Device(tf::DEVICE_CPU), DaliOp)
} // namespace dali_tf_impl