Skip to content

Commit

Permalink
[Lang] Simplify dense.bitmasked to bitmasked (#670)
Browse files Browse the repository at this point in the history
* [Lang] Simplify dense.bitmasked to bitmasked

* fix tests

* put back

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
k-ye and taichi-gardener authored Mar 28, 2020
1 parent 1371a07 commit 4443534
Show file tree
Hide file tree
Showing 20 changed files with 130 additions and 99 deletions.
11 changes: 5 additions & 6 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def dynamic(self, index, dimension, chunk_size=None):
chunk_size = dimension
return SNode(self.ptr.dynamic(index[0], dimension, chunk_size))

def bitmasked(self, val=True):
self.ptr.bitmasked(val)
return self
def bitmasked(self, indices, dimensions):
if isinstance(dimensions, int):
dimensions = [dimensions] * len(indices)
return SNode(self.ptr.bitmasked(indices, dimensions))

def place(self, *args):
from .expr import Expr
Expand Down Expand Up @@ -72,8 +73,6 @@ def deactivate_all(self):
for c in ch:
c.deactivate_all()
import taichi as ti
if self.ptr.type == ti.core.SNodeType.pointer or (
self.ptr.type == ti.core.SNodeType.dense
and self.ptr.is_bitmasked):
if self.ptr.type == ti.core.SNodeType.pointer or self.ptr.type == ti.core.SNodeType.bitmasked:
from .meta import snode_deactivate
snode_deactivate(self)
11 changes: 5 additions & 6 deletions taichi/backends/metal/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,15 +509,14 @@ class MetalRuntime::Impl {
rtm_meta->num_slots = sn_meta.num_slots;
rtm_meta->mem_offset_in_parent = sn_meta.mem_offset_in_parent;
switch (sn_meta.snode->type) {
case SNodeType::dense:
rtm_meta->type = SNodeMeta::Dense;
break;
case SNodeType::root:
rtm_meta->type = SNodeMeta::Root;
break;
case SNodeType::dense:
if (sn_meta.snode->_bitmasked) {
rtm_meta->type = SNodeMeta::DenseBitmask;
} else {
rtm_meta->type = SNodeMeta::Dense;
}
case SNodeType::bitmasked:
rtm_meta->type = SNodeMeta::Bitmasked;
break;
default:
TI_ERROR("Unsupported SNode type={}",
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ STR(
};

struct SNodeMeta {
enum Type { Root = 0, Dense = 1, DenseBitmask = 2 };
enum Type { Root = 0, Dense = 1, Bitmasked = 2 };
int32_t element_stride = 0;
int32_t num_slots = 0;
int32_t mem_offset_in_parent = 0;
Expand Down
10 changes: 8 additions & 2 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
if (snode->type == SNodeType::dense) {
meta = std::make_unique<RuntimeObject>("DenseMeta", this, builder.get());
emit_struct_meta_base("Dense", meta->ptr, snode);
meta->call("set_bitmasked", tlctx->get_constant(snode->_bitmasked));
meta->call("set_morton_dim", tlctx->get_constant((int)snode->_morton));
} else if (snode->type == SNodeType::pointer) {
meta =
Expand All @@ -200,6 +199,10 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
std::make_unique<RuntimeObject>("DynamicMeta", this, builder.get());
emit_struct_meta_base("Dynamic", meta->ptr, snode);
meta->call("set_chunk_size", tlctx->get_constant(snode->chunk_size));
} else if (snode->type == SNodeType::bitmasked) {
meta =
std::make_unique<RuntimeObject>("BitmaskedMeta", this, builder.get());
emit_struct_meta_base("Bitmasked", meta->ptr, snode);
} else {
TI_P(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED;
Expand Down Expand Up @@ -1067,6 +1070,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
return "pointer";
} else if (snode->type == SNodeType::hash) {
return "Hash";
} else if (snode->type == SNodeType::bitmasked) {
return "Bitmasked";
} else {
TI_P(snode_type_name(snode->type));
TI_NOT_IMPLEMENTED
Expand Down Expand Up @@ -1141,7 +1146,8 @@ class CodeGenLLVM : public IRVisitor, public ModuleBuilder {
stmt->value = builder->CreateGEP(parent, stmt->input_index->value);
} else if (snode->type == SNodeType::dense ||
snode->type == SNodeType::pointer ||
snode->type == SNodeType::dynamic) {
snode->type == SNodeType::dynamic ||
snode->type == SNodeType::bitmasked) {
if (stmt->activate) {
call(snode, stmt->input_snode->value, "activate",
{stmt->input_index->value});
Expand Down
6 changes: 3 additions & 3 deletions taichi/codegen/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class MetalKernelCodegen : public IRVisitor {
emit(R"({}_ch {} = {}.children({});)", sn->node_type_name, stmt->raw_name(),
parent, index_name);
if (stmt->activate) {
TI_ASSERT(sn->type == SNodeType::dense && sn->_bitmasked);
TI_ASSERT(sn->type == SNodeType::bitmasked);
emit("{{");
{
ScopedIndent s(line_appender_);
Expand Down Expand Up @@ -707,7 +707,7 @@ class MetalKernelCodegen : public IRVisitor {

std::string make_snode_meta_bm(const SNode *sn,
const std::string &var_name) const {
TI_ASSERT(sn->type == SNodeType::dense && sn->_bitmasked);
TI_ASSERT(sn->type == SNodeType::bitmasked);
const auto &meta = compiled_snodes_->snode_descriptors.find(sn->id)->second;
LineAppender la = line_appender_;
// Keep the indentation settings only
Expand All @@ -716,7 +716,7 @@ class MetalKernelCodegen : public IRVisitor {
la.append("SNodeMeta {};", var_name);
la.append("{}.element_stride = {};", var_name, meta.element_stride);
la.append("{}.num_slots = {};", var_name, meta.num_slots);
la.append("{}.type = {};", var_name, meta.num_slots);
la.append("{}.type = {};", var_name, (int)shaders::SNodeMeta::Bitmasked);
return la.lines();
}

Expand Down
1 change: 1 addition & 0 deletions taichi/inc/snodes.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ PER_SNODE(root)
PER_SNODE(dense)
PER_SNODE(dynamic)
PER_SNODE(pointer)
PER_SNODE(bitmasked)
PER_SNODE(hash)
PER_SNODE(place)
PER_SNODE(undefined)
2 changes: 1 addition & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@ class SNodeOpExpression : public Expression {
// It should be lowered into a pointer to parent and an index.
TI_ERROR_IF(
snode->type != SNodeType::pointer && snode->type != SNodeType::hash &&
!(snode->type == SNodeType::dense && snode->_bitmasked),
snode->type != SNodeType::bitmasked,
"ti.is_active only works on pointer, hash or bitmasked nodes.");
ret.push_back<SNodeOpStmt>(SNodeOpType::is_active, snode, indices_stmt);
} else {
Expand Down
1 change: 0 additions & 1 deletion taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ SNode::SNode(int depth, SNodeType t) : depth(depth), type(t) {
has_ambient = false;
dt = DataType::gen;
_morton = false;
_bitmasked = false;

reader_kernel = nullptr;
writer_kernel = nullptr;
Expand Down
19 changes: 13 additions & 6 deletions taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class SNode {
std::string node_type_name;
SNodeType type;
bool _morton{};
bool _bitmasked{};

std::string get_node_type_name() const {
return fmt::format("S{}", id);
Expand Down Expand Up @@ -140,6 +139,19 @@ class SNode {
return SNode::pointer(std::vector<Index>{index}, size);
}

SNode &bitmasked(const std::vector<Index> &indices,
const std::vector<int> &sizes) {
return create_node(indices, sizes, SNodeType::bitmasked);
}

SNode &bitmasked(const std::vector<Index> &indices, int sizes) {
return create_node(indices, std::vector<int>{sizes}, SNodeType::bitmasked);
}

SNode &bitmasked(const Index &index, int size) {
return SNode::bitmasked(std::vector<Index>{index}, size);
}

SNode &hash(const std::vector<Index> &indices,
const std::vector<int> &sizes) {
return create_node(indices, sizes, SNodeType::hash);
Expand Down Expand Up @@ -194,11 +206,6 @@ class SNode {
return *this;
}

SNode &bitmasked(bool val = true) {
_bitmasked = val;
return *this;
}

// for float and double
void write_float(const std::vector<int> &I, float64);
float64 read_float(const std::vector<int> &I);
Expand Down
6 changes: 3 additions & 3 deletions taichi/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ void compile_runtime_bitcode(Arch arch) {
if (ret) {
TI_ERROR("LLVMRuntime compilation failed.");
}
std::system(fmt::format("llvm-as {}runtime.ll -o {}{}", runtime_folder,
runtime_folder, get_runtime_fn(arch))
.c_str());
cmd = fmt::format("llvm-as {}runtime.ll -o {}{}", runtime_folder,
runtime_folder, get_runtime_fn(arch));
std::system(cmd.c_str());
TI_TRACE("runtime module bitcode compiled.");
runtime_compiled.insert((int)arch);
}
Expand Down
6 changes: 4 additions & 2 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ void export_lang(py::module &m) {
.def(py::init<>())
.def_readwrite("parent", &SNode::parent)
.def_readonly("type", &SNode::type)
.def_readonly("is_bitmasked", &SNode::_bitmasked)
.def("dense",
(SNode & (SNode::*)(const std::vector<Index> &,
const std::vector<int> &))(&SNode::dense),
Expand All @@ -158,7 +157,10 @@ void export_lang(py::module &m) {
py::return_value_policy::reference)
.def("dynamic", &SNode::dynamic_chunked,
py::return_value_policy::reference)
.def("bitmasked", &SNode::bitmasked)
.def("bitmasked",
(SNode & (SNode::*)(const std::vector<Index> &,
const std::vector<int> &))(&SNode::bitmasked),
py::return_value_policy::reference)
.def("place", (SNode & (SNode::*)(Expr &))(&SNode::place),
py::return_value_policy::reference)
.def("data_type", [](SNode *snode) { return snode->dt; })
Expand Down
34 changes: 34 additions & 0 deletions taichi/runtime/llvm/node_bitmasked.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

// Specialized Attributes and functions
struct BitmaskedMeta : public StructMeta {
bool _;
};

STRUCT_FIELD(BitmaskedMeta, _);

i32 Bitmasked_get_num_elements(Ptr meta, Ptr node) {
return ((StructMeta *)meta)->max_num_elements;
}

void Bitmasked_activate(Ptr meta, Ptr node, int i) {
auto smeta = (StructMeta *)meta;
auto element_size = StructMeta_get_element_size(smeta);
auto num_elements = Bitmasked_get_num_elements(meta, node);
auto data_section_size = element_size * num_elements;
auto mask_begin = (uint64 *)(node + data_section_size);
atomic_or_u64(&mask_begin[i / 64], 1UL << (i % 64));
}

i32 Bitmasked_is_active(Ptr meta, Ptr node, int i) {
auto smeta = (StructMeta *)meta;
auto element_size = StructMeta_get_element_size(smeta);
auto num_elements = Dense_get_num_elements(meta, node);
auto data_section_size = element_size * num_elements;
auto mask_begin = node + data_section_size;
return i32(bool((mask_begin[i / 8] >> (i % 8)) & 1));
}

Ptr Bitmasked_lookup_element(Ptr meta, Ptr node, int i) {
return node + ((StructMeta *)meta)->element_size * i;
}
24 changes: 2 additions & 22 deletions taichi/runtime/llvm/node_dense.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,21 @@

// Specialized Attributes and functions
struct DenseMeta : public StructMeta {
bool bitmasked;
int morton_dim;
};

STRUCT_FIELD(DenseMeta, bitmasked)
STRUCT_FIELD(DenseMeta, morton_dim)

i32 Dense_get_num_elements(Ptr meta, Ptr node) {
return ((StructMeta *)meta)->max_num_elements;
}

void Dense_activate(Ptr meta, Ptr node, int i) {
auto smeta = (StructMeta *)meta;
auto dmeta = (DenseMeta *)meta;
if (DenseMeta_get_bitmasked(dmeta)) {
auto element_size = StructMeta_get_element_size(smeta);
auto num_elements = Dense_get_num_elements(meta, node);
auto data_section_size = element_size * num_elements;
auto mask_begin = (uint64 *)(node + data_section_size);
atomic_or_u64(&mask_begin[i / 64], 1UL << (i % 64));
}
// Dense elements are always active
}

i32 Dense_is_active(Ptr meta, Ptr node, int i) {
auto smeta = (StructMeta *)meta;
auto dmeta = (DenseMeta *)meta;
if (DenseMeta_get_bitmasked(dmeta)) {
auto element_size = StructMeta_get_element_size(smeta);
auto num_elements = Dense_get_num_elements(meta, node);
auto data_section_size = element_size * num_elements;
auto mask_begin = node + data_section_size;
return i32(bool((mask_begin[i / 8] >> (i % 8)) & 1));
} else {
return 1;
}
return 1;
}

Ptr Dense_lookup_element(Ptr meta, Ptr node, int i) {
Expand Down
1 change: 1 addition & 0 deletions taichi/runtime/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ i32 linear_thread_idx() {
#include "node_dynamic.h"
#include "node_pointer.h"
#include "node_root.h"
#include "node_bitmasked.h"

void ListManager::touch_chunk(int chunk_id) {
if (!chunks[chunk_id]) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/struct/struct_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ void StructCompilerLLVM::generate_types(SNode &snode) {
snode_attr[snode].llvm_element_type = ch_type;

llvm::Type *body_type = nullptr, *aux_type = nullptr;
if (type == SNodeType::dense) {
if (type == SNodeType::dense || type == SNodeType::bitmasked) {
TI_ASSERT(snode._morton == false);
body_type = llvm::ArrayType::get(ch_type, snode.max_num_elements());
if (snode._bitmasked) {
if (type == SNodeType::bitmasked) {
aux_type = llvm::ArrayType::get(llvm::Type::getInt32Ty(*llvm_ctx),
(snode.max_num_elements() + 31) / 32);
}
Expand Down Expand Up @@ -206,7 +206,7 @@ std::unique_ptr<StructCompiler> StructCompiler::make(Program *prog, Arch arch) {

bool SNode::need_activation() const {
return type == SNodeType::pointer || type == SNodeType::hash ||
(type == SNodeType::dense && _bitmasked) || type == SNodeType::dynamic;
type == SNodeType::bitmasked || type == SNodeType::dynamic;
}

TLANG_NAMESPACE_END
23 changes: 13 additions & 10 deletions taichi/struct/struct_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ constexpr size_t kListManagerSize = sizeof(shaders::ListManager);
constexpr size_t kSNodeMetaSize = sizeof(shaders::SNodeMeta);
constexpr size_t kSNodeExtractorsSize = sizeof(shaders::SNodeExtractors);

inline bool is_bitmasked(const SNode &sn) {
return (sn.type == SNodeType::dense && sn._bitmasked);
}

inline size_t bitmasks_stride(int n) {
constexpr int kBitsPerByte = 8;
const int bytes_needed = iroundup(n, kBitsPerByte) / kBitsPerByte;
// The roundup is to align the stride to 8-bytes.
return iroundup(bytes_needed, 8);
}

inline int get_n(const SNode &sn) {
// For root, sn.n is 0.
return sn.type == SNodeType::root ? 1 : sn.n;
}

class StructCompiler {
public:
StructCompiledResult run(SNode &root) {
Expand All @@ -54,7 +55,8 @@ class StructCompiler {
{
max_snodes_ = 0;
for (const auto &sn : snodes_) {
if (sn->type == SNodeType::root || sn->type == SNodeType::dense) {
if (sn->type == SNodeType::root || sn->type == SNodeType::dense ||
sn->type == SNodeType::bitmasked) {
max_snodes_ = std::max(max_snodes_, sn->id);
}
}
Expand Down Expand Up @@ -132,12 +134,13 @@ class StructCompiler {
emit(" device {}* val;", dt_name);
emit("}};");
} else if (snode.type == SNodeType::dense ||
snode.type == SNodeType::root) {
const bool bitmasked = is_bitmasked(snode);
snode.type == SNodeType::root ||
snode.type == SNodeType::bitmasked) {
const bool bitmasked = snode.type == SNodeType::bitmasked;
const std::string ch_name = fmt::format("{}_ch", node_name);
emit("struct {} {{", node_name);
emit(" // {}", snode_type_name(snode.type));
const int n = (snode.type == SNodeType::dense) ? snode.n : 1;
const int n = get_n(snode);
emit(" constant static constexpr int n = {};", n);
if (bitmasked) {
emit(
Expand Down Expand Up @@ -170,7 +173,7 @@ class StructCompiler {
return metal_data_type_bytes(to_metal_type(sn->dt));
}

const int n = (sn->type == SNodeType::dense) ? sn->n : 1;
const int n = get_n(*sn);
size_t ch_size = 0;
for (const auto &ch : sn->ch) {
const size_t ch_offset = ch_size;
Expand All @@ -186,7 +189,7 @@ class StructCompiler {
sn_desc.element_stride = ch_size;
sn_desc.num_slots = n;
sn_desc.stride = ch_size * n;
if (is_bitmasked(*sn)) {
if (sn->type == SNodeType::bitmasked) {
sn_desc.stride += bitmasks_stride(n);
}
sn_desc.total_num_elems_from_root = 1;
Expand Down
Loading

0 comments on commit 4443534

Please sign in to comment.