diff --git a/include/tvm/relax/ir_builder.h b/include/tvm/relax/ir_builder.h index 352882116916..f97981a6cfef 100644 --- a/include/tvm/relax/ir_builder.h +++ b/include/tvm/relax/ir_builder.h @@ -82,7 +82,7 @@ class IRBuilderNode : public Object { /*! * \brief Generate an output for the current dataflow block or function. * \param output The output variable of the block/function. - * \return The variable being binded to \p ouput. + * \return The variable being binded to \p output. */ Var EmitOutput(const Expr& output); /*! @@ -107,13 +107,15 @@ class IRBuilderNode : public Object { private: /*! \brief The state of the function currently being built. */ - RelaxFunction func; + RelaxFunction func_; /*! \brief A flag tracking if currently inside a dataflow block or not. */ - bool is_dataflow = false; + bool is_dataflow_ = false; /*! \brief A global variable counter for naming global variables. */ - int global_var_counter = 0; + int global_var_counter_ = 0; /*! \brief A dataflow variable counter for naming dataflow variables. */ - int dataflow_var_counter = 0; + int dataflow_var_counter_ = 0; + /*! \brief A diagnostic context for reporting errors. */ + DiagnosticContext diag_ctx_ = DiagnosticContext::Default(IRModule({}, {})); }; class IRBuilder : public ObjectRef { diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h new file mode 100644 index 000000000000..d38b2729717d --- /dev/null +++ b/include/tvm/relax/op_attr_types.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/op_attr_types.h + * \brief Data structures that can appear in operator attributes. + */ +#ifndef TVM_RELAX_OP_ATTR_TYPES_H_ +#define TVM_RELAX_OP_ATTR_TYPES_H_ + +#include <tvm/relay/expr.h> +#include <tvm/relay/type.h> +#include <tvm/te/schedule.h> +#include <tvm/te/tensor.h> + +#include <string> + +namespace tvm { +namespace relax { + +using relay::Call; + +/*! + * \brief Infer the output shape for operators. This function will + * be invoked to fill the \p shape_ field of expressions. + * \param call The call node. + * \param diag_ctx The diagnostic context for reporting errors. + * \return The inferred output shape expression. + */ +using FInferShape = runtime::TypedPackedFunc<Optional<RelayExpr>(const Call& call, DiagnosticContext diag_ctx)>; + +/*! + * \brief Infer the output type for operators. This function will + * be invoked to fill the \p checked_type_ field of expressions. + * \param call The call node. + * \param diag_ctx The diagnostic context for reporting errors. + * \return The inferred output type. + */ +using FInferType = runtime::TypedPackedFunc<Type(const Call& call, DiagnosticContext diag_ctx)>; + +} // namespace relax +} // namespace tvm +#endif // TVM_RELAX_OP_ATTR_TYPES_H_ diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index b9edbdfe521a..232c65be759e 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -46,7 +46,7 @@ def checked_type(self): checked_type : tvm.relay.Type The checked type. """ - ret = self._checked_type_ + ret = self.checked_type_ if ret is None: raise ValueError("The type checker has not populated" " the checked_type for this node") return ret diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index d4ed3cc7976f..d7743c002626 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -66,6 +66,8 @@ def __init__( type_annotation: Optional[Type] = None, span: Span = None, ) -> None: + if shape_annotation is not None: + shape_annotation = make_shape(shape_annotation) self.__init_handle_by_constructor__( _ffi_api.Var, name_hint, shape_annotation, type_annotation, span ) @@ -86,6 +88,8 @@ def __init__( type_annotation: Optional[Type] = None, span: Span = None, ) -> None: + if shape_annotation is not None: + shape_annotation = make_shape(shape_annotation) self.__init_handle_by_constructor__( _ffi_api.DataflowVar, name_hint, shape_annotation, type_annotation, span ) diff --git a/python/tvm/relax/op/tensor.py b/python/tvm/relax/op/tensor.py index 0e5854bdb586..896b2c4d3d0d 100644 --- a/python/tvm/relax/op/tensor.py +++ b/python/tvm/relax/op/tensor.py @@ -1,3 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""Basic tensor operations.""" from . import _ffi_api from ..expr import Expr diff --git a/src/relax/ir_builder.cc b/src/relax/ir_builder.cc index e5c09ccebd1d..c2cc4fa5500b 100644 --- a/src/relax/ir_builder.cc +++ b/src/relax/ir_builder.cc @@ -22,6 +22,7 @@ */ #include <tvm/relax/ir_builder.h> +#include <tvm/relax/op_attr_types.h> #include <tvm/relay/op.h> namespace tvm { @@ -38,59 +39,84 @@ IRBuilder IRBuilderNode::Create() { void IRBuilderNode::FillFuncNameParam(const Array<Var>& params, const std::string& func_name) { if (!func_name.empty()) { - this->func.func_name = GlobalVar(func_name); + this->func_.func_name = GlobalVar(func_name); } - - this->func.params = params; + + this->func_.params = params; } void IRBuilderNode::BuildFunction() { - SeqExpr seq = SeqExpr(this->func.binding_blocks, this->func.ret); - this->func.func = Function(this->func.func_name, this->func.params, seq, {}); - this->global_var_counter = 0; + SeqExpr seq = SeqExpr(this->func_.binding_blocks, this->func_.ret); + this->func_.func = Function(this->func_.func_name, this->func_.params, seq, {}); + this->global_var_counter_ = 0; } void IRBuilderNode::BuildBlock() { - if (!this->func.bindings.empty()) { - if (is_dataflow) { - this->func.binding_blocks.emplace_back(DataflowBlock(this->func.bindings)); + if (!this->func_.bindings.empty()) { + if (is_dataflow_) { + this->func_.binding_blocks.emplace_back(DataflowBlock(this->func_.bindings)); } else { - this->func.binding_blocks.emplace_back(BindingBlock(this->func.bindings)); + this->func_.binding_blocks.emplace_back(BindingBlock(this->func_.bindings)); } - this->func.bindings.clear(); + this->func_.bindings.clear(); } - this->dataflow_var_counter = 0; - this->is_dataflow = !this->is_dataflow; + this->dataflow_var_counter_ = 0; + this->is_dataflow_ = !this->is_dataflow_; +} + +Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) { + auto op_map = Op::GetAttrMap<FInferShape>("FInferShape"); + Op op = Downcast<Op>(call->op); + return op_map[op](call, diag_ctx); +} + +Type InferType(const Call& call, DiagnosticContext diag_ctx) { + auto op_map = Op::GetAttrMap<FInferType>("FInferType"); + Op op = Downcast<Op>(call->op); + return op_map[op](call, diag_ctx); } Var IRBuilderNode::Emit(const Call& call) { Var var; - if (is_dataflow) { - var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter++)), NullOpt, NullOpt); + if (is_dataflow_) { + var = DataflowVar(Id("lv" + std::to_string(dataflow_var_counter_++)), NullOpt, NullOpt); } else { - var = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt); + var = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt); + } + + // Shape inference + auto inferred_shape = InferShape(call, this->diag_ctx_); + if (inferred_shape.defined()) { + if (auto* shape_expr = inferred_shape.value().as<ShapeExprNode>()) { + call->shape_ = GetRef<Expr>(shape_expr); + var->shape_ = call->shape_; + } } + // Type inference + auto inferred_type = InferType(call, this->diag_ctx_); + call->checked_type_ = inferred_type; + var->checked_type_ = inferred_type; - this->func.bindings.emplace_back(VarBinding(var, call)); + this->func_.bindings.emplace_back(VarBinding(var, call)); return var; } Var IRBuilderNode::EmitOutput(const Expr& output) { Var ret; - if (is_dataflow) { - ret = Var(Id("gv" + std::to_string(global_var_counter++)), NullOpt, NullOpt); + if (is_dataflow_) { + ret = Var(Id("gv" + std::to_string(global_var_counter_++)), NullOpt, NullOpt); ret->shape_ = output->shape_; ret->checked_type_ = output->checked_type_; - this->func.bindings.emplace_back(VarBinding(ret, output)); + this->func_.bindings.emplace_back(VarBinding(ret, output)); } else { - this->func.ret = output; + this->func_.ret = output; } return ret; } -Function IRBuilderNode::Get() { return this->func.func; } +Function IRBuilderNode::Get() { return this->func_.func; } -std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func.binding_blocks; } +std::vector<BindingBlock> IRBuilderNode::GetBlocks() { return this->func_.binding_blocks; } class FunctionScope::Internal { public: @@ -121,20 +147,16 @@ DataflowScope::DataflowScope(IRBuilder ib) { data_ = std::move(n); } -void DataflowScope::EnterWithScope() { - this->get()->ir_builder->BuildBlock(); -} +void DataflowScope::EnterWithScope() { this->get()->ir_builder->BuildBlock(); } -void DataflowScope::ExitWithScope() { - this->get()->ir_builder->BuildBlock(); -} +void DataflowScope::ExitWithScope() { this->get()->ir_builder->BuildBlock(); } TVM_REGISTER_GLOBAL("relax.IRBuilderCreate").set_body_typed(IRBuilderNode::Create); TVM_REGISTER_GLOBAL("relax.IRBuilderFillFuncNameParam") -.set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) { - return builder->FillFuncNameParam(params, func_name); -}); + .set_body_typed([](IRBuilder builder, const Array<Var>& params, const std::string& func_name) { + return builder->FillFuncNameParam(params, func_name); + }); TVM_REGISTER_GLOBAL("relax.IRBuilderBuildFunction").set_body_typed([](IRBuilder builder) { return builder->BuildFunction(); @@ -145,9 +167,9 @@ TVM_REGISTER_GLOBAL("relax.IRBuilderEmit").set_body_typed([](IRBuilder builder, }); TVM_REGISTER_GLOBAL("relax.IRBuilderEmitOutput") -.set_body_typed([](IRBuilder builder, const Expr& output) { - return builder->EmitOutput(output); -}); + .set_body_typed([](IRBuilder builder, const Expr& output) { + return builder->EmitOutput(output); + }); TVM_REGISTER_GLOBAL("relax.IRBuilderGet").set_body_typed([](IRBuilder builder) { return builder->Get(); diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 4e1ce826f9b2..074544cfc298 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -25,9 +25,9 @@ #ifndef TVM_RELAX_OP_OP_COMMON_H_ #define TVM_RELAX_OP_OP_COMMON_H_ +#include <tvm/relax/op_attr_types.h> #include <tvm/relay/expr.h> #include <tvm/relay/op.h> -#include <tvm/relay/op_attr_types.h> namespace tvm { namespace relax { @@ -42,15 +42,17 @@ namespace relax { * * \param OpName the name of registry. */ -#define RELAX_REGISTER_BINARY_OP(OpName) \ +#define RELAX_REGISTER_BINARY_BROADCAST_OP(OpName) \ TVM_REGISTER_GLOBAL("relax.op." OpName).set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ + static const Op& op = Op::Get("relax." OpName); \ return Call(op, {lhs, rhs}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP("relax." OpName) \ .set_num_inputs(2) \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .set_attr<FInferShape>("FInferShape", InferShapeBinaryBroadcast) \ + .set_attr<FInferType>("FInferType", InferTypeBinaryBroadcast) } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index df053027aff2..6818ff12deb0 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -22,32 +22,18 @@ * \brief binary broadcast operators. */ -#include <tvm/arith/analyzer.h> -#include <tvm/relax/expr.h> -#include <tvm/relax/type.h> -#include <tvm/tir/op.h> -#include <tvm/topi/broadcast.h> +#include "binary.h" #include "../op_common.h" namespace tvm { namespace relax { -using Expr = tvm::RelayExpr; -using relay::Call; - -#define RELAX_BINARY_COMPUTE(FTOPI) \ - [](const Attrs& attrs, const Array<te::Tensor>& inputs, \ - const Type& out_type) -> Array<te::Tensor> { \ - ICHECK_EQ(inputs.size(), 2U); \ - return {FTOPI(inputs[0], inputs[1])}; \ - } - -RELAX_REGISTER_BINARY_OP("add") +RELAX_REGISTER_BINARY_BROADCAST_OP("add") .describe("Elementwise add with broadcasting") .set_support_level(1); -RELAX_REGISTER_BINARY_OP("multiply") +RELAX_REGISTER_BINARY_BROADCAST_OP("multiply") .describe("Elementwise multiply with broadcasting") .set_support_level(1); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h new file mode 100644 index 000000000000..75957b530506 --- /dev/null +++ b/src/relax/op/tensor/binary.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file binary.h + * \brief shape and type deduction for binary broadcast operators. + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/op_attr_types.h> +#include <tvm/relax/type.h> +#include <tvm/tir/op.h> +#include <tvm/topi/broadcast.h> + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +bool EqualConstInt(const PrimExpr& lhs, int64_t value) { + if (const int64_t* pvalue = tir::as_const_int(lhs)) { + return pvalue[0] == value; + } + return false; +} + +bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { + PrimExpr diff = lhs - rhs; + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); + if (const int64_t* pdiff = tir::as_const_int(diff)) { + return pdiff[0] == 0; + } + return false; +} + +Optional<Expr> InferShapeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Binary broadcast op should have 2 arguments"); + } + Expr lhs_shape = call->args[0]->shape(); + Expr rhs_shape = call->args[1]->shape(); + auto* s0 = lhs_shape.as<ShapeExprNode>(); + auto* s1 = rhs_shape.as<ShapeExprNode>(); + if (s0 && s1) { + std::vector<PrimExpr> output_shape; + size_t ndim0 = s0->values.size(); + size_t ndim1 = s1->values.size(); + size_t i = 1; + for (; i <= std::min(ndim0, ndim1); ++i) { + PrimExpr dim0 = s0->values[ndim0 - i]; + PrimExpr dim1 = s1->values[ndim1 - i]; + if (EqualConstInt(dim0, 1)) { + output_shape.push_back(dim1); + } else if (EqualConstInt(dim1, 1)) { + output_shape.push_back(dim0); + } else if (EqualCheck(dim0, dim1)) { + output_shape.push_back(dim0); + } else { + // defer the computation of output shapes to runtime + // e.g., broadcast Tensor([m, n]), Tensor([k]) -> defer to runtime + return Call(ExternFunc(String("vm.binary_broadcast_shape_infer")), + {call->args[0], call->args[1]}, {}, {}); + } + } + size_t max_ndim = std::max(ndim0, ndim1); + auto& longer_shape = (ndim0 > ndim1) ? s0 : s1; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape->values[max_ndim - i]); + } + return ShapeExpr(Array<PrimExpr>(output_shape.rbegin(), output_shape.rend())); + } else { + return NullOpt; + } +} + +Type InferTypeBinaryBroadcast(const Call& call, DiagnosticContext diag_ctx) { + if (call->args.size() != 2) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Binary broadcast op should have 2 arguments"); + } + Type lhs_type = call->args[0]->checked_type(); + Type rhs_type = call->args[1]->checked_type(); + auto* t0 = lhs_type.as<DynTensorTypeNode>(); + auto* t1 = rhs_type.as<DynTensorTypeNode>(); + if (!t0 || !t1) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Both lhs and rhs should be DynTensor for broadcasting"); + } + + DataType output_dtype; + if (t0->IsUnknownDtype() || t1->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t0->dtype != t1->dtype) { + diag_ctx.EmitFatal(Diagnostic::Error(call->span) + << "Data types " << t0->dtype << " and " << t1->dtype + << " must be equal for broadcasting operators"); + } else { + output_dtype = t0->dtype; + } + + int output_rank; + if (t0->IsUnknownRank() || t1->IsUnknownRank()) { + output_rank = -1; + } else { + output_rank = std::max(t0->rank, t1->rank); + } + return DynTensorType(output_rank, output_dtype); +} + +} // namespace relax +} // namespace tvm diff --git a/src/relax/vm/builtin.cc b/src/relax/vm/builtin.cc index 5553a08108b9..96392b6912ba 100644 --- a/src/relax/vm/builtin.cc +++ b/src/relax/vm/builtin.cc @@ -91,6 +91,30 @@ TVM_REGISTER_GLOBAL("vm.builtin.alloc_tensor") return tensor; }); +TVM_REGISTER_GLOBAL("vm.binary_broadcast_shape_infer") +.set_body_typed([](ShapeTuple lhs_shape, ShapeTuple rhs_shape) { + std::vector<int64_t> output_shape; + size_t ndim0 = lhs_shape.size(); + size_t ndim1 = rhs_shape.size(); + size_t i = 1; + for (; i <= std::min(ndim0, ndim1); ++i) { + int64_t lhs_dim = lhs_shape[ndim0 - i]; + int64_t rhs_dim = rhs_shape[ndim1 - i]; + if (lhs_dim == 1 || rhs_dim == 1 || lhs_dim == rhs_dim) { + output_shape.push_back(std::max(lhs_dim, rhs_dim)); + } else { + LOG(FATAL) << "Incompatible shapes " << lhs_shape << " and " << rhs_shape + << " for broadcasting"; + } + } + size_t max_ndim = std::max(ndim0, ndim1); + ShapeTuple& longer_shape = (ndim0 > ndim1) ? lhs_shape : rhs_shape; + for (; i <= max_ndim; ++i) { + output_shape.push_back(longer_shape[max_ndim - i]); + } + return ShapeTuple(output_shape.rbegin(), output_shape.rend()); +}); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_irbuilder.py b/tests/python/relax/test_irbuilder.py index 28bbc9756b67..4e00049882e8 100644 --- a/tests/python/relax/test_irbuilder.py +++ b/tests/python/relax/test_irbuilder.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import tvm from tvm import tir from tvm import relax as rx @@ -27,13 +28,26 @@ def test_dataflow_block(): x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) ib = rx.IRBuilder() + with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.name_hint == "lv0" + assert lv0.shape[0] == m + assert lv0.shape[1] == n + assert lv0.checked_type.rank == 2 + assert lv0.checked_type.dtype == "float16" + lv1 = ib.emit(rx.op.multiply(lv0, y)) assert lv1.name_hint == "lv1" gv0 = ib.emit_output(lv1) + assert gv0.name_hint == "gv0" + assert gv0.shape[0] == m + assert gv0.shape[1] == n + assert gv0.checked_type.rank == 2 + assert gv0.checked_type.dtype == "float16" + blocks = ib.get_blocks() assert len(blocks) == 1 assert len(blocks[-1].bindings) == 3 @@ -47,6 +61,7 @@ def test_function_single_block(): x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) ib = rx.IRBuilder() + with ib.function([x, y]): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) @@ -56,10 +71,15 @@ def test_function_single_block(): gv0 = ib.emit_output(lv1) assert gv0.name_hint == "gv0" ib.emit_output(gv0) + func = ib.get() assert func.params[0] == x assert func.params[1] == y assert func.body.body == gv0 + assert gv0.shape[0] == m + assert gv0.shape[1] == n + assert gv0.checked_type.rank == 2 + assert gv0.checked_type.dtype == "float16" assert len(func.body.blocks) == 1 assert len(func.body.blocks[0].bindings) == 3 @@ -72,6 +92,7 @@ def test_function_multi_blocks(): x = rx.Var("x", [m, n], dtype0) y = rx.Var("y", [n], dtype1) ib = rx.IRBuilder() + with ib.function([x, y], "func"): with ib.dataflow() as df: lv0 = ib.emit(rx.op.add(x, y)) @@ -85,7 +106,12 @@ def test_function_multi_blocks(): assert lv0.name_hint == "lv0" gv2 = ib.emit_output(gv1) ib.emit_output(gv2) + func = ib.get() + assert gv2.shape[0] == m + assert gv2.shape[1] == n + assert gv2.checked_type.rank == 2 + assert gv2.checked_type.dtype == "float16" assert func.params[0] == x assert func.params[1] == y assert func.name.name_hint == "func" @@ -96,7 +122,40 @@ def test_function_multi_blocks(): assert len(func.body.blocks[2].bindings) == 2 +def test_binary_shape_deduction(): + m = tir.Var("m", "int32") + n = tir.Var("n", "int32") + k = tir.Var("k", "int32") + dtype0 = rx.DynTensorType(rank=2, dtype="float16") + dtype1 = rx.DynTensorType(rank=1, dtype="float16") + x = rx.Var("x", [m, 1], dtype0) + y = rx.Var("y", [n], dtype1) + z = rx.Var("z", [5], dtype0) + w = rx.Var("w", [k], dtype1) + ib = rx.IRBuilder() + + with ib.function([x, y, z, w]): + with ib.dataflow() as df: + lv0 = ib.emit(rx.op.add(x, y)) + assert lv0.shape[0] == m + assert lv0.shape[1] == n + + lv1 = ib.emit(rx.op.multiply(x, z)) + assert lv1.shape[0] == m + assert lv1.shape[1] == 5 + + lv2 = ib.emit(rx.op.multiply(z, w)) + assert isinstance(lv2.shape, tvm.relay.Call) + + lv3 = ib.emit(rx.op.multiply(y, w)) + assert isinstance(lv3.shape, tvm.relay.Call) + gv0 = ib.emit_output(lv3) + ib.emit_output(gv0) + assert isinstance(gv0.shape, tvm.relay.Call) + + if __name__ == "__main__": test_dataflow_block() test_function_single_block() test_function_multi_blocks() + test_binary_shape_deduction()