Skip to content

Commit

Permalink
[Relax] Implement relax.op.view
Browse files Browse the repository at this point in the history
This commit implements `relax.op.view` (`R.view` in TVMScript) to
produce a view into an existing array.  This returned view shares
the same backing allocation as the existing array.

Because `R.view` comes with potential trade-offs; such as increased
memory footprint, performance cost to apply a non-zero
`DLTensor::byte_offset`, and potential misalignment for vector
operators; this PR does not use `R.view` apart from unit tests.
Applications of `R.view`, either for specific compute kernels or in
optimization passes, is instead kept for follow-up PRs.
  • Loading branch information
Lunderberg committed Apr 29, 2024
1 parent c0385c7 commit ed4fd50
Show file tree
Hide file tree
Showing 10 changed files with 1,287 additions and 12 deletions.
15 changes: 10 additions & 5 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,21 +1108,26 @@ def inline_functions(


@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
class ExternFunc(BaseFunc, ExprWithOp):
"""extern function, which represents a PackedFunc."""

global_symbol: String
span: Optional[Span]

def __init__(self, global_symbol: String, span: Optional[Span] = None) -> None:
def __init__(
self,
global_symbol: String,
struct_info: Optional[StructInfo] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ExternFunc, global_symbol, span # type: ignore
_ffi_api.ExternFunc, global_symbol, struct_info, span # type: ignore
)


def extern(name: str, span: Optional[Span] = None):
def extern(name: str, struct_info: Optional[StructInfo] = None, span: Optional[Span] = None):
"""Create extern function."""
return ExternFunc(name, span)
return ExternFunc(name, struct_info, span)


def const(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
tan,
tanh,
)
from .view import view


def _register_op_make():
Expand Down
76 changes: 76 additions & 0 deletions python/tvm/relax/op/view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Operations that act on the DLTensor container """
from typing import Optional, Sequence, Union

from tvm.ir.expr import PrimExpr

from . import _ffi_api
from ..expr import Expr, PrimValue, ShapeExpr, DataTypeImm

PrimExprLike = Union[int, PrimExpr]


def view(
data: Expr,
shape: Optional[Union[Sequence[PrimExprLike], Expr]] = None,
dtype: Optional[Expr] = None,
relative_byte_offset: Optional[Expr] = None,
) -> Expr:
"""Broadcasts a tensor to a specified shape.
Parameters
----------
data : relax.Expr
The input data to the operator.
shape : Optional[Union[Sequence[PrimExprLike], Expr]]
The target shape. Should be a `relax.ShapeExpr`, or a
collection that can be converted to a `relax.ShapeExpr`.
dtype : Optional[Expr]
The target datatype. Should be a `relax.ShapeExpr`, or a
collection that can be converted to a `relax.ShapeExpr`.
relative_byte_offset: Optional[Expr]
The offset of the output NDArray, relative to the byte offset
of `data`. If `None`, the offset of the view is the same as
the offset of `data`.
Returns
-------
result : relax.Expr
The tensor view
"""

def _normalize(expr, relax_cls):
if expr is None or isinstance(expr, Expr):
return expr
else:
return relax_cls(expr)

shape = _normalize(shape, ShapeExpr)
dtype = _normalize(dtype, DataTypeImm)
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)

return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore
7 changes: 5 additions & 2 deletions python/tvm/relax/struct_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(
def opaque_func(
*,
ret: Optional[StructInfo] = None,
derive_func: Optional[EnvFunc] = None,
derive_func: Optional[Union[str, EnvFunc]] = None,
purity: bool = False,
span: Span = None,
) -> "FuncStructInfo":
Expand All @@ -249,7 +249,7 @@ def opaque_func(
ret: Optional[StructInfo]
The struct info of the function return value.
derive_func: Optional[EnvFunc]
derive_func: Optional[Union[str,EnvFunc]]
The environment function used for derivation
purity: bool
Expand All @@ -266,4 +266,7 @@ def opaque_func(
----
We cannot specify ret and derive_func simultaneously.
"""

if isinstance(derive_func, str):
derive_func = tvm.ir.EnvFunc.get("tvm.relax.struct_info.infer_view_sinfo")
return _ffi_api.FuncStructInfoOpaqueFunc(ret, derive_func, purity, span) # type: ignore
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
sum,
take,
variance,
view,
sigmoid,
sign,
sin,
Expand Down Expand Up @@ -794,6 +795,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"tuple",
"unique",
"variance",
"view",
"vm",
"vpi",
"vulkan",
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Callable as _Callable
from typing import Dict, List, Optional, Set, TypeVar, Union

import tvm
from tvm.relax import (
Expr,
SeqExpr,
Expand Down Expand Up @@ -277,6 +278,7 @@ class CallableProxy(StructInfoProxy):
params: List[StructInfoProxy]
ret: StructInfoProxy
purity: bool
derive_func: Optional[Union[str, tvm.ir.EnvFunc]]

"""Function type.
Expand All @@ -296,13 +298,17 @@ class CallableProxy(StructInfoProxy):
purity : bool
Whether the callable is pure.
derive_func: Optional[Union[str, tvm.ir.EnvFunc]]
The derivation function for the outputq
"""

def __init__(
self,
params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
ret: Optional[StructInfoProxy] = None,
purity: Optional[bool] = None,
derive_func: Optional[Union[str, tvm.ir.EnvFunc]] = None,
) -> None:
if params is None:
self.params = params
Expand All @@ -320,6 +326,7 @@ def __init__(

self.ret = ret() if callable(ret) else ret
self.purity = purity
self.derive_func = derive_func

def get_symbolic_vars(self) -> Set[str]:
if self.params is None:
Expand All @@ -339,7 +346,9 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> FuncS
params = [param.as_struct_info(dict_globals) for param in self.params]

if params is None:
return FuncStructInfo.opaque_func(ret=ret, purity=self.purity)
return FuncStructInfo.opaque_func(
ret=ret, derive_func=self.derive_func, purity=self.purity
)
else:
return FuncStructInfo(params, ret, purity=self.purity)

Expand All @@ -348,8 +357,9 @@ def Callable(
params: Optional[Union[StructInfoProxy, List[StructInfoProxy]]] = None,
ret: Optional[StructInfoProxy] = None,
purity: Optional[bool] = None,
derive_func: Optional[Union[str, tvm.ir.EnvFunc]] = None,
) -> CallableProxy:
return CallableProxy(params, ret, purity=purity)
return CallableProxy(params, ret, purity=purity, derive_func=derive_func)


############################### R.Tuple ################################
Expand Down
11 changes: 8 additions & 3 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,9 +650,14 @@ ExternFunc::ExternFunc(String global_symbol, StructInfo struct_info, Span span)
data_ = std::move(n);
}

TVM_REGISTER_GLOBAL("relax.ExternFunc").set_body_typed([](String global_symbol, Span span) {
return ExternFunc(global_symbol, span);
});
TVM_REGISTER_GLOBAL("relax.ExternFunc")
.set_body_typed([](String global_symbol, Optional<StructInfo> struct_info, Span span) {
if (struct_info.defined()) {
return ExternFunc(global_symbol, struct_info.value(), span);
} else {
return ExternFunc(global_symbol, span);
}
});

Expr GetShapeOf(const Expr& expr) {
// default case, to be normalized.
Expand Down
Loading

0 comments on commit ed4fd50

Please sign in to comment.