Skip to content

Commit

Permalink
Merge pull request #1 from gigiblender/aot-mem-lower
Browse files Browse the repository at this point in the history
[Relax][AOT] Add AOTMemoryLower pass when USMP is disabled
  • Loading branch information
mbaret authored Dec 5, 2022
2 parents c099d06 + 834e975 commit 66eae17
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 31 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/analysis/*.cc
src/relax/usmp/*.cc
src/relax/transform/*.cc
src/relax/backend/aot/*.cc
src/relax/backend/vm/*.cc
src/relax/backend/aot/*.cc
src/relax/backend/task_extraction.cc
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relax/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ TVM_DLL Pass VMMemoryLower();
*/
TVM_DLL Pass VMShapeLower();

/*!
* \brief Perform memory lowering in AOT. Lowers the relax.builtin.alloc_tensor intrinsic to
* relax.memory.* intrinsics.
*
* \return The Pass.
*/
TVM_DLL Pass AOTMemoryLower();

} // namespace transform
} // namespace relax
} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, redefined-builtin, no-else-return
"""The Relax virtual machine"""
"""The Relax AOT executor"""
from typing import Callable, List, Optional, Union, Dict

import tvm
Expand Down Expand Up @@ -63,7 +63,7 @@ def build(
if not isinstance(ir_mod, IRModule):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

ctxt = tvm.transform.PassContext()
ctxt = tvm.transform.PassContext.current()
config = make_compilation_config(ctxt, target, target_host)

ir_mod = lower(ir_mod)
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ def VMShapeLower() -> tvm.ir.transform.Pass:
return _ffi_api.VMShapeLower() # type: ignore


def AOTMemoryLower() -> tvm.ir.transform.Pass:
"""Perform memory lowering in AOT. Lowers the relax.builtin.alloc_tensor intrinsic to
relax.memory.* intrinsics.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.AOTMemoryLower() # type: ignore


def Normalize() -> tvm.ir.transform.Pass:
"""Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting
and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available.
Expand Down
4 changes: 2 additions & 2 deletions src/relax/backend/aot/aot_lower_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class AOTMainLowerer : public ExprVisitor {

IRModule Lower(IRModule mod, String mod_name) {
IRModule lowered_mod = GetRef<IRModule>(mod.CopyOnWrite());

auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

Expand All @@ -76,7 +76,7 @@ class AOTMainLowerer : public ExprVisitor {
.value_or(Map<GlobalVar, String>()));

VisitExpr(lowered_main_func);

// Remove the Relay main and replace it with the lowered TIR version
mod->Remove(lowered_mod->GetGlobalVar("main"));
auto tir_main_func = CreateMainFunc(mod_name);
Expand Down
116 changes: 116 additions & 0 deletions src/relax/backend/aot/aot_memory_lower.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/backend/aot/aot_memory_lower.cc
* \brief Perform memory lowering. Lowers the relax.builtin.alloc_tensor intrinsic to
* relax.memory.alloc_storage + relax.memory.alloc_tensor.
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include "../../../relay/transforms/pattern_utils.h"

namespace tvm {
namespace relax {

// ==================
// MemLowerMutator
// Lower the relax.builtin.alloc_tensor op to relax.memory builtin functions.
// Example:
// x = relax.builtin.alloc_tensor((m, n), relax.attrs.AllocTensorAttrs)
// -->
// gv0 = relax.memory.alloc_storage(m * n * dtype, relax.attrs.MemAllocStorageAttrs)
// gv1 = relax.memory.alloc_tensor(gv0, (m, n), relax.attrs.MemAllocTensorAttrs)

class AOTMemLowerMutator : public ExprMutator {

// TODO(gigiblender): Dedup this function with the one in VMMemoryLower.
Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const {
// Question: what if the dtype of tensor_type is unknown?
// Symbolic/static shape case
if (auto* shape_expr = shape.as<ShapeExprNode>()) {
PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes());
PrimExpr add = num + 7;
PrimExpr ret = 1;
for (PrimExpr dim : shape_expr->values) {
ret = ret * dim;
}
ret = ret * (add / PrimExpr(8));
return ShapeExpr({ret});
}
// Fully dynamic shape case will need to dedup with ComputeStorageInRelay when we upstream
Expr prod = relay::Prod(shape, Array<Integer>(nullptr), false, false);
Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes());
Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7));
Expr div = relay::MakeConstantScalar(DataType::Int(64), 8);
Expr ret = relay::Multiply(prod, relay::Divide(add, div));
return ret;
}

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expr expr = VisitExprPostOrder_(call);
call = expr.as<CallNode>();

static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& memory_alloc_storage_op = Op::Get("relax.memory.alloc_storage");
static const Op& memory_alloc_tensor_op = Op::Get("relax.memory.alloc_tensor");
if (call->op == alloc_tensor_op) {
ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]);
auto alloc_attrs = call->attrs.as<AllocTensorAttrs>();
ICHECK(alloc_attrs != nullptr) << "must be AllocTensorAttrs";
DataType dtype = alloc_attrs->dtype;
Expr storage_size = ComputeStorageSize(output_shape, dtype);
auto storage_attr = make_object<MemAllocStorageAttrs>();
storage_attr->dtype = dtype;

Var storage =
builder_->Emit(Call(memory_alloc_storage_op, {storage_size}, Attrs(storage_attr)),
"storage");
auto tensor_attr = make_object<MemAllocTensorAttrs>();
tensor_attr->offset = 0;
tensor_attr->dtype = dtype;
Expr shape = call->args[0];
return Call(memory_alloc_tensor_op, {storage, shape}, Attrs(tensor_attr));
}

return GetRef<Expr>(call);
}
};

Expr AOTMemoryLower(const Expr& e) { return AOTMemLowerMutator().VisitExpr(e); }

namespace transform {

Pass AOTMemoryLower() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(AOTMemoryLower(f)); };
return CreateFunctionPass(pass_func, 0, "AOTMemoryLower", {});
}

TVM_REGISTER_GLOBAL("relax.transform.AOTMemoryLower").set_body_typed(AOTMemoryLower);

} // namespace transform
} // namespace relax
} // namespace tvm
13 changes: 11 additions & 2 deletions src/relax/backend/aot/codegen_aot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/runtime.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/backend.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/object.h>
#include <tvm/tir/analysis.h>
Expand All @@ -39,6 +40,7 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/usmp/utils.h>
#include <tvm/relax/usmp/utils.h>
#include <tvm/relay/executor.h>
#include <tvm/relay/runtime.h>

Expand All @@ -64,8 +66,15 @@ runtime::Module Build(IRModule mod, String mod_name, CompilationConfig config, r
Integer constant_byte_alignment =
executor->GetAttr<Integer>("constant-byte-alignment").value_or(16);

transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_usmp = pass_ctx->GetConfig<Bool>(kUSMPRelaxEnableOption, Bool(false)).value();

mod = LowerModule(mod);
mod = relax::transform::UnifiedStaticMemoryPlanner()(mod);
if (enable_usmp) {
mod = relax::transform::UnifiedStaticMemoryPlanner()(mod);
} else {
mod = relax::transform::AOTMemoryLower()(mod);
}
mod = AOTLowerMain(mod_name, config)(mod);
mod = tir::transform::LegalizePackedCalls()(mod);

Expand All @@ -85,4 +94,4 @@ TVM_REGISTER_GLOBAL("relax.aot.build")

} // namespace aot
} // namespace relax
} // namespace tvm
} // namespace tvm
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable(
*module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name);

param_symbol->setAlignment(data.DataType().bits());
param_symbol->setAlignment(llvm::Align(data.DataType().bits()));
var_map_[op->buffer_var.operator->()] = param_symbol;
this->VisitStmt(op->body);
}
Expand Down
47 changes: 30 additions & 17 deletions tests/python/relax/aot/test_aot_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def _export_mod(mod):
return tvm.runtime.load_module(test_so_path)


def test_single_elementwise():
@pytest.mark.parametrize("enable_usmp", [True, False])
def test_single_elementwise(enable_usmp):
dtype = "int32"
target = "llvm"
inputs = {"x": np.array([[-10, 5], [1, 2]], dtype=dtype)}
Expand All @@ -48,21 +49,23 @@ def _relay():
def _reference(inputs):
x = inputs["x"]
return np.abs(x) # abs

relax_mod = relay_translator.from_relay(
_relay(),
target,
)

mod = build(relax_mod, target)
with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}):
mod = build(relax_mod, target)
loaded_mod = _export_mod(mod)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)
runner.run()
assert (runner.get_output(0).numpy() == _reference(inputs)).all()


def test_scalar_constant():
@pytest.mark.parametrize("enable_usmp", [True, False])
def test_scalar_constant(enable_usmp):
dtype = "int32"
target = "llvm"
inputs = {"x": np.array([[-10, 5], [1, 2]], dtype=dtype)}
Expand All @@ -75,21 +78,23 @@ def _relay():
def _reference(inputs):
x = inputs["x"]
return np.add(x, -1) # add

relax_mod = relay_translator.from_relay(
_relay(),
target,
)

mod = build(relax_mod, target)
with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}):
mod = build(relax_mod, target)
loaded_mod = _export_mod(mod)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)
runner.run()
assert (runner.get_output(0).numpy() == _reference(inputs)).all()


def test_tensor_constant():
@pytest.mark.parametrize("enable_usmp", [True, False])
def test_tensor_constant(enable_usmp):
dtype = "int32"
target = "llvm"
inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype)}
Expand All @@ -102,24 +107,29 @@ def _relay():
def _reference(inputs):
x = inputs["x"]
return np.add(x, np.array([[1, 2], [3, 4]])) # add

relax_mod = relay_translator.from_relay(
_relay(),
target,
)

mod = build(relax_mod, target)
with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}):
mod = build(relax_mod, target)
loaded_mod = _export_mod(mod)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)
runner.run()
assert (runner.get_output(0).numpy() == _reference(inputs)).all()


def test_multi_input():
@pytest.mark.parametrize("enable_usmp", [True, False])
def test_multi_input(enable_usmp):
dtype = "int32"
target = "llvm"
inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype), "y": np.array([[1, 2], [3, 4]], dtype=dtype)}
inputs = {
"x": np.array([[-10, 1], [5, 1]], dtype=dtype),
"y": np.array([[1, 2], [3, 4]], dtype=dtype),
}

def _relay():
x = relay.var("x", shape=(2, 2), dtype=dtype)
Expand All @@ -131,21 +141,23 @@ def _reference(inputs):
x = inputs["x"]
y = inputs["y"]
return np.add(x, y) # add

relax_mod = relay_translator.from_relay(
_relay(),
target,
)

mod = build(relax_mod, target)
with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}):
mod = build(relax_mod, target)
loaded_mod = _export_mod(mod)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)
runner.run()
assert (runner.get_output(0).numpy() == _reference(inputs)).all()


def test_multi_output():
@pytest.mark.parametrize("enable_usmp", [True, False])
def test_multi_output(enable_usmp):
dtype = "int32"
target = "llvm"
inputs = {"x": np.array([[-10, 1], [5, 1]], dtype=dtype)}
Expand All @@ -159,16 +171,17 @@ def _relay():

def _reference(inputs):
x = inputs["x"]
abs = np.abs(x) # abs
abs = np.abs(x) # abs
out = abs - 1
return [abs, out]

relax_mod = relay_translator.from_relay(
_relay(),
target,
)

mod = build(relax_mod, target)
with tvm.transform.PassContext(config={"relax.usmp.enable": enable_usmp}):
mod = build(relax_mod, target)
loaded_mod = _export_mod(mod)
runner = tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
runner.set_input(**inputs)
Expand Down
Loading

0 comments on commit 66eae17

Please sign in to comment.