From 62a4c735fb7710c6fd0ef7d2f99245637e8e8acd Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Thu, 23 Feb 2023 05:21:20 -0500 Subject: [PATCH] [Unity][Layout] Add layout transformation analysis for PrimFunc (#14066) * [Layout] Add layout transformation analysis for PrimFunc. This change adds a PrimFunc level analysis to suggest layout transformations to block and buffers in the PrimFunc based on the layout transformations to PrimFunc outputs. * Add support for multiple blocks such as split op. * Add negative tests and increase coverage. * fix warning message * fix lint * remove unused header * Address comments. Moved some utility functions to support/array.h improve doc * fix deprecation warn T.var("int64") to T.int64() * address comments --- include/tvm/relax/analysis.h | 13 + python/tvm/relax/analysis/analysis.py | 32 +- src/relax/analysis/layout_transformation.cc | 621 +++++++++++++ src/support/array.h | 27 +- ...test_analysis_suggest_layout_transforms.py | 831 ++++++++++++++++++ 5 files changed, 1522 insertions(+), 2 deletions(-) create mode 100644 src/relax/analysis/layout_transformation.cc create mode 100644 tests/python/relax/test_analysis_suggest_layout_transforms.py diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 39ecfd9e13a7..2b771b9708ab 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -403,6 +403,19 @@ TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); */ TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true); +/*! + * \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks + * and buffers for the PrimFunc. + * + * \param fn The PrimFunc to be analyzed. + * \param write_buffer_transformations Array of IndexMap transformations on PrimFunc outputs. + * \return Suggested transforms per block in `fn`. For each block the returned value is a map + * from the object (block or buffer) to it's index map transformation. + */ + +TVM_DLL Map> SuggestLayoutTransforms( + const Function& fn, Array write_buffer_transformations); + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index ffcdaceb4076..efd1b51f11de 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -21,7 +21,7 @@ configuring the passes and scripting them in Python. """ -from typing import Dict, List +from typing import Dict, List, Union, Callable from enum import IntEnum from tvm import tir @@ -29,6 +29,7 @@ from tvm.relax.ty import Type from tvm.relax.struct_info import StructInfo, FuncStructInfo from tvm.relax.expr import DataflowBlock, Var, Expr, Function, Call, Binding +from tvm.tir import IndexMap, PrimFunc, Block, Buffer from . import _ffi_api @@ -289,3 +290,32 @@ def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool: will be well tested and will not be blocked by not having structure info. """ return _ffi_api.well_formed(mod, check_struct_info) # type: ignore + + +def suggest_layout_transforms( + func: PrimFunc, write_buffer_transforms: List[Union[IndexMap, Callable]] +) -> Dict[Block, Dict[Union[Block, Buffer], IndexMap]]: + """Suggest Layout transformations of blocks and buffers in a PrimFunc. + + Parameters + ---------- + func: PrimFunc + PrimFunc on which analysis will be performed and transformations suggested. + + write_buffer_transforms: List[Union[IndexMap, Callable] + List of layout transformations on the output buffers. The number of layout + transformations must match the number of outputs of the PrimFunc. + + Returns + ------- + ret: Dict[Block, Dict[Union[Block, Buffer], IndexMap]] + Suggested transforms per block in `func`. For each block the returned value is a map + from the object (block or buffer) to it's index map transformation. + """ + write_buffer_index_maps = [] + for transform in write_buffer_transforms: + if callable(transform): + transform = IndexMap.from_func(transform) + assert isinstance(transform, IndexMap) + write_buffer_index_maps.append(transform) + return _ffi_api.suggest_layout_transforms(func, write_buffer_index_maps) # type: ignore diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc new file mode 100644 index 000000000000..44538fea98e5 --- /dev/null +++ b/src/relax/analysis/layout_transformation.cc @@ -0,0 +1,621 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relax/analysis/layout_transormation.cc + * \brief Analyze the PrimFunc and suggest layout transformation on it's blocks and buffers based on + * the user provided layout transformations on it's outputs. + */ +#include +#include +#include +#include + +#include "../../support/array.h" + +namespace tvm { +namespace relax { + +using namespace tir; + +/********** Helper Functions **********/ + +/*! \brief Checks if a transformation is bijective affine over the given ranges */ +static bool IsBijectiveAffine(const IndexMap& m, const Array& ranges) { + Map input_iters; + ICHECK_EQ(m->initial_indices.size(), ranges.size()); + for (size_t i = 0; i < ranges.size(); i++) { + input_iters.Set(m->initial_indices[i], ranges[i]); + } + arith::Analyzer analyzer; + auto iter_map_result = DetectIterMap(m->final_indices, input_iters, /* predicate = */ 1, + /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*simplify_trivial_iterators=*/true); + return !iter_map_result->indices.empty(); +} + +/*! + * \brief Analyzer to collect iterators from IterSumExpr. + * \details Analyzes the indices from DetectIterMap analysis to collect the spatial iterators that + * are used in it. This is important to get which spatial iterators are accessed in each index + * of buffer access. + */ +class IndexAnalyzer : public ExprVisitor { + public: + Array Analyze(const arith::IterSumExpr& expr) { + VisitExpr(expr); + return iterators_; + } + + private: + /*! \brief Override VisitExpr for iter expr type processing */ + void VisitExpr(const PrimExpr& expr) override { + if (const auto* op = expr.as()) { + for (const auto& arg : op->args) VisitExpr(arg); + VisitExpr(op->base); + return; + } + if (const auto* op = expr.as()) { + VisitIterMark(op->source); + VisitExpr(op->lower_factor); + VisitExpr(op->extent); + VisitExpr(op->scale); + return; + } + return ExprVisitor::VisitExpr(expr); + } + + void VisitIterMark(const arith::IterMark& op) { + if (const auto* var = op->source.as()) + iterators_.push_back(GetRef(var)); + else + VisitExpr(op->source); + VisitExpr(op->extent); + } + + private: + Array iterators_; +}; + +/*! + * \brief Analyzes IterMapResult to get the Spatial Layout of buffer access. + * \details We define Spatial Layout of a buffer access as an array of length equal to the + * dimensions of the buffer. i-th element of Spatial Layout contains spatial iter var used from the + * block iteration domain. For indices, where no spatial iter vars are used, the spatial layout + * element is empty. If any of the buffer access indices use multiple spatial iter vars, the spatial + * layout is undefined. + * + * Here are a few examples of inferred spatial layout from buffer access. si denotes i-th spatial + * iter var, and ri denotes i-th reduction iter var. + * + * SpatialLayout(A[s0*constant, s1]) = {s0, s1} + * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1} + * SpatialLayout(A[s0 * c + s1]) = undefined + */ +using SpatialLayout = Array>; +static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { + ICHECK(!iter_map_result->indices.empty()); + SpatialLayout result; + for (const arith::IterSumExpr& index : iter_map_result->indices) { + IndexAnalyzer index_analyzer; + Array iter_vars = index_analyzer.Analyze(index); + if (iter_vars.size() >= 2) { + LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of access: " + << arith::NormalizeIterMapToExpr(index); + return {}; + } + if (iter_vars.empty()) { + result.push_back({}); + continue; + } + result.push_back(iter_vars[0]); + } + return result; +} + +/*! + * \brief Checks if the two spatial layouts are identical. Two empty spatial layouts are treated as + * unequal. + */ +static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const SpatialLayout& s1) { + if (s0.empty() || s1.empty()) return false; + if (s0.size() != s1.size()) return false; + for (size_t i = 0; i < s0.size(); ++i) { + if ((!s0[i].defined() && s1[i].defined()) || (s0[i].defined() && !s1[i].defined())) + return false; + if (!s0[i].same_as(s1[i])) return false; + } + return true; +} + +/*! + * \brief Checks if the block accesses a buffer sequentially in terms of spatial dimensions + * (ignoring reduction dimensions). It checks that the order of spatial iter vars in spatial layout + * of a buffer access is same as the order of spatial iter vars in block domain. + */ +using VarToBlockIndexMap = std::unordered_map; +static bool IsSequentialAccess(const SpatialLayout& iterators, + const VarToBlockIndexMap& iter_to_block_index) { + int last_value = -1; + for (const auto& i : iterators) { + if (!i.defined()) continue; + auto it = iter_to_block_index.find(i.value()); + ICHECK(it != iter_to_block_index.end()); + int blk_index = it->second; + if (blk_index <= last_value) return false; + last_value = blk_index; + } + return true; +} + +/*! \brief Checks if two IndexMaps represent identical transforms */ +static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { + if (t0->initial_indices.size() != t1->initial_indices.size()) return false; + if (t0->final_indices.size() != t1->final_indices.size()) return false; + + // Create a new shape expression. + Array t1_initial_indices = + t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); + auto t0_output = t0->MapIndices(t1_initial_indices); + arith::Analyzer analyzer; + for (size_t i = 0; i < t0_output.size(); ++i) { + if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return false; + } + return true; +} + +/*! + * \brief Returns the layout transformation for a target spatial layout from the source spatial + * layout and transformation. + * \details Given the source buffer spatial layout \p src_spatial_layout and its transformation \p + * src_transformation, this function constructs the transformation for the target buffer whose + * spatial layout is given as \p tgt_spatial_layout. + * + * The algorithm is explained below using an example: + * + * Let's say the source transformation is lambda N, C, H, W -> (N, H, W, C // 4, C % + * 4), source spatial layout is 'NCHW' and target spatial layout is 'KCHW'. + * + * Step 1: Copy over the source transformation initial & final indices to target transformation + * initial and final indices. + * target transformation = lambda N, C, H, W -> (N, H, W, C // 4, C %4) + * + * Step 2: Drop any vars from initial indices which do not occur in target buffer using source and + * target spatial layouts. + * target transformation = lambda C, H, W -> (N, H, W, C // 4, C %4) + * + * Step 3: Erase any expression from final indices which is dependent on a var not present in + * initial indices. + * target transformation = lambda C, H, W -> (H, W, C // 4, C %4) + * + * Step 4: Go over the target spatial layout and add any missing dims to both initial and final + * indices. This is done by checking if any iterator in target spatial layout is not present in + * source spatial layout. + * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) + */ +using VarSet = std::unordered_set; +static Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, + const IndexMap& src_transformation, + const SpatialLayout& tgt_spatial_layout) { + // Copy over the src transformation intial and final indices + auto initial_indices = support::AsList(src_transformation->initial_indices); + auto final_indices = support::AsList(src_transformation->final_indices); + + // Get the iterator var set used in target spatial layout. + VarSet tgt_var_set; + for (const auto& i : tgt_spatial_layout) { + if (i.defined()) tgt_var_set.insert(i.value()); + } + + // Erase initial indices corresponding to iter vars that do not occur in target spatial layout. + // Also compute the var set of initial indices. + auto initial_indices_it = initial_indices.begin(); + VarSet initial_indices_var_set; + for (const auto& i : src_spatial_layout) { + ICHECK(i.defined()); + if (tgt_var_set.count(i.value())) { + initial_indices_var_set.insert(*initial_indices_it); + initial_indices_it++; + continue; + } + initial_indices_it = initial_indices.erase(initial_indices_it); + } + + // Erase any expressions in final indices that have undefined vars + auto final_indices_it = final_indices.begin(); + while (final_indices_it != final_indices.end()) { + // Collect all the vars used in this final index. + Array used_vars = tir::UndefinedVars(*final_indices_it); + ICHECK(!used_vars.empty()) + << "IndexMap expression must always contain tir::Var nodes but found none in: " + << *final_indices_it; + + bool has_undefined_vars = std::any_of(used_vars.begin(), used_vars.end(), + [&initial_indices_var_set](const tir::Var& v) { + return initial_indices_var_set.count(v) == 0; + }); + + // If all vars are from initial indices, nothing to do for this final index. + if (!has_undefined_vars) { + final_indices_it++; + continue; + } + // We are about to drop this expr from final indices since it has undefined vars. Check if it is + // dependent on any of the initial indices. If it is dependent, this cannot be dropped and we + // bail by returning null. + // This captures the scenario where the source transformation is unpacking a dimension (e.g, + // "H4h" -> "H*4+h" ) and the buffer we are trying to infer the transformation of has 'h' + // dimension, but not 'H'. So, it is dependent on undefined var 'H' and defined var 'h'. + bool depends_on_initial_indices = std::any_of(used_vars.begin(), used_vars.end(), + [&initial_indices_var_set](const tir::Var& v) { + return initial_indices_var_set.count(v) != 0; + }); + if (depends_on_initial_indices) { + LOG(WARNING) + << "[LayoutInference] Buffer access is dependent on both defined and undefined vars"; + return {}; + } + // It is ok to erase this final index expression as it only depends on undefined vars. + final_indices_it = final_indices.erase(final_indices_it); + } + + // Go over the target spatial layout and add any missing dims to both initial and final indices. + // This is done by checking if any iterator in target spatial layout is not present in source + // spatial layout. + VarSet src_var_set; + for (const auto& i : src_spatial_layout) { + ICHECK(i.defined()); + src_var_set.insert(i.value()); + } + + initial_indices_it = initial_indices.begin(); + final_indices_it = final_indices.begin(); + for (const auto& i : tgt_spatial_layout) { + if (i.defined() && src_var_set.count(i.value())) { + initial_indices_it++; + if (final_indices_it != final_indices.end()) final_indices_it++; + continue; + } + + auto new_dim = tir::Var("d"); + initial_indices.insert(initial_indices_it, new_dim); + final_indices.insert(final_indices_it, new_dim); + } + + return IndexMap(support::AsArray(initial_indices), support::AsArray(final_indices)); +} + +/*! + * \brief Analyzes the Block and given output buffer transformations to propose + * transformations of block and read buffers. + * \details It does a best effort analysis to propose transformations which would preserve + * sequential access to buffers (especially output buffers). Since this is best effort, it is + * possible that the Block is too complex for analysis. In such a case, no transformations are + * proposed. Limitations: + * 1. Expects exactly one write buffer in the block whose transformation is given by + * `write_transformation`. + * 2. Expects write buffer access to be affine and only use spatial iterators of the block. + * 3. Proposes transformations to a read buffer if all access to it are affine. + */ +class BlockAnalyzer : public StmtExprVisitor { + public: + explicit BlockAnalyzer(const Block& block, const Map& transformation_cache, + IndexMap write_transformation) + : can_transform_block_(true), + write_transformation_(write_transformation), + block_(block), + buffer_transformation_cache_(transformation_cache) { + ICHECK(block_->writes.size() == 1); + auto write_buffer = block_->writes[0]->buffer; + + ComputeBlockSpatialDomain(); + + // Visit the block body to collect load/store access patterns of different buffers. + VisitStmt(block_->body); + + // While visiting the load/store accesses it is possible we see an unexpected pattern, such as + // nested block or write access to multiple buffers. In such a case, we can return early as we + // would not be making any layout suggesstions. + if (!can_transform_block_) { + LOG(WARNING) << "[LayoutInference] Unable to transform block " << block->name_hint; + return; + } + + // Get iterator ordering and it's spatial layout. + VarToBlockIndexMap iter_var_to_block_index; + SpatialLayout block_spatial_layout; + int index = 0; + for (const auto& iter_var : block->iter_vars) { + auto var = iter_var->var; + iter_var_to_block_index[var] = index++; + block_spatial_layout.push_back(var); + } + + // Helper to get the spatial layout of buffer from buffer access map. + auto get_spatial_layout = [&](Buffer b) -> SpatialLayout { + auto it = buffer_access_info_.find(b); + if (it == buffer_access_info_.end()) { + return {}; + } + auto access_info = it->second; + return access_info.GetValidSpatialLayout(); + }; + + // Check that write has sequential access within the block. + SpatialLayout write_spatial_layout = get_spatial_layout(write_buffer); + if (write_spatial_layout.empty()) { + can_transform_block_ = false; + return; + } + if (!IsSequentialAccess(write_spatial_layout, iter_var_to_block_index)) { + can_transform_block_ = false; + return; + } + + // Infer Block transformation from write buffer transformation. + auto maybe_block_transformation = InferLayoutTransformation( + write_spatial_layout, write_transformation_, block_spatial_layout); + if (!maybe_block_transformation.defined()) { + can_transform_block_ = false; + return; + } + block_transformation_ = maybe_block_transformation.value(); + + Array block_ranges = block_->iter_vars.Map([](const IterVar& i) { return i->dom; }); + if (!IsBijectiveAffine(block_transformation_, block_ranges)) { + can_transform_block_ = false; + LOG(WARNING) << "[LayoutInference] Inferred block transformation is not bijective affine, " + "transformation: (" + << block_transformation_ << ") over range (" << block_ranges << ")"; + return; + } + + // Infer read buffer transformations from write buffer transformation. + for (const auto& r : block->reads) { + SpatialLayout read_spatial_layout = get_spatial_layout(r->buffer); + if (read_spatial_layout.empty()) continue; + if (!IsSequentialAccess(read_spatial_layout, iter_var_to_block_index)) continue; + + auto maybe_read_transformation = InferLayoutTransformation( + write_spatial_layout, write_transformation_, read_spatial_layout); + if (!maybe_read_transformation.defined()) continue; + IndexMap read_transformation = maybe_read_transformation.value(); + if (buffer_transformation_cache_.count(r->buffer) != 0) { + if (!AreIdenticalTransforms(read_transformation, buffer_transformation_cache_[r->buffer])) + LOG(WARNING) << "[LayoutInference] Buffer: " << r->buffer + << " has conflicting transform proposals -- (preferred) " + << buffer_transformation_cache_[r->buffer] << " vs. " << read_transformation; + continue; + } + read_buffer_transformations_.Set(r->buffer, read_transformation); + } + } + + private: + // Helper class to keep track of spatial layout of buffer as we visit multiple accesses to this + // buffer within the block. + class BufferAccessInfo { + public: + BufferAccessInfo() : is_valid_(true) {} + void Update(SpatialLayout s) { + if (!IsValid()) return; + if (spatial_layout_.empty()) spatial_layout_ = s; + if (!AreIdenticalSpatialAccess(s, spatial_layout_)) { + Invalidate(); + return; + } + } + bool IsValid() { return is_valid_; } + void Invalidate() { is_valid_ = false; } + SpatialLayout GetValidSpatialLayout() { + if (!IsValid()) return {}; + return spatial_layout_; + } + + private: + bool is_valid_; + SpatialLayout spatial_layout_; + }; + + // Helper to break down the indices of buffer access. + SpatialLayout DetectBufferAccessIterMap(Array indices) { + auto result = arith::DetectIterMap( + /*indices=*/indices, /*input_iters*/ spatial_dom_, + /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, &arith_analyzer_); + if (result->indices.empty()) { + LOG(WARNING) << "[LayoutInference] Failed to analyze indices " << indices + << ", error: " << result->errors; + return {}; + } + return GetSpatialLayout(result); + } + + // Compute the spatial domain map of block + void ComputeBlockSpatialDomain() { + for (const IterVar& v : block_->iter_vars) { + if (v->iter_type == kDataPar) { + spatial_dom_.Set(v->var, v->dom); + continue; + } + if (v->iter_type == kCommReduce) continue; + LOG(WARNING) << "[LayoutInference] Cannot compute block spatial domain in presence of " + "unknown block iter_type : " + << v->iter_type; + can_transform_block_ = false; + return; + } + } + + void VisitStmt_(const BlockNode* op) final { + // Blocks with nested blocks cannot be handled yet. + LOG(WARNING) << "[LayoutInference] Nested blocks are not supported for layout inference yet"; + can_transform_block_ = false; + } + void VisitStmt_(const BufferStoreNode* op) final { + StmtExprVisitor::VisitStmt_(op); + + BufferAccessInfo& access_info = buffer_access_info_[op->buffer]; + + // Fast path to ignore further analysis if we know that the buffer access is invalid. + if (!access_info.IsValid()) return; + + // Only single write buffer is supported for each block. + if (!op->buffer.same_as(block_->writes[0]->buffer)) { + access_info.Invalidate(); + LOG(WARNING) << "[LayoutInference] Exactly one write buffer is supported for layout " + "inference, found two: " + << op->buffer << " and " << block_->writes[0]->buffer; + can_transform_block_ = false; + return; + } + + // If the write buffer access cannot be analyzed, no transformation to the block will be made. + auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices); + if (detected_spatial_layout.empty()) { + access_info.Invalidate(); + return; + } + + // Check if we have access info for this buffer, if present, the two accesses must be + // identical. + access_info.Update(detected_spatial_layout); + } + + void VisitExpr_(const BufferLoadNode* op) final { + Buffer read_buffer = op->buffer; + BufferAccessInfo& access_info = buffer_access_info_[op->buffer]; + + auto detected_spatial_layout = DetectBufferAccessIterMap(op->indices); + + if (detected_spatial_layout.empty()) { + access_info.Invalidate(); + return; + } + access_info.Update(detected_spatial_layout); + } + + public: + bool CanBeTransformed() { return can_transform_block_; } + IndexMap GetBlockTransformation() { return block_transformation_; } + Map GetReadBufferTransformations() { return read_buffer_transformations_; } + + private: + bool can_transform_block_; + IndexMap write_transformation_; + Map spatial_dom_; + arith::Analyzer arith_analyzer_; + + Block block_; + IndexMap block_transformation_; + + Map read_buffer_transformations_; + const Map& buffer_transformation_cache_; + std::unordered_map buffer_access_info_; +}; + +/*! + * \brief Analyzes the PrimFunc and user provided output buffer transformations to propose + * transformations of block and buffers within the PrimFunc. + * \details It does a best effort analysis to propose transformations which would preserve + * sequential access to buffers (especially output buffers). Since this is best effort, it is + * possible that the PrimFunc is too complex for analysis. In such a case, no transformations are + * proposed. + */ +class PrimFuncAnalyzer : public StmtExprVisitor { + public: + explicit PrimFuncAnalyzer(const PrimFunc& func, Array write_transformations) { + ICHECK_LE(write_transformations.size(), func->params.size()) + << "Incompatible PrimFunc and write_transformations"; + + size_t first_write_index = func->params.size() - write_transformations.size(); + for (size_t i = 0; i < write_transformations.size(); ++i) { + auto param = func->params[first_write_index + i]; + Optional param_buf = func->buffer_map.Get(param); + ICHECK(param_buf.defined()); + ICHECK_EQ(param_buf.value()->shape.size(), write_transformations[i]->initial_indices.size()) + << "Mismatch between output buffer shape and index map"; + buffer_transformation_cache_.Set(param_buf.value(), write_transformations[i]); + } + VisitStmt(func->body); + } + Map> GetSuggestedTransforms() { + Map> result; + for (const auto& [block, index_map] : block_transformations_) { + Map block_transformations; + block_transformations.Set(block, index_map); + for (const auto& buffer : block_to_buffer_[block]) { + block_transformations.Set(buffer, buffer_transformation_cache_[buffer]); + } + result.Set(block, block_transformations); + } + return result; + } + + private: + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt_(op); + return; + } + + Block block = GetRef(op); + // Get block write buffer transformation. + if (block->writes.size() != 1) return; + auto write_buffer = block->writes[0]->buffer; + block_to_buffer_[block].push_back(write_buffer); + BlockAnalyzer block_analyzer(block, buffer_transformation_cache_, + buffer_transformation_cache_[write_buffer]); + + if (!block_analyzer.CanBeTransformed()) return; + // Collect the suggested transformations + block_transformations_.Set(block, block_analyzer.GetBlockTransformation()); + + for (const auto& [buffer, index_map] : block_analyzer.GetReadBufferTransformations()) { + // BlockAnalyzer makes sure that it does not propose transformation for a buffer for which a + // transformation has already been proposed by other blocks or by write_transformations which + // are input to this analysis. + ICHECK_EQ(buffer_transformation_cache_.count(buffer), 0); + buffer_transformation_cache_.Set(buffer, index_map); + block_to_buffer_[block].push_back(buffer); + } + } + + private: + Map buffer_transformation_cache_; + Map block_transformations_; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; +}; + +Map> SuggestLayoutTransforms( + const PrimFunc& prim_func, Array write_buffer_transformations) { + // No changes to the PrimFunc are required if no transformations on output buffers. + if (write_buffer_transformations.empty()) return {}; + + PrimFuncAnalyzer analyzer(prim_func, write_buffer_transformations); + return analyzer.GetSuggestedTransforms(); +} + +TVM_REGISTER_GLOBAL(("relax.analysis.suggest_layout_transforms")) + .set_body_typed([](PrimFunc fn, Array write_buffer_transformations) { + return SuggestLayoutTransforms(fn, write_buffer_transformations); + }); + +} // namespace relax +} // namespace tvm diff --git a/src/support/array.h b/src/support/array.h index 218150f9dba0..0ca57a2410c5 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { @@ -81,11 +82,35 @@ inline std::vector AsVector(const Array& vec); * \brief Convert a std::vector to tvm::runtime::Array * \tparam TSrc The type of elements in the source vector * \tparam TDst The type of elements in the result Array - * \return The result vector + * \return The result Array */ template inline Array AsArray(const std::vector& vec); +/*! + * \brief Convert a tvm::runtime::Array to std::list + * \tparam T The type of elements in the source array + * \return The result list + */ +template +inline std::list AsList(const Array& array) { + std::list list; + for (const auto& v : array) list.push_back(v); + return list; +} + +/*! + * \brief Convert a std::list to tvm::runtime::Array + * \tparam T The type of elements in the source list + * \return The result list + */ +template +inline Array AsArray(const std::list& list) { + Array array; + for (const auto& v : list) array.push_back(v); + return array; +} + /*! * \brief Get the shape tuple as array * \param shape The shape tuple diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py b/tests/python/relax/test_analysis_suggest_layout_transforms.py new file mode 100644 index 000000000000..2850f0ed9f94 --- /dev/null +++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py @@ -0,0 +1,831 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm.testing + +from tvm import relax, tir +from tvm.script import tir as T + + +def apply_transformations(func, suggested_transfoms, print_transformation=False): + sch = tir.Schedule(func) + for block, per_block_transformations in suggested_transfoms.items(): + blockrv = sch.get_block(block.name_hint) + for obj, index_map in per_block_transformations.items(): + if isinstance(obj, tir.Block): + block_name = obj.name_hint + if print_transformation: + print("Block transformation: ", block_name, " :: ", index_map) + sch.transform_block_layout(block_name, index_map) + else: + assert isinstance(obj, tir.Buffer) + buffer = obj + if print_transformation: + print("Buffer transformation: ", buffer, " :: ", index_map) + sch.transform_layout(blockrv, buffer, index_map) + return sch.mod["main"] + + +def test_nested_blocks(): + @T.prim_func + def nested_block( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i, j in T.grid(32, 64): + with T.block("outer"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(arg[v_i, v_j, 0:224, 0:224]) + T.writes(relu[v_i, v_j, 0:224, 0:224]) + for k, l in T.grid(224, 224): + with T.block("inner"): + v_k, v_l = T.axis.remap("SS", [k, l]) + T.reads(arg[v_i, v_j, v_k, v_l]) + T.writes(relu[v_i, v_j, v_k, v_l]) + relu[v_i, v_j, v_k, v_l] = T.max(arg[v_i, v_j, v_k, v_l], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=nested_block, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + # no suggestions for nested block. + assert len(suggested_transforms.items()) == 0 + + +def test_mismatch_transformations_and_num_params(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + with pytest.raises(tvm.TVMError, match="Incompatible PrimFunc and write_transformations"): + _ = relax.analysis.suggest_layout_transforms( + func=elemwise, + write_buffer_transforms=[ + lambda n, c, h, w: (n, h, w, c), + lambda n, c, h, w: (n, h, w, c), + lambda n, c, h, w: (n, h, w, c), + ], + ) + + +def test_empty_write_transformations(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=elemwise, write_buffer_transforms=[] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_non_bijective_block_transform(): + @T.prim_func + def before( + arg: T.Buffer((32, 64), "float32"), + output: T.Buffer((32, 64), "float32"), + ): + for ax0, ax1 in T.grid(32, 64): + with T.block("compute"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 5, c % 5)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_non_affine_access(): + @T.prim_func + def before( + arg: T.Buffer((32, 64), "float32"), + output: T.Buffer((32 * 64, 10), "float32"), + ): + for ax0, ax1, ax2 in T.grid(32, 64, 10): + with T.block("compute"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0 * v_ax1, v_ax2]) + output[v_ax0 * v_ax1, v_ax2] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a, b: (b, a)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_unsupported_write_spatial_layout(): + @T.prim_func + def before( + arg: T.Buffer((4, 4), "float32"), + output: T.Buffer((16), "float32"), + ): + for ax0, ax1 in T.grid(4, 4): + with T.block("flatten"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(arg[v_ax0, v_ax1]) + T.writes(output[v_ax0 * 4 + v_ax1]) + output[v_ax0 * 4 + v_ax1] = arg[v_ax0, v_ax1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a: (a // 4, a % 4)] + ) + assert len(suggested_transforms.items()) == 0 + + +def test_unpacked_iter_used_in_read_access(): + @T.prim_func + def before( + arg: T.Buffer((8, 4), "float32"), + output: T.Buffer((4, 8), "float32"), + ): + for ax0, ax1, ax2 in T.grid(4, 8, 4): + with T.block("compute"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(arg[v_ax1, v_ax2]) + T.writes(output[v_ax0, v_ax1]) + output[v_ax0, v_ax1] = arg[v_ax1, v_ax2] + + @T.prim_func + def expected( + arg: T.Buffer((8, 4), "float32"), + output: T.Buffer((32), "float32"), + ): + for ax0, ax2 in T.grid(32, 4): + with T.block("compute"): + v_ax0, v_ax2 = T.axis.remap("SS", [ax0, ax2]) + T.reads(arg[v_ax0 % 8, v_ax2]) + T.writes(output[v_ax0]) + output[v_ax0] = arg[v_ax0 % 8, v_ax2] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda a, b: (a * 8 + b)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_invalid_index_map(): + @T.prim_func + def elemwise( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + with pytest.raises(tvm.TVMError, match="Mismatch between output buffer shape and index map"): + _ = relax.analysis.suggest_layout_transforms( + func=elemwise, write_buffer_transforms=[lambda n, h, w: (n, w, h)] + ) + with pytest.raises(AssertionError): + _ = relax.analysis.suggest_layout_transforms(func=elemwise, write_buffer_transforms=[2]) + + +def test_SRSR_block(): + @T.prim_func + def before( + arg: T.Buffer((32, 224, 64, 224), "float32"), + sum: T.Buffer((32, 64), "float32"), + ): + for ax0, k2, ax1, k3 in T.grid(32, 224, 64, 224): + with T.block("rxplaceholder_red"): + v_ax0, v_k2, v_ax1, v_k3 = T.axis.remap("SRSR", [ax0, k2, ax1, k3]) + T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(sum[v_ax0, v_ax1]) + with T.init(): + sum[v_ax0, v_ax1] = T.float32(0) + sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_k2, v_ax1, v_k3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 16, 224, 4), "float32"), + sum: T.Buffer((32, 16, 4), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 16, 224, 4): + with T.block("rxplaceholder_red"): + v0, v1, v2, v3, v4 = T.axis.remap("SRSRS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(sum[v0, v2, v4]) + with T.init(): + sum[v0, v2, v4] = T.float32(0) + sum[v0, v2, v4] = sum[v0, v2, v4] + arg[v0, v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_elemwise_symbolic(): + @T.prim_func + def before(arg: T.handle, relu: T.handle): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + Arg = T.match_buffer(arg, (N, C, H, W)) + Relu = T.match_buffer(relu, (N, C, H, W)) + for i0, i1, i2, i3 in T.grid(N, C, H, W): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(Arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(Relu[v_i0, v_i1, v_i2, v_i3]) + Relu[v_i0, v_i1, v_i2, v_i3] = T.max(Arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @T.prim_func + def expected(arg: T.handle, relu: T.handle): + N = T.int64() + C = T.int64() + H = T.int64() + W = T.int64() + Arg = T.match_buffer(arg, (N, H, W, C)) + Relu = T.match_buffer(relu, (N, H, W, C)) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(N, H, W, C): + with T.block("compute"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(Arg[v0, v1, v2, v3]) + T.writes(Relu[v0, v1, v2, v3]) + Relu[v0, v1, v2, v3] = T.max(Arg[v0, v1, v2, v3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_elemwise(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + relu: T.Buffer((32, 64, 224, 224), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 224, 224): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2, v_i3]) + T.writes(relu[v_i0, v_i1, v_i2, v_i3]) + relu[v_i0, v_i1, v_i2, v_i3] = T.max(arg[v_i0, v_i1, v_i2, v_i3], T.float32(0)) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + relu: T.Buffer((32, 224, 224, 64), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): + with T.block("compute"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3]) + T.writes(relu[v0, v1, v2, v3]) + relu[v0, v1, v2, v3] = T.max(arg[v0, v1, v2, v3], T.float32(0)) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pool_nchw_nhwc(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + pool_max: T.Buffer((32, 64, 111, 223), "float32"), + ): + for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(32, 64, 111, 223, 2, 2): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap( + "SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1] + ) + T.reads( + arg[ + v_ax0, + v_ax1, + v_ax2 * 2 + v_rv0 * 2, + v_ax3 + v_rv1, + ] + ) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(-3.4028234663852886e38) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3], + arg[ + v_ax0, + v_ax1, + v_ax2 * 2 + v_rv0 * 2, + v_ax3 + v_rv1, + ], + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + pool_max: T.Buffer((32, 111, 223, 64), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 111, 223, 64, 2, 2): + with T.block("pool_max"): + v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) + T.reads(arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3]) + T.writes(pool_max[v0, v1, v2, v3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) + pool_max[v0, v1, v2, v3] = T.max( + pool_max[v0, v1, v2, v3], + arg[v0, v1 * 2 + v4 * 2, v2 + v5, v3], + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pool_nchw16c_nhwc(): + @T.prim_func + def before( + arg: T.Buffer( + (32, 4, 224, 224, 16), + "float32", + ), + pool_max: T.Buffer( + (32, 4, 110, 220, 16), + "float32", + ), + ): + for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(32, 4, 110, 220, 16, 5, 5): + with T.block("pool_max"): + v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap( + "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1] + ) + T.reads(arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4]) + T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(-3.4028234663852886e38) + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max( + pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], + arg[v_ax0, v_ax1, v_ax2 * 2 + v_rv0, v_ax3 + v_rv1, v_ax4], + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + pool_max: T.Buffer((32, 110, 220, 64), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(32, 110, 220, 64, 5, 5): + with T.block("pool_max"): + v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5]) + T.reads(arg[v0, v1 * 2 + v4, v2 + v5, v3]) + T.writes(pool_max[v0, v1, v2, v3]) + T.block_attr({"schedule_rule": "meta_schedule.pool_max"}) + with T.init(): + pool_max[v0, v1, v2, v3] = T.float32(-3.4028234663852886e38) + pool_max[v0, v1, v2, v3] = T.max( + pool_max[v0, v1, v2, v3], + arg[v0, v1 * 2 + v4, v2 + v5, v3], + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, C, h, w, c: (n, h, w, C * 16 + c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_reduce(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + sum: T.Buffer((32, 64), "float32"), + ): + for ax0, ax1, k2, k3 in T.grid(32, 64, 224, 224): + with T.block("rxplaceholder_red"): + v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, k2, k3]) + T.reads(arg[v_ax0, v_ax1, v_k2, v_k3]) + T.writes(sum[v_ax0, v_ax1]) + with T.init(): + sum[v_ax0, v_ax1] = T.float32(0) + sum[v_ax0, v_ax1] = sum[v_ax0, v_ax1] + arg[v_ax0, v_ax1, v_k2, v_k3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 4, 224, 224, 16), "float32"), + sum: T.Buffer((32, 4, 16), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 4, 224, 224, 16): + with T.block("rxplaceholder_red"): + v0, v1, v2, v3, v4 = T.axis.remap("SSRRS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(sum[v0, v1, v4]) + with T.init(): + sum[v0, v1, v4] = T.float32(0) + sum[v0, v1, v4] = sum[v0, v1, v4] + arg[v0, v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c: (n, c // 16, c % 16)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_upsampling(): + # relay materializes the layout if H, W or D dimensions are moved or tiled. + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + resize: T.Buffer((32, 64, 202, 246), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 202, 246): + with T.block("resize"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, 0:224, 0:224]) + T.writes(resize[v_i0, v_i1, v_i2, v_i3]) + resize[v_i0, v_i1, v_i2, v_i3] = arg[ + v_i0, + v_i1, + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(1.1089109182357788) * T.Cast("float32", v_i2) + + T.float32(1.0000000000000001e-05) + ), + ), + 223, + ), + 0, + ), + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(0.91056913137435913) * T.Cast("float32", v_i3) + + T.float32(1.0000000000000001e-05) + ), + ), + 223, + ), + 0, + ), + ] + + @T.prim_func + def expected( + arg: T.Buffer((32, 64, 224, 224), "float32"), + resize: T.Buffer((32, 202, 246, 64), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(32, 202, 246, 64): + with T.block("resize"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v3, 0:224, 0:224]) + T.writes(resize[v0, v1, v2, v3]) + resize[v0, v1, v2, v3] = arg[ + v0, + v3, + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(1.1089109182357788) * T.Cast("float32", v1) + + T.float32(1.0000000000000001e-05) + ), + ), + T.int64(223), + ), + T.int64(0), + ), + T.max( + T.min( + T.Cast( + "int64", + T.floor( + T.float32(0.91056913137435913) * T.Cast("float32", v2) + + T.float32(1.0000000000000001e-05) + ), + ), + T.int64(223), + ), + T.int64(0), + ), + ] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_strided_slice(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_strided_slice_with_axes: T.Buffer((32, 64, 10, 8), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 64, 10, 8): + with T.block("T_strided_slice_with_axes"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + arg[ + v_ax0, + v_ax1, + v_ax2 * 5 + 2, + v_ax3 * 7 + 4, + ] + ) + T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3]) + T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2, v_ax3] = arg[ + v_ax0, + v_ax1, + v_ax2 * 5 + 2, + v_ax3 * 7 + 4, + ] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + T_strided_slice_with_axes: T.Buffer((32, 10, 8, 16, 4), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 10, 8, 16, 4): + with T.block("T_strided_slice_with_axes"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4]) + T.writes(T_strided_slice_with_axes[v0, v1, v2, v3, v4]) + T_strided_slice_with_axes[v0, v1, v2, v3, v4] = arg[ + v0, v1 * 5 + 2, v2 * 7 + 4, v3, v4 + ] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_binary_broadcast(): + @T.prim_func + def before( + arg0: T.Buffer((32, 64, 224, 224), "float32"), + arg1: T.Buffer((64, 224, 224), "float32"), + T_add: T.Buffer((32, 64, 224, 224), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3 in T.grid(32, 64, 224, 224): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads( + arg0[v_ax0, v_ax1, v_ax2, v_ax3], + arg1[v_ax1, v_ax2, v_ax3], + ) + T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3]) + T_add[v_ax0, v_ax1, v_ax2, v_ax3] = ( + arg0[v_ax0, v_ax1, v_ax2, v_ax3] + arg1[v_ax1, v_ax2, v_ax3] + ) + + @T.prim_func + def expected( + arg0: T.Buffer((32, 224, 224, 16, 4), "float32"), + arg1: T.Buffer((224, 224, 16, 4), "float32"), + T_add: T.Buffer((32, 224, 224, 16, 4), "float32"), + ): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4): + with T.block("T_add"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg0[v0, v1, v2, v3, v4], arg1[v1, v2, v3, v4]) + T.writes(T_add[v0, v1, v2, v3, v4]) + T_add[v0, v1, v2, v3, v4] = arg0[v0, v1, v2, v3, v4] + arg1[v1, v2, v3, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_transpose(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_transpose: T.Buffer((32, 224, 224, 64), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 64): + with T.block("T_transpose"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax3, v_ax1, v_ax2]) + T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3]) + T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax3, v_ax1, v_ax2] + + @T.prim_func + def expected( + arg: T.Buffer((32, 64, 224, 224), "float32"), + T_transpose: T.Buffer((32, 224, 64, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 64, 224): + with T.block("T_transpose"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v2, v3, v1]) + T.writes(T_transpose[v0, v1, v2, v3]) + T_transpose[v0, v1, v2, v3] = arg[v0, v2, v3, v1] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_pad(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + PadInput: T.Buffer((32, 64, 230, 230), "float32"), + ): + for i0, i1, i2, i3 in T.grid(32, 64, 230, 230): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + 2 <= v_i2 and v_i2 < 226 and 2 <= v_i3 and v_i3 < 226, + arg[v_i0, v_i1, v_i2 - 2, v_i3 - 2], + T.float32(2), + ) + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + PadInput: T.Buffer((32, 230, 230, 16, 4), "float32"), + ): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 230, 230, 16, 4): + with T.block("PadInput"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1 - 2, v2 - 2, v3, v4]) + T.writes(PadInput[v0, v1, v2, v3, v4]) + PadInput[v0, v1, v2, v3, v4] = T.if_then_else( + 2 <= v1 and v1 < 226 and 2 <= v2 and v2 < 226, + arg[v0, v1 - 2, v2 - 2, v3, v4], + T.float32(2), + ) + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c // 4, c % 4)] + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_split(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + split0: T.Buffer((32, 32, 224, 224), "float32"), + split1: T.Buffer((32, 32, 224, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) + split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) + T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) + split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 64), "float32"), + split0: T.Buffer((32, 224, 224, 32), "float32"), + split1: T.Buffer((32, 224, 224, 32), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): + with T.block("T_split_sections"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3]) + T.writes(split0[v0, v1, v2, v3]) + split0[v0, v1, v2, v3] = arg[v0, v1, v2, v3] + for ax0, ax1, ax2, ax3 in T.grid(32, 224, 224, 32): + with T.block("T_split_sections_1"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v0, v1, v2, v3 + 32]) + T.writes(split1[v0, v1, v2, v3]) + split1[v0, v1, v2, v3] = arg[v0, v1, v2, v3 + 32] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[lambda n, c, h, w: (n, h, w, c), lambda n, c, h, w: (n, h, w, c)], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +def test_op_split_tiling_split_dim(): + @T.prim_func + def before( + arg: T.Buffer((32, 64, 224, 224), "float32"), + split0: T.Buffer((32, 32, 224, 224), "float32"), + split1: T.Buffer((32, 32, 224, 224), "float32"), + ): + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(split0[v_ax0, v_ax1, v_ax2, v_ax3]) + split0[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(32, 32, 224, 224): + with T.block("T_split_sections_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3]) + T.writes(split1[v_ax0, v_ax1, v_ax2, v_ax3]) + split1[v_ax0, v_ax1, v_ax2, v_ax3] = arg[v_ax0, v_ax1 + 32, v_ax2, v_ax3] + + @T.prim_func + def expected( + arg: T.Buffer((32, 224, 224, 16, 4), "float32"), + split0: T.Buffer((32, 224, 224, 8, 4), "float32"), + split1: T.Buffer((32, 224, 224, 8, 4), "float32"), + ): + # with T.block("root"): + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): + with T.block("T_split_sections"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3, v4]) + T.writes(split0[v0, v1, v2, v3, v4]) + split0[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3, v4] + for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 8, 4): + with T.block("T_split_sections_1"): + v0, v1, v2, v3, v4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4]) + T.reads(arg[v0, v1, v2, v3 + 8, v4]) + T.writes(split1[v0, v1, v2, v3, v4]) + split1[v0, v1, v2, v3, v4] = arg[v0, v1, v2, v3 + 8, v4] + + suggested_transforms = relax.analysis.suggest_layout_transforms( + func=before, + write_buffer_transforms=[ + lambda n, c, h, w: (n, h, w, c // 4, c % 4), + lambda n, c, h, w: (n, h, w, c // 4, c % 4), + ], + ) + after = apply_transformations(before, suggested_transforms) + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main()