Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Add op uniform_random (#1149)
Browse files Browse the repository at this point in the history
* Add custom call op uniform random

* Reformat code

* Use direct function call instead of CustomInstr
  • Loading branch information
FisherWY authored Jan 17, 2023
1 parent f3dfd27 commit 9ba68f3
Show file tree
Hide file tree
Showing 17 changed files with 502 additions and 1 deletion.
12 changes: 12 additions & 0 deletions cinn/frontend/net_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -656,5 +656,17 @@ Variable NetBuilder::GaussianRandom(
.front();
}

Variable NetBuilder::UniformRandom(
const std::vector<int>& shape, float min, float max, int seed, const std::string& dtype) {
auto uniform_out =
CustomInstr(
"uniform_random", {}, {{"shape", shape}, {"min", min}, {"max", max}, {"seed", seed}, {"dtype", dtype}})
.front();
auto uniform_range = FillConstant(shape, max - min, UniqName("uniform_range"), dtype);
auto uniform_mul_out = Multiply(uniform_out, uniform_range);
auto uniform_min = FillConstant(shape, min, UniqName("uniform_min"), dtype);
return Add(uniform_mul_out, uniform_min);
}

} // namespace frontend
} // namespace cinn
14 changes: 14 additions & 0 deletions cinn/frontend/net_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,20 @@ class NetBuilder {
int seed = 0,
const std::string& dtype = "float32");

/**
* @brief Uniform random
* @param shape Shape of the variable to be created.
* @param min The lower bound of the range of random values ​​generated, min is included in the range.
* @param max The upper bound of the range of random values ​​generated, max is not included in the range.
* @param seed Random seed of generator, default is 0.
* @param dtype Data tpye of output variable, supported data types: float32, float64.
*/
Variable UniformRandom(const std::vector<int>& shape,
float min = -1.0f,
float max = 1.0f,
int seed = 0,
const std::string& dtype = "float32");

private:
CINN_DISALLOW_COPY_AND_ASSIGN(NetBuilder);
};
Expand Down
51 changes: 51 additions & 0 deletions cinn/frontend/op_mappers/paddle/uniform_random.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/common_utils.h"
#include "cinn/frontend/var_type_utils.h"

namespace cinn {
namespace frontend {
namespace paddle_mappers {

void UniformRandomOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
auto out_name = op_desc.Output("Out").front();

auto shape_origin = utils::GetAttrOrDefault<std::vector<int64_t>>(op_desc, "shape");
auto shape = utils::ToShapeType(shape_origin);

auto min = utils::GetAttrOrDefault<float>(op_desc, "min", -1.0f);
auto max = utils::GetAttrOrDefault<float>(op_desc, "max", 1.0f);
auto seed = utils::GetAttrOrDefault<int>(op_desc, "seed", 0);

auto dtype_id = utils::GetAttrOrDefault<int>(op_desc, "dtype", static_cast<int>(paddle::cpp::VarDescAPI::Type::FP32));
auto dtype_pd = static_cast<paddle::cpp::VarDescAPI::Type>(dtype_id);
auto dtype_cinn = utils::CppVarType2CommonType(dtype_pd);
auto dtype = common::Type2Str(dtype_cinn);

auto out = ctx.Builder()->UniformRandom(shape, min, max, seed, dtype);
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
}

} // namespace paddle_mappers
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(paddle_uniform_random) {
CINN_REGISTER_OP_MAPPER(uniform_random, cinn::frontend::paddle_mappers::UniformRandomOpMapper)
return true;
}
1 change: 1 addition & 0 deletions cinn/frontend/op_mappers/use_op_mappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ CINN_USE_REGISTER(paddle_gather)
CINN_USE_REGISTER(paddle_reduce)
CINN_USE_REGISTER(paddle_atan)
CINN_USE_REGISTER(paddle_gaussian_random)
CINN_USE_REGISTER(paddle_uniform_random)

CINN_USE_REGISTER(science_broadcast)
CINN_USE_REGISTER(science_transform)
Expand Down
1 change: 1 addition & 0 deletions cinn/frontend/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ OptimizeOptions DefaultTrainingOptimizeOptions() {
options.graph_passes.push_back("MatmulToCublasCustomCallPass");
}
options.graph_passes.push_back("GaussianRandomToCustomCallPass");
options.graph_passes.push_back("UniformRandomToCustomCallPass");
#ifdef CINN_WITH_CUDNN
if (FLAGS_cinn_use_cudnn_conv) {
options.graph_passes.push_back("ConvToCudnnCustomCallPass");
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/contrib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ gather_srcs(cinnapi_src SRCS
one_hot.cc
reciprocal.cc
gaussian_random.cc
uniform_random.cc
)

cc_test(test_gather_nd SRCS gather_nd_test.cc DEPS cinncore)
Expand Down
109 changes: 109 additions & 0 deletions cinn/hlir/op/contrib/uniform_random.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gflags/gflags.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "absl/types/variant.h"
#include "cinn/common/cas.h"
#include "cinn/common/cinn_value.h"
#include "cinn/common/common.h"
#include "cinn/common/context.h"
#include "cinn/common/ir_util.h"
#include "cinn/common/macros.h"
#include "cinn/common/target.h"
#include "cinn/hlir/framework/node.h"
#include "cinn/hlir/framework/op.h"
#include "cinn/hlir/framework/op_strategy.h"
#include "cinn/hlir/op/op_util.h"
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/hlir/pe/schedule.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_base.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/ir/tensor.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"
#include "cinn/lang/packed_func.h"
#include "cinn/poly/stage.h"
#include "glog/logging.h"

namespace cinn {
namespace hlir {
namespace op {

using common::CINNValue;
using common::CINNValuePack;

std::shared_ptr<framework::OpStrategy> StrategyForUniformRandom(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute uniform_random_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(attrs.attr_store.count("shape"));
ir::Tensor shape_tensor;
std::string tensor_name = "uniform_random_out";
auto out = pe::Identity(shape_tensor, tensor_name).front();
auto stages = CreateStages({out});
std::vector<CINNValue> res{CINNValue(out), CINNValue(stages)};
*ret = CINNValuePack{res};
});
auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
uniform_random_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.uniform_random.x86", 1);
return strategy;
}

std::vector<framework::shape_t> InferShapeForUniformRandom(const std::vector<framework::shape_t> &inputs_shape,
const framework::AttrMapType &attrs) {
CHECK(attrs.count("shape"));
auto shape = absl::get<std::vector<int>>(attrs.at("shape"));
CHECK(!shape.empty()) << "shape attr is empty!";
return {shape};
}

std::vector<Type> InferDtypeForUniformRandom(const std::vector<Type> &inputs_type,
const framework::AttrMapType &attrs) {
std::string dtype = "float32";
if (attrs.find("dtype") != attrs.end()) {
dtype = absl::get<std::string>(attrs.at("dtype"));
}
std::vector<Type> res{common::Str2Type(dtype)};
return res;
}

} // namespace op
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(uniform_random_ops) {
CINN_REGISTER_OP(uniform_random)
.describe("UniformRandom")
.set_num_inputs(0)
.set_num_outputs(1)
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForUniformRandom)
.set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForUniformRandom))
.set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForUniformRandom))
.set_attr<cinn::hlir::framework::OpPatternKind>("OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise)
.set_support_level(4);

return true;
}
20 changes: 20 additions & 0 deletions cinn/hlir/op/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,12 +470,32 @@ std::vector<ir::Expr> CustomCallArgsForGaussianRandom(const framework::NodeAttr
return args;
}

std::vector<ir::Expr> CustomCallArgsForUniformRandom(const framework::NodeAttr &attrs,
const std::vector<ir::Tensor> &inputs,
const std::vector<std::vector<int>> &output_shapes) {
CHECK_EQ(output_shapes.size(), 1UL);

auto attr_store = attrs.attr_store;

float min = attr_store.count("min") ? absl::get<float>(attrs.attr_store.at("min")) : -1.0f;
float max = attr_store.count("max") ? absl::get<float>(attrs.attr_store.at("max")) : 1.0f;
int seed = attr_store.count("seed") ? absl::get<int>(attrs.attr_store.at("seed")) : 0;

CHECK_GE(max, min) << "Arg max must greater than min, please check.";

std::vector<ir::Expr> args = {ir::Expr(min), ir::Expr(max), ir::Expr(seed)};

return args;
}

bool RegisteryCustomCallArgsFunc() {
#ifdef CINN_WITH_CUDA
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_cublas", common::DefaultNVGPUTarget(), CustomCallArgsForCublas);
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_gaussian_random", common::DefaultNVGPUTarget(), CustomCallArgsForGaussianRandom);
CustomCallArgsFuncRegistry::Global().Register(
"cinn_call_uniform_random", common::DefaultNVGPUTarget(), CustomCallArgsForUniformRandom);
#endif

#ifdef CINN_WITH_CUDNN
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/use_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ CINN_USE_REGISTER(one_hot_ops)
CINN_USE_REGISTER(lookup_table_ops)
CINN_USE_REGISTER(reciprocal_ops)
CINN_USE_REGISTER(gaussian_random_ops)
CINN_USE_REGISTER(uniform_random_ops)
30 changes: 30 additions & 0 deletions cinn/hlir/pass/custom_call_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ class GraphAlterHelper {
}
}

void UniformRandomToCustomCall() {
auto nodes = graph_->CollectNodes([](const common::GraphNode* graph_node) -> bool {
if (graph_node->safe_as<Node>()) {
auto node = graph_node->safe_as<Node>();
if (node->op()->name == "uniform_random") {
return true;
}
}

return false;
});

for (auto gnode : nodes) {
auto src = gnode->safe_as<Node>();
CHECK(src);
src->attrs.op = framework::Operator::Get("custom_call");
src->attrs.attr_store["custom_call"] = std::string("cinn_call_uniform_random");
}
}

private:
Graph* graph_;
};
Expand All @@ -138,6 +158,12 @@ void GaussianRandomToCustomCallInternal(Graph* graph) {
VLOG(3) << "GaussianRandomToCustomCall Finish...!";
}

void UniformRandomToCustomCallInternal(Graph* graph) {
VLOG(3) << "UniformRandomToCustomCall...!";
GraphAlterHelper(graph).UniformRandomToCustomCall();
VLOG(3) << "UniformRandomToCustomCall Finish...!";
}

} // namespace pass
} // namespace hlir
} // namespace cinn
Expand All @@ -152,6 +178,10 @@ CINN_REGISTER_HELPER(CustomCallPass) {
.describe("This pass which convert gaussian random op to custom call pass.")
.set_change_structure(false)
.set_body(cinn::hlir::pass::GaussianRandomToCustomCallInternal);
CINN_REGISTER_PASS(UniformRandomToCustomCallPass)
.describe("This pass which convert uniform random op to custom call pass.")
.set_change_structure(false)
.set_body(cinn::hlir::pass::UniformRandomToCustomCallInternal);
#endif
#ifdef CINN_WITH_CUDNN
CINN_REGISTER_PASS(ConvToCudnnCustomCallPass)
Expand Down
7 changes: 7 additions & 0 deletions cinn/pybind/frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,13 @@ void BindFrontend(pybind11::module *m) {
py::arg("mean") = 0.0f,
py::arg("std") = 1.0f,
py::arg("seed") = 0,
py::arg("dtype") = "float32")
.def("uniform_random",
&NetBuilder::UniformRandom,
py::arg("shape"),
py::arg("min") = -1.0f,
py::arg("max") = 1.0f,
py::arg("seed") = 0,
py::arg("dtype") = "float32");

auto computation = py::class_<CinnComputation, std::shared_ptr<CinnComputation>>(*m, "Computation");
Expand Down
11 changes: 11 additions & 0 deletions cinn/runtime/cuda/cuda_intrinsics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ CINN_REGISTER_HELPER(cinn_cuda_host_api) {
.AddInputType<void *>() // stream
.End();

using cinn::runtime::cuda::cinn_call_uniform_random;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_uniform_random, cinn::common::DefaultHostTarget())
.SetRetType<void>()
.AddInputType<void *>() // v_args
.AddInputType<int>() // num_args
.AddInputType<float>() // min
.AddInputType<float>() // max
.AddInputType<int>() // seed
.AddInputType<void *>() // stream
.End();

#ifdef CINN_WITH_CUDNN
using cinn::runtime::cuda::cinn_call_cudnn_conv2d_forward;
REGISTER_EXTERN_FUNC_HELPER(cinn_call_cudnn_conv2d_forward, cinn::common::DefaultHostTarget())
Expand Down
23 changes: 22 additions & 1 deletion cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,9 @@ void cinn_call_gaussian_random(void *v_args, int num_args, float mean, float std
size_t numel = output->num_elements();
curandGenerator_t generator;
CURAND_CALL(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_PHILOX4_32_10));
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(generator, seed));
if (seed != 0) {
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(generator, static_cast<unsigned long long>(seed)));
}
if (dtype == cinn_float32_t()) {
float *ptr = reinterpret_cast<float *>(output->memory);
CURAND_CALL(curandGenerateNormal(generator, ptr, numel, mean, std));
Expand All @@ -1065,6 +1067,25 @@ void cinn_call_gaussian_random(void *v_args, int num_args, float mean, float std
}
}

void cinn_call_uniform_random(void *v_args, int num_args, float min, float max, int seed, void *stream) {
cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
cinn_buffer_t *output = args[0].operator cinn_buffer_t *();
cinn_type_t dtype = output->type;
size_t numel = output->num_elements();
curandGenerator_t generator;
CURAND_CALL(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_PHILOX4_32_10));
if (seed != 0) {
CURAND_CALL(curandSetPseudoRandomGeneratorSeed(generator, static_cast<unsigned long long>(seed)));
}
if (dtype == cinn_float32_t()) {
float *ptr = reinterpret_cast<float *>(output->memory);
CURAND_CALL(curandGenerateUniform(generator, ptr, numel));
} else if (dtype == cinn_float64_t()) {
double *ptr = reinterpret_cast<double *>(output->memory);
CURAND_CALL(curandGenerateUniformDouble(generator, ptr, numel));
}
}

#ifdef CINN_WITH_CUDNN

namespace {
Expand Down
2 changes: 2 additions & 0 deletions cinn/runtime/cuda/cuda_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void cinn_gpu_cublas_gemm(const std::vector<int>& attrs,

void cinn_call_gaussian_random(void* v_args, int num_args, float mean, float std, int seed, void* stream = nullptr);

void cinn_call_uniform_random(void* v_args, int num_args, float min, float max, int seed, void* stream = nullptr);

#ifdef CINN_WITH_CUDNN
void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map<std::string, int>& attr,
cinn_buffer_t* x,
Expand Down
Loading

0 comments on commit 9ba68f3

Please sign in to comment.