From f3990d79516e59557a7e37669cda1f12ecd3487b Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Tue, 11 Feb 2020 21:55:42 -0800 Subject: [PATCH] Use `const Value*` where possible (#146) --- torch/csrc/jit/tensorexpr/kernel.cpp | 34 ++++++++++++++-------------- torch/csrc/jit/tensorexpr/kernel.h | 22 ++++++++++-------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 989e42781d8cd2..ce5f87d20d0808 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -35,7 +35,7 @@ static std::vector texprSizes(const c10::VaryingShape& shape) { return dims; } -static std::vector texprDims(torch::jit::Value* v) { +static std::vector texprDims(const torch::jit::Value* v) { CHECK(v->type()->kind() == TypeKind::TensorType); auto tt = v->type()->cast(); std::vector dimArgs; @@ -64,7 +64,7 @@ int64_t bufferSize(T t) { return size; } -Expr TensorExprKernel::constant(torch::jit::Value* v) { +Expr TensorExprKernel::constant(const torch::jit::Value* v) { if (v->node()->kind() == prim::Constant) { const auto val = toIValue(v).value(); if (val.isDouble()) { @@ -94,7 +94,7 @@ void TensorExprKernel::promoteInputs(std::vector& inputs) { } } -Expr TensorExprKernel::demoteOutput(const Expr& e, torch::jit::Value* v) { +Expr TensorExprKernel::demoteOutput(const Expr& e, const torch::jit::Value* v) { CHECK(v->type()->kind() == TypeKind::TensorType); auto tt = v->type()->cast()->scalarType(); if (e.dtype() == kFloat32 && tt == at::ScalarType::Int) { @@ -106,11 +106,11 @@ Expr TensorExprKernel::demoteOutput(const Expr& e, torch::jit::Value* v) { Tensor TensorExprKernel::ComputeOneOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr) { return Compute( name, texprDims(v), [this, v, inner_expr](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); std::vector inputs = {tensorOrConstant(n->inputs()[0], axes)}; promoteInputs(inputs); @@ -121,11 +121,11 @@ Tensor TensorExprKernel::ComputeOneOperand( Tensor TensorExprKernel::ComputeTwoOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr) { return Compute( name, texprDims(v), [this, v, inner_expr](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); std::vector inputs = { tensorOrConstant(n->inputs()[0], axes), tensorOrConstant(n->inputs()[1], axes), @@ -139,11 +139,11 @@ Tensor TensorExprKernel::ComputeTwoOperand( Tensor TensorExprKernel::ComputeTwoOperandWithAlpha( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr) { return Compute( name, texprDims(v), [this, v, inner_expr](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); std::vector inputs = { tensorOrConstant(n->inputs()[0], axes), tensorOrConstant(n->inputs()[1], axes), @@ -158,11 +158,11 @@ Tensor TensorExprKernel::ComputeTwoOperandWithAlpha( Tensor TensorExprKernel::ComputeThreeOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr) { return Compute( name, texprDims(v), [this, v, inner_expr](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); std::vector inputs = { tensorOrConstant(n->inputs()[0], axes), tensorOrConstant(n->inputs()[1], axes), @@ -177,11 +177,11 @@ Tensor TensorExprKernel::ComputeThreeOperand( Tensor TensorExprKernel::ComputeFourOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr) { return Compute( name, texprDims(v), [this, v, inner_expr](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); std::vector inputs = { tensorOrConstant(n->inputs()[0], axes), tensorOrConstant(n->inputs()[1], axes), @@ -195,7 +195,7 @@ Tensor TensorExprKernel::ComputeFourOperand( }); } -Tensor TensorExprKernel::ComputeValue(torch::jit::Value* v) { +Tensor TensorExprKernel::ComputeValue(const torch::jit::Value* v) { switch (v->node()->kind()) { case aten::add: { return ComputeTwoOperandWithAlpha( @@ -524,7 +524,7 @@ Tensor TensorExprKernel::ComputeValue(torch::jit::Value* v) { "prim_constantchunk", texprDims(v), [this, v](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); int64_t dim = n->i(attr::dim); int64_t chunks = n->i(attr::chunks); return chunk( @@ -539,7 +539,7 @@ Tensor TensorExprKernel::ComputeValue(torch::jit::Value* v) { case aten::cat: { return Compute( "aten_cat", texprDims(v), [this, v](const std::vector& axes) { - Node* n = v->node(); + auto const& n = v->node(); auto inputs = n->inputs()[0]->node()->inputs(); size_t dim = n->inputs()[1]->node()->i(attr::value); @@ -698,7 +698,7 @@ void TensorExprKernel::CodeGenRun( } } -void TensorExprKernel::bindInput(torch::jit::Value* input) { +void TensorExprKernel::bindInput(const torch::jit::Value* input) { auto const& t = input->type(); switch (t->kind()) { case TypeKind::TensorType: { diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index 088b9ccf009209..f0b15cf6d7c709 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -54,7 +54,7 @@ class TensorExprKernel { kCudaCodeGen, }; - Expr constant(torch::jit::Value* v); + Expr constant(const torch::jit::Value* v); template Expr broadcast(const T& t, const std::vector& axes) { @@ -85,10 +85,12 @@ class TensorExprKernel { void promoteInputs(std::vector& inputs); - Expr demoteOutput(const Expr& e, torch::jit::Value* v); + Expr demoteOutput(const Expr& e, const torch::jit::Value* v); template - Expr tensorOrConstant(torch::jit::Value* v, const std::vector& axes) { + Expr tensorOrConstant( + const torch::jit::Value* v, + const std::vector& axes) { auto ti = tensors_.find(v->unique()); if (ti != tensors_.end()) { return broadcast(ti->second, axes); @@ -98,31 +100,31 @@ class TensorExprKernel { Tensor ComputeOneOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr); Tensor ComputeTwoOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr); Tensor ComputeTwoOperandWithAlpha( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr); Tensor ComputeThreeOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr); Tensor ComputeFourOperand( const std::string& name, - torch::jit::Value* v, + const torch::jit::Value* v, std::function inner_expr); - Tensor ComputeValue(torch::jit::Value* v); + Tensor ComputeValue(const torch::jit::Value* v); void LowerToBackend(BackendType backend_type); @@ -130,7 +132,7 @@ class TensorExprKernel { void CodeGenRun(const std::vector& run_args); - void bindInput(torch::jit::Value* input); + void bindInput(const torch::jit::Value* input); private: std::vector buffer_args_;