Skip to content

Commit

Permalink
symbol+scope:: now viewID ignores one-sized dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Oct 4, 2018
1 parent 0890fa4 commit f3e482d
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 13 deletions.
23 changes: 16 additions & 7 deletions core/jitk/engines/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,32 @@ void Engine::writeBlock(const SymbolTable &symbols,
jitk::Scope scope(symbols, parent_scope, local_tmps, scalar_replaced_reduction_outputs, srio);

// Write temporary and scalar replaced array declarations
vector<const bh_view *> scalar_replaced_to_write_back;
vector<pair<const bh_view *, int> > scalar_replaced_to_write_back; // Pair of the view and hidden_axis
for (const jitk::Block &block: kernel._block_list) {
if (block.isInstr()) {
const jitk::InstrPtr &instr = block.getInstr();
for (const bh_view &view: instr->getViews()) {
for (size_t o = 0; o < instr->operand.size(); ++o) {
const bh_view &view = instr->operand[o];
if (not scope.isDeclared(view)) {
if (scope.isTmp(view.base)) {
util::spaces(out, 8 + kernel.rank * 4);
scope.writeDeclaration(view, writeType(view.base->type), out);
out << "\n";
} else if (scope.isScalarReplaced(view)) {
// If 'instr' is a reduction we have to ignore the reduced axis when declaring the output
// array (but only if we are reducing to a non-scalar).
int hidden_axis = BH_MAXDIM; // Note, `BH_MAXDIM` means on hidden axis
if (o == 0 and bh_opcode_is_reduction(instr->opcode) and instr->operand[1].ndim > 1) {
hidden_axis = instr->sweep_axis();
}
util::spaces(out, 8 + kernel.rank * 4);
scope.writeDeclaration(view, writeType(view.base->type), out);
out << " " << scope.getName(view) << " = a" << symbols.baseID(view.base);
write_array_subscription(scope, view, out);
write_array_subscription(scope, view, out, false, hidden_axis);
out << ";";
out << "\n";
if (scope.isScalarReplaced_RW(view)) {
scalar_replaced_to_write_back.push_back(&view);
scalar_replaced_to_write_back.emplace_back(&view, hidden_axis);
}
}
}
Expand Down Expand Up @@ -191,11 +198,13 @@ void Engine::writeBlock(const SymbolTable &symbols,
}
}

// Let's copy the scalar replaced reduction outputs back to the original array
for (const bh_view *view: scalar_replaced_to_write_back) {
// Let's copy the scalar replaced back to the original array
for (const auto view_and_hidden_axis: scalar_replaced_to_write_back) {
const bh_view *view = view_and_hidden_axis.first;
const int hidden_axis = view_and_hidden_axis.second;
util::spaces(out, 8 + kernel.rank * 4);
out << "a" << symbols.baseID(view->base);
write_array_subscription(scope, *view, out, true);
write_array_subscription(scope, *view, out, true, hidden_axis);
out << " = ";
scope.getName(*view, out);
out << ";\n";
Expand Down
1 change: 0 additions & 1 deletion include/jitk/codegen_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ void create_directories(const boost::filesystem::path &path);
// This makes the source of the kernels more identical, which improve the code and compile caches.
std::vector<InstrPtr> order_sweep_set(const std::set<InstrPtr> &sweep_set, const SymbolTable &symbols);


// Returns True when `view` is accessing row major style
bool row_major_access(const bh_view &view);

Expand Down
6 changes: 3 additions & 3 deletions include/jitk/scope.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class Scope {
const Scope * const parent;
private:
std::set<const bh_base*> _tmps; // Set of temporary arrays
std::set<bh_view> _scalar_replacements_rw; // Set of scalar replaced arrays that both reads and writes
std::set<bh_view> _scalar_replacements_r; // Set of scalar replaced arrays
std::set<bh_view, IgnoreOneDim_less> _scalar_replacements_rw; // Set of scalar replaced arrays that both reads and writes
std::set<bh_view, IgnoreOneDim_less> _scalar_replacements_r; // Set of scalar replaced arrays
std::set<InstrPtr> _omp_atomic; // Set of instructions that should be guarded by OpenMP atomic
std::set<InstrPtr> _omp_critical; // Set of instructions that should be guarded by OpenMP critical
std::set<bh_base*> _declared_base; // Set of bases that have been locally declared (e.g. a temporary variable)
std::set<bh_view> _declared_view; // Set of views that have been locally declared (e.g. a temporary variable)
std::set<bh_view, IgnoreOneDim_less> _declared_view; // Set of views that have been locally declared (e.g. scalar replaced variable)
std::set<bh_view, OffsetAndStrides_less> _declared_idx; // Set of indexes that have been locally declared
public:
template<typename T1, typename T2>
Expand Down
57 changes: 55 additions & 2 deletions include/jitk/symbol_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,70 @@ struct OffsetAndStrides_less {

// Compare class for the constant_map
struct Constant_less {
// This compare tje 'origin_id' member of the instructions
// This compare the 'origin_id' member of the instructions
bool operator() (const InstrPtr &i1, const InstrPtr& i2) const {
return i1->origin_id < i2->origin_id;
}
};

// Compare class for the viewID sets and maps
struct IgnoreOneDim_less {
BhIntVec get_shape_where_shape_is_greater_than_one(const bh_view &view) const {
BhIntVec ret;
for (int64_t i = 0; i < view.ndim; ++i) {
if (view.shape[i] > 1) {
ret.push_back(view.shape[i]);
}
}
return ret;
}

BhIntVec get_stride_where_shape_is_greater_than_one(const bh_view &view) const {
BhIntVec ret;
for (int64_t i = 0; i < view.ndim; ++i) {
if (view.shape[i] > 1) {
ret.push_back(view.stride[i]);
}
}
return ret;
}

// This compare is the same as view compare ('v1 < v2') but ignoring their bases and zero or one-sized dimensions
bool operator() (const bh_view& v1, const bh_view& v2) const {
if (v1.base < v2.base) return true;
if (v2.base < v1.base) return false;
if (v1.start < v2.start) return true;
if (v2.start < v1.start) return false;

auto v1_shape = get_shape_where_shape_is_greater_than_one(v1);
auto v2_shape = get_shape_where_shape_is_greater_than_one(v2);
if (v1_shape.size() < v2_shape.size()) return true;
if (v2_shape.size() < v1_shape.size()) return false;

auto v1_stride = get_stride_where_shape_is_greater_than_one(v1);
auto v2_stride = get_stride_where_shape_is_greater_than_one(v2);
assert(v1_shape.size() == v1_stride.size());
assert(v2_shape.size() == v2_stride.size());

for (size_t i=0; i < v1_shape.size(); ++i) {
if (v1_stride[i] < v2_stride[i]) return true;
if (v2_stride[i] < v1_stride[i]) return false;
if (v1_shape[i] < v2_shape[i]) return true;
if (v2_shape[i] < v1_shape[i]) return false;
}
return false;
}
bool operator() (const bh_view* v1, const bh_view* v2) const {
return (*this)(*v1, *v2);
}
};


// The SymbolTable class contains all array meta date needed for a JIT kernel.
class SymbolTable {
private:
std::map<const bh_base*, size_t> _base_map; // Mapping a base to its ID
std::map<bh_view, size_t> _view_map; // Mapping a view to its ID
std::map<bh_view, size_t, IgnoreOneDim_less> _view_map; // Mapping a view to its ID
std::map<bh_view, size_t, OffsetAndStrides_less> _idx_map; // Mapping a index (of an array) to its ID
std::map<bh_view, size_t, OffsetAndStrides_less> _offset_strides_map; // Mapping a offset-and-strides to its ID
std::vector<const bh_view*> _offset_stride_views; // Vector of all offset-and-stride views
Expand Down

0 comments on commit f3e482d

Please sign in to comment.