From f3e482d2c8bad77d6df59f96da9152daf8a5ef22 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Thu, 4 Oct 2018 16:28:54 +0200 Subject: [PATCH] symbol+scope:: now viewID ignores one-sized dimensions --- core/jitk/engines/engine.cpp | 23 +++++++++----- include/jitk/codegen_util.hpp | 1 - include/jitk/scope.hpp | 6 ++-- include/jitk/symbol_table.hpp | 57 +++++++++++++++++++++++++++++++++-- 4 files changed, 74 insertions(+), 13 deletions(-) diff --git a/core/jitk/engines/engine.cpp b/core/jitk/engines/engine.cpp index 2f6995258..e49afbf97 100644 --- a/core/jitk/engines/engine.cpp +++ b/core/jitk/engines/engine.cpp @@ -103,25 +103,32 @@ void Engine::writeBlock(const SymbolTable &symbols, jitk::Scope scope(symbols, parent_scope, local_tmps, scalar_replaced_reduction_outputs, srio); // Write temporary and scalar replaced array declarations - vector scalar_replaced_to_write_back; + vector > scalar_replaced_to_write_back; // Pair of the view and hidden_axis for (const jitk::Block &block: kernel._block_list) { if (block.isInstr()) { const jitk::InstrPtr &instr = block.getInstr(); - for (const bh_view &view: instr->getViews()) { + for (size_t o = 0; o < instr->operand.size(); ++o) { + const bh_view &view = instr->operand[o]; if (not scope.isDeclared(view)) { if (scope.isTmp(view.base)) { util::spaces(out, 8 + kernel.rank * 4); scope.writeDeclaration(view, writeType(view.base->type), out); out << "\n"; } else if (scope.isScalarReplaced(view)) { + // If 'instr' is a reduction we have to ignore the reduced axis when declaring the output + // array (but only if we are reducing to a non-scalar). + int hidden_axis = BH_MAXDIM; // Note, `BH_MAXDIM` means on hidden axis + if (o == 0 and bh_opcode_is_reduction(instr->opcode) and instr->operand[1].ndim > 1) { + hidden_axis = instr->sweep_axis(); + } util::spaces(out, 8 + kernel.rank * 4); scope.writeDeclaration(view, writeType(view.base->type), out); out << " " << scope.getName(view) << " = a" << symbols.baseID(view.base); - write_array_subscription(scope, view, out); + write_array_subscription(scope, view, out, false, hidden_axis); out << ";"; out << "\n"; if (scope.isScalarReplaced_RW(view)) { - scalar_replaced_to_write_back.push_back(&view); + scalar_replaced_to_write_back.emplace_back(&view, hidden_axis); } } } @@ -191,11 +198,13 @@ void Engine::writeBlock(const SymbolTable &symbols, } } - // Let's copy the scalar replaced reduction outputs back to the original array - for (const bh_view *view: scalar_replaced_to_write_back) { + // Let's copy the scalar replaced back to the original array + for (const auto view_and_hidden_axis: scalar_replaced_to_write_back) { + const bh_view *view = view_and_hidden_axis.first; + const int hidden_axis = view_and_hidden_axis.second; util::spaces(out, 8 + kernel.rank * 4); out << "a" << symbols.baseID(view->base); - write_array_subscription(scope, *view, out, true); + write_array_subscription(scope, *view, out, true, hidden_axis); out << " = "; scope.getName(*view, out); out << ";\n"; diff --git a/include/jitk/codegen_util.hpp b/include/jitk/codegen_util.hpp index 9892b3ceb..70d2416fc 100644 --- a/include/jitk/codegen_util.hpp +++ b/include/jitk/codegen_util.hpp @@ -107,7 +107,6 @@ void create_directories(const boost::filesystem::path &path); // This makes the source of the kernels more identical, which improve the code and compile caches. std::vector order_sweep_set(const std::set &sweep_set, const SymbolTable &symbols); - // Returns True when `view` is accessing row major style bool row_major_access(const bh_view &view); diff --git a/include/jitk/scope.hpp b/include/jitk/scope.hpp index b9ab87b6f..b6518807a 100644 --- a/include/jitk/scope.hpp +++ b/include/jitk/scope.hpp @@ -37,12 +37,12 @@ class Scope { const Scope * const parent; private: std::set _tmps; // Set of temporary arrays - std::set _scalar_replacements_rw; // Set of scalar replaced arrays that both reads and writes - std::set _scalar_replacements_r; // Set of scalar replaced arrays + std::set _scalar_replacements_rw; // Set of scalar replaced arrays that both reads and writes + std::set _scalar_replacements_r; // Set of scalar replaced arrays std::set _omp_atomic; // Set of instructions that should be guarded by OpenMP atomic std::set _omp_critical; // Set of instructions that should be guarded by OpenMP critical std::set _declared_base; // Set of bases that have been locally declared (e.g. a temporary variable) - std::set _declared_view; // Set of views that have been locally declared (e.g. a temporary variable) + std::set _declared_view; // Set of views that have been locally declared (e.g. scalar replaced variable) std::set _declared_idx; // Set of indexes that have been locally declared public: template diff --git a/include/jitk/symbol_table.hpp b/include/jitk/symbol_table.hpp index 1c52aeae4..927acc544 100644 --- a/include/jitk/symbol_table.hpp +++ b/include/jitk/symbol_table.hpp @@ -56,17 +56,70 @@ struct OffsetAndStrides_less { // Compare class for the constant_map struct Constant_less { - // This compare tje 'origin_id' member of the instructions + // This compare the 'origin_id' member of the instructions bool operator() (const InstrPtr &i1, const InstrPtr& i2) const { return i1->origin_id < i2->origin_id; } }; +// Compare class for the viewID sets and maps +struct IgnoreOneDim_less { + BhIntVec get_shape_where_shape_is_greater_than_one(const bh_view &view) const { + BhIntVec ret; + for (int64_t i = 0; i < view.ndim; ++i) { + if (view.shape[i] > 1) { + ret.push_back(view.shape[i]); + } + } + return ret; + } + + BhIntVec get_stride_where_shape_is_greater_than_one(const bh_view &view) const { + BhIntVec ret; + for (int64_t i = 0; i < view.ndim; ++i) { + if (view.shape[i] > 1) { + ret.push_back(view.stride[i]); + } + } + return ret; + } + + // This compare is the same as view compare ('v1 < v2') but ignoring their bases and zero or one-sized dimensions + bool operator() (const bh_view& v1, const bh_view& v2) const { + if (v1.base < v2.base) return true; + if (v2.base < v1.base) return false; + if (v1.start < v2.start) return true; + if (v2.start < v1.start) return false; + + auto v1_shape = get_shape_where_shape_is_greater_than_one(v1); + auto v2_shape = get_shape_where_shape_is_greater_than_one(v2); + if (v1_shape.size() < v2_shape.size()) return true; + if (v2_shape.size() < v1_shape.size()) return false; + + auto v1_stride = get_stride_where_shape_is_greater_than_one(v1); + auto v2_stride = get_stride_where_shape_is_greater_than_one(v2); + assert(v1_shape.size() == v1_stride.size()); + assert(v2_shape.size() == v2_stride.size()); + + for (size_t i=0; i < v1_shape.size(); ++i) { + if (v1_stride[i] < v2_stride[i]) return true; + if (v2_stride[i] < v1_stride[i]) return false; + if (v1_shape[i] < v2_shape[i]) return true; + if (v2_shape[i] < v1_shape[i]) return false; + } + return false; + } + bool operator() (const bh_view* v1, const bh_view* v2) const { + return (*this)(*v1, *v2); + } +}; + + // The SymbolTable class contains all array meta date needed for a JIT kernel. class SymbolTable { private: std::map _base_map; // Mapping a base to its ID - std::map _view_map; // Mapping a view to its ID + std::map _view_map; // Mapping a view to its ID std::map _idx_map; // Mapping a index (of an array) to its ID std::map _offset_strides_map; // Mapping a offset-and-strides to its ID std::vector _offset_stride_views; // Vector of all offset-and-stride views