Skip to content

Commit

Permalink
Use const Value* where possible (pytorch#146)
Browse files Browse the repository at this point in the history
  • Loading branch information
bertmaher authored Feb 12, 2020
1 parent 8cfdd14 commit f3990d7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 27 deletions.
34 changes: 17 additions & 17 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ static std::vector<Expr> texprSizes(const c10::VaryingShape& shape) {
return dims;
}

static std::vector<DimArg> texprDims(torch::jit::Value* v) {
static std::vector<DimArg> texprDims(const torch::jit::Value* v) {
CHECK(v->type()->kind() == TypeKind::TensorType);
auto tt = v->type()->cast<TensorType>();
std::vector<DimArg> dimArgs;
Expand Down Expand Up @@ -64,7 +64,7 @@ int64_t bufferSize(T t) {
return size;
}

Expr TensorExprKernel::constant(torch::jit::Value* v) {
Expr TensorExprKernel::constant(const torch::jit::Value* v) {
if (v->node()->kind() == prim::Constant) {
const auto val = toIValue(v).value();
if (val.isDouble()) {
Expand Down Expand Up @@ -94,7 +94,7 @@ void TensorExprKernel::promoteInputs(std::vector<Expr>& inputs) {
}
}

Expr TensorExprKernel::demoteOutput(const Expr& e, torch::jit::Value* v) {
Expr TensorExprKernel::demoteOutput(const Expr& e, const torch::jit::Value* v) {
CHECK(v->type()->kind() == TypeKind::TensorType);
auto tt = v->type()->cast<TensorType>()->scalarType();
if (e.dtype() == kFloat32 && tt == at::ScalarType::Int) {
Expand All @@ -106,11 +106,11 @@ Expr TensorExprKernel::demoteOutput(const Expr& e, torch::jit::Value* v) {

Tensor TensorExprKernel::ComputeOneOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&)> inner_expr) {
return Compute(
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
std::vector<Expr> inputs = {tensorOrConstant(n->inputs()[0], axes)};

promoteInputs(inputs);
Expand All @@ -121,11 +121,11 @@ Tensor TensorExprKernel::ComputeOneOperand(

Tensor TensorExprKernel::ComputeTwoOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&)> inner_expr) {
return Compute(
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
std::vector<Expr> inputs = {
tensorOrConstant(n->inputs()[0], axes),
tensorOrConstant(n->inputs()[1], axes),
Expand All @@ -139,11 +139,11 @@ Tensor TensorExprKernel::ComputeTwoOperand(

Tensor TensorExprKernel::ComputeTwoOperandWithAlpha(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&)> inner_expr) {
return Compute(
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
std::vector<Expr> inputs = {
tensorOrConstant(n->inputs()[0], axes),
tensorOrConstant(n->inputs()[1], axes),
Expand All @@ -158,11 +158,11 @@ Tensor TensorExprKernel::ComputeTwoOperandWithAlpha(

Tensor TensorExprKernel::ComputeThreeOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&, const Expr&)> inner_expr) {
return Compute(
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
std::vector<Expr> inputs = {
tensorOrConstant(n->inputs()[0], axes),
tensorOrConstant(n->inputs()[1], axes),
Expand All @@ -177,11 +177,11 @@ Tensor TensorExprKernel::ComputeThreeOperand(

Tensor TensorExprKernel::ComputeFourOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)> inner_expr) {
return Compute(
name, texprDims(v), [this, v, inner_expr](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
std::vector<Expr> inputs = {
tensorOrConstant(n->inputs()[0], axes),
tensorOrConstant(n->inputs()[1], axes),
Expand All @@ -195,7 +195,7 @@ Tensor TensorExprKernel::ComputeFourOperand(
});
}

Tensor TensorExprKernel::ComputeValue(torch::jit::Value* v) {
Tensor TensorExprKernel::ComputeValue(const torch::jit::Value* v) {
switch (v->node()->kind()) {
case aten::add: {
return ComputeTwoOperandWithAlpha(
Expand Down Expand Up @@ -524,7 +524,7 @@ Tensor TensorExprKernel::ComputeValue(torch::jit::Value* v) {
"prim_constantchunk",
texprDims(v),
[this, v](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
int64_t dim = n->i(attr::dim);
int64_t chunks = n->i(attr::chunks);
return chunk(
Expand All @@ -539,7 +539,7 @@ Tensor TensorExprKernel::ComputeValue(torch::jit::Value* v) {
case aten::cat: {
return Compute(
"aten_cat", texprDims(v), [this, v](const std::vector<Var>& axes) {
Node* n = v->node();
auto const& n = v->node();
auto inputs = n->inputs()[0]->node()->inputs();
size_t dim = n->inputs()[1]->node()->i(attr::value);

Expand Down Expand Up @@ -698,7 +698,7 @@ void TensorExprKernel::CodeGenRun(
}
}

void TensorExprKernel::bindInput(torch::jit::Value* input) {
void TensorExprKernel::bindInput(const torch::jit::Value* input) {
auto const& t = input->type();
switch (t->kind()) {
case TypeKind::TensorType: {
Expand Down
22 changes: 12 additions & 10 deletions torch/csrc/jit/tensorexpr/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TensorExprKernel {
kCudaCodeGen,
};

Expr constant(torch::jit::Value* v);
Expr constant(const torch::jit::Value* v);

template <typename T, typename T1>
Expr broadcast(const T& t, const std::vector<T1>& axes) {
Expand Down Expand Up @@ -85,10 +85,12 @@ class TensorExprKernel {

void promoteInputs(std::vector<Expr>& inputs);

Expr demoteOutput(const Expr& e, torch::jit::Value* v);
Expr demoteOutput(const Expr& e, const torch::jit::Value* v);

template <typename T>
Expr tensorOrConstant(torch::jit::Value* v, const std::vector<T>& axes) {
Expr tensorOrConstant(
const torch::jit::Value* v,
const std::vector<T>& axes) {
auto ti = tensors_.find(v->unique());
if (ti != tensors_.end()) {
return broadcast(ti->second, axes);
Expand All @@ -98,39 +100,39 @@ class TensorExprKernel {

Tensor ComputeOneOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&)> inner_expr);

Tensor ComputeTwoOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&)> inner_expr);

Tensor ComputeTwoOperandWithAlpha(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&)> inner_expr);

Tensor ComputeThreeOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&, const Expr&)> inner_expr);

Tensor ComputeFourOperand(
const std::string& name,
torch::jit::Value* v,
const torch::jit::Value* v,
std::function<Expr(const Expr&, const Expr&, const Expr&, const Expr&)>
inner_expr);

Tensor ComputeValue(torch::jit::Value* v);
Tensor ComputeValue(const torch::jit::Value* v);

void LowerToBackend(BackendType backend_type);

void PickAndCheckBackendType(const at::ArrayRef<IValue>& inputs);

void CodeGenRun(const std::vector<CodeGen::CallArg>& run_args);

void bindInput(torch::jit::Value* input);
void bindInput(const torch::jit::Value* input);

private:
std::vector<CodeGen::BufferArg> buffer_args_;
Expand Down

0 comments on commit f3990d7

Please sign in to comment.