Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#171 from graphcore/resize_tensor_insi…
Browse files Browse the repository at this point in the history
…de_run

Resize tensor inside ipu_backend::Run()
  • Loading branch information
yiakwy-xpu-ml-framework-team authored Sep 22, 2021
2 parents 3eb17d6 + 99dbb32 commit 4bab0d4
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 63 deletions.
6 changes: 2 additions & 4 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,15 @@ void IpuBackend::Compile(ir::Graph* graph,
compiler_->LowerWeights(graph, scope_);
compiler_->LowerBody(graph);
compiler_->InitOutputs(fetch_list);
executor_->SetOutputTensorId(compiler_->GetOutputTensors());
executor_->SetWeights(compiler_->GetWeights());
VLOG(10) << "leave IpuBackend::Compile";
}

void IpuBackend::Run(const std::vector<const Tensor*>& inputs,
const std::vector<Tensor*>& outputs) {
void IpuBackend::Run(const framework::ExecutionContext& ctx) {
Prepare();
auto inputs_id = compiler_->GetInputs();
auto outputs_id = compiler_->GetOutputs();
executor_->Run(inputs_id, inputs, outputs_id, outputs);
executor_->Run(inputs_id, outputs_id, ctx);
}

void IpuBackend::Prepare() {
Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/framework/ipu/ipu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ipu/ipu_compiler.h"
#include "paddle/fluid/framework/ipu/ipu_executor.h"
#include "paddle/fluid/framework/ipu/ipu_strategy.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -53,15 +54,11 @@ class IpuBackend {
void Compile(ir::Graph *graph, const std::vector<std::string> &feed_list,
const std::vector<std::string> &fetch_list);

// need doc
void Prepare();

// what run does include:
// 1. construct forward onnx graph
// 2. graph-level optimization
// 3. autodiff
void Run(const std::vector<const Tensor *> &inputs,
const std::vector<Tensor *> &outputs);
void Run(const framework::ExecutionContext &ctx);

Executor &GetExecutor() { return *executor_; }

Expand All @@ -78,6 +75,7 @@ class IpuBackend {

private:
int UpperIpuNum();
void Prepare();

private:
std::shared_ptr<Compiler> compiler_;
Expand Down
9 changes: 0 additions & 9 deletions paddle/fluid/framework/ipu/ipu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,15 +367,6 @@ std::vector<int64_t> Compiler::GetTensorShape(const std::string& name) {
return builder_->getTensorShape(tensors_[name]);
}

std::map<std::string, std::string> Compiler::GetOutputTensors() {
std::map<std::string, std::string> outputs;
for (const auto& fetch_name : fetch_list_) {
auto tensorid = tensors_[fetch_name];
outputs[fetch_name] = tensorid;
}
return outputs;
}

std::vector<popart::TensorId>& Compiler::GetWeights() { return weights_; }

std::string Compiler::GetModelProto() { return builder_->getModelProto(); }
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/ipu/ipu_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class Compiler {
std::vector<popart::TensorId> GetOutputs() { return outputs_; }
std::map<std::string, popart::TensorId> GetTensors() { return tensors_; }
std::vector<int64_t> GetTensorShape(const std::string &name);
std::map<std::string, std::string> GetOutputTensors();
std::vector<popart::TensorId> &GetWeights();

std::string GetModelProto();
Expand Down
34 changes: 16 additions & 18 deletions paddle/fluid/framework/ipu/ipu_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ void Executor::Prepare(const std::string &proto,
}

void Executor::Run(const std::vector<popart::TensorId> &inputs_id,
const std::vector<const Tensor *> &inputs,
const std::vector<popart::TensorId> &outputs_id,
const std::vector<Tensor *> &outputs) {
const framework::ExecutionContext &ctx) {
auto inputs = ctx.MultiInput<framework::Tensor>("FeedList");
auto outputs = ctx.MultiOutput<framework::Tensor>("FetchList");
// inputs
std::map<popart::TensorId, popart::IArray &> popart_inputs;
std::map<popart::TensorId, PaddleIArray> input_wrappers;
for (size_t i = 0; i < inputs.size(); i++) {
Expand All @@ -92,12 +94,23 @@ void Executor::Run(const std::vector<popart::TensorId> &inputs_id,
input_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_inputs.emplace(tensor_id, input_wrappers.at(tensor_id));
}

// anchors
std::map<popart::TensorId, popart::IArray &> popart_anchors;
std::map<popart::TensorId, PaddleIArray> anchor_wrappers;
for (size_t i = 0; i < outputs.size(); i++) {
auto tensor_id = outputs_id[i];
auto tensor = const_cast<Tensor *>(outputs[i]);
// get dims & dtype from session
auto fetch_info = session_->getInfo(tensor_id);
auto output_shape = fetch_info.shape();
if (ipu_strategy_->batches_per_step > 1) {
output_shape.insert(output_shape.begin(),
ipu_strategy_->batches_per_step);
}
tensor->Resize(framework::make_ddim(output_shape));
auto fetch_dtype = fetch_info.dataType();
auto paddle_type = PopartType2VarType(fetch_dtype);
tensor->mutable_data(ctx.GetPlace(), paddle_type);
anchor_wrappers.emplace(tensor_id, PaddleIArray(tensor));
popart_anchors.emplace(tensor_id, anchor_wrappers.at(tensor_id));
}
Expand Down Expand Up @@ -191,21 +204,6 @@ void Executor::SetIpuStrategy(const IpuStrategy &strategy) {
ipu_strategy_ = &strategy;
}

void Executor::SetOutputTensorId(
const std::map<std::string, std::string> &outputs) {
outputs_ = outputs;
}

std::vector<int64_t> Executor::GetOutputShape(const std::string &fetch_name) {
auto tensor_id = outputs_[fetch_name];
auto fetch_info = session_->getInfo(tensor_id);
auto output_shape = fetch_info.shape();
if (ipu_strategy_->batches_per_step > 1) {
output_shape.insert(output_shape.begin(), ipu_strategy_->batches_per_step);
}
return output_shape;
}

float Executor::GetLRFromScope() {
auto lr_var = scope_->GetVar(opt_info.GetLRVarName());
auto tensor = lr_var->Get<framework::LoDTensor>();
Expand Down
9 changes: 2 additions & 7 deletions paddle/fluid/framework/ipu/ipu_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ipu/ipu_optimizer.h"
#include "paddle/fluid/framework/ipu/ipu_strategy.h"
#include "paddle/fluid/framework/ipu/ipu_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
Expand All @@ -38,9 +39,8 @@ class Executor {
const std::vector<popart::TensorId> &outputs,
std::shared_ptr<popart::DeviceInfo> device);
void Run(const std::vector<popart::TensorId> &inputs_id,
const std::vector<const Tensor *> &inputs,
const std::vector<popart::TensorId> &outputs_id,
const std::vector<Tensor *> &outputs);
const framework::ExecutionContext &ctx);

// Optimizer
void SetOptimizerType(const std::string &type);
Expand All @@ -61,10 +61,6 @@ class Executor {
// Strategy
void SetIpuStrategy(const IpuStrategy &strategy);

// Outputs
void SetOutputTensorId(const std::map<std::string, std::string> &outputs);
std::vector<int64_t> GetOutputShape(const std::string &fetch_name);

private:
float GetLRFromScope();

Expand All @@ -77,7 +73,6 @@ class Executor {
const IpuStrategy *ipu_strategy_ = nullptr;
popart::WeightsIO weights_io_;
std::vector<popart::TensorId> weights_;
std::map<std::string, std::string> outputs_;
};

} // namespace ipu
Expand Down
36 changes: 34 additions & 2 deletions paddle/fluid/framework/ipu/ipu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ std::size_t PaddleIArray::nelms() const {

const popart::Shape PaddleIArray::shape() const { return shape_; }

popart::DataType VarType2PopartType(proto::VarType::Type type) {
popart::DataType VarType2PopartType(const proto::VarType::Type type) {
switch (type) {
case proto::VarType::UINT8:
return popart::DataType::UINT8;
Expand Down Expand Up @@ -69,7 +69,39 @@ popart::DataType VarType2PopartType(proto::VarType::Type type) {
}
}

popart::DataType OnnxDtype2PopartType(int type) {
proto::VarType::Type PopartType2VarType(const popart::DataType type) {
switch (type) {
case popart::DataType::UINT8:
return proto::VarType::UINT8;
case popart::DataType::INT8:
return proto::VarType::INT8;
case popart::DataType::INT16:
return proto::VarType::INT16;
case popart::DataType::INT32:
return proto::VarType::INT32;
case popart::DataType::INT64:
return proto::VarType::INT64;
case popart::DataType::BOOL:
return proto::VarType::BOOL;
case popart::DataType::DOUBLE:
return proto::VarType::FP64;
case popart::DataType::FLOAT:
return proto::VarType::FP32;
case popart::DataType::FLOAT16:
return proto::VarType::FP16;
case popart::DataType::BFLOAT16:
return proto::VarType::BF16;
case popart::DataType::COMPLEX64:
return proto::VarType::COMPLEX64;
case popart::DataType::COMPLEX128:
return proto::VarType::COMPLEX128;
default:
PADDLE_THROW(paddle::platform::errors::Unavailable(
"Unsupported Paddle var type."));
}
}

popart::DataType OnnxDtype2PopartType(const int type) {
auto dtype = static_cast<ONNXDataType>(type);
switch (dtype) {
case ONNXDataType::BOOL:
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/framework/ipu/ipu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ class PaddleIArray final : public popart::IArray {
std::vector<int64_t> shape_;
};

popart::DataType VarType2PopartType(proto::VarType::Type type);
popart::DataType OnnxDtype2PopartType(int type);
popart::DataType VarType2PopartType(const proto::VarType::Type type);
proto::VarType::Type PopartType2VarType(const popart::DataType type);
popart::DataType OnnxDtype2PopartType(const int type);
bool GetBoolEnv(std::string str);

template <typename T>
Expand Down
18 changes: 4 additions & 14 deletions paddle/fluid/operators/ipu_runtime_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,13 @@ class IpuRuntimeKernel : public framework::OpKernel<T> {
ctx.device_context());
ipu_backend->AttachDevice(ipu_ctx.DeviceId());
}
VLOG(4) << "IpuBackend prepare session";
ipu_backend->Prepare();

VLOG(4) << "IpuRuntime Kernel, begin to run graph";
auto inputs = ctx.MultiInput<framework::Tensor>("FeedList");
ipu_backend->Run(ctx);

// post-run
auto outputs = ctx.MultiOutput<framework::Tensor>("FetchList");
auto output_names = ctx.OutputNames("FetchList");
for (size_t i = 0; i < outputs.size(); ++i) {
auto* out = outputs[i];
auto oshape = ipu_backend->GetExecutor().GetOutputShape(output_names[i]);
out->Resize(framework::make_ddim(oshape));
// TODO(alleng) support muti-output dtypes
// maybe get dtype from ipu_backend
out->mutable_data<T>(ctx.GetPlace());
}

ipu_backend->Run(inputs, outputs);

// resize tensor when tensor.dims() is empty
for (size_t i = 0; i < outputs.size(); ++i) {
auto* out = outputs[i];
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ endif()

cc_library(jit_kernel_helper INTERFACE SRCS ${jit_kernel_cc_srcs} DEPS jit_kernel_base ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
if(NOT WIN32)
# TODO(alleng) fix error when WITH_IPU
if(NOT WIN32 AND NOT WITH_IPU)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper device_tracer tensor)
endif()
if(WITH_TESTING AND TEST jit_kernel_test)
Expand Down

0 comments on commit 4bab0d4

Please sign in to comment.