Skip to content

Commit

Permalink
Merge branch 'master' of github.com:taichi-dev/taichi into fix_8057
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed Jun 1, 2023
2 parents da6255e + b3b7c64 commit b9de29f
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 821 deletions.
8 changes: 8 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,11 @@ def field(self, **kwargs):
kwargs.update({"ndim": self.ndim})
return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)

def ndarray(self, **kwargs):
assert kwargs.get("ndim", self.ndim) == self.ndim
kwargs.update({"ndim": self.ndim})
return Matrix.ndarray(self.n, self.m, dtype=self.dtype, **kwargs)

def get_shape(self):
if self.ndim == 1:
return (self.n,)
Expand Down Expand Up @@ -1598,6 +1603,9 @@ def _instantiate(self, entries):
def field(self, **kwargs):
return Vector.field(self.n, dtype=self.dtype, **kwargs)

def ndarray(self, **kwargs):
return Vector.ndarray(self.n, dtype=self.dtype, **kwargs)

def to_string(self):
dtype_str = self.dtype.to_string() if self.dtype is not None else ""
return f"VectorType[{self.n}, {dtype_str}]"
Expand Down
10 changes: 9 additions & 1 deletion python/taichi/linalg/matrixfree_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,15 @@ def MatrixFreeCG(A, b, x, tol=1e-6, maxiter=5000, quiet=True):
p = ti.field(dtype=solver_dtype)
r = ti.field(dtype=solver_dtype)
Ap = ti.field(dtype=solver_dtype)
vector_fields_builder.dense(ti.ij, size).place(p, r, Ap)
if len(size) == 1:
axes = ti.i
elif len(size) == 2:
axes = ti.ij
elif len(size) == 3:
axes = ti.ijk
else:
raise TaichiRuntimeError(f"MatrixFreeCG only support 1D, 2D, 3D inputs; your inputs is {len(size)}-D.")
vector_fields_builder.dense(axes, size).place(p, r, Ap)
vector_fields_snode_tree = vector_fields_builder.finalize()

scalar_builder = ti.FieldsBuilder()
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ def build(self, dtype=f32, _format="CSR"):
taichi_arch = get_runtime().prog.config().arch
if taichi_arch in [_ti_core.Arch.x64, _ti_core.Arch.arm64]:
sm = self.ptr.build()
return SparseMatrix(sm=sm, dtype=dtype)
return SparseMatrix(sm=sm, dtype=self.dtype)
if taichi_arch == _ti_core.Arch.cuda:
if dtype != f32:
if self.dtype != f32:
raise TaichiRuntimeError("CUDA sparse matrix only supports f32.")
sm = self.ptr.build_cuda()
return SparseMatrix(sm=sm, dtype=dtype)
return SparseMatrix(sm=sm, dtype=self.dtype)
raise TaichiRuntimeError("Sparse matrix only supports CPU and CUDA backends.")


Expand Down
5 changes: 5 additions & 0 deletions python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SparseSolver:

def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
self.matrix = None
self.dtype = dtype
solver_type_list = ["LLT", "LDLT", "LU"]
solver_ordering = ["AMD", "COLAMD"]
if solver_type in solver_type_list and ordering in solver_ordering:
Expand Down Expand Up @@ -70,6 +71,10 @@ def analyze_pattern(self, sparse_matrix):
"""
if isinstance(sparse_matrix, SparseMatrix):
self.matrix = sparse_matrix
if self.matrix.dtype != self.dtype:
raise TaichiRuntimeError(
f"The SparseSolver's dtype {self.dtype} is not consistent with the SparseMatrix's dtype {self.matrix.dtype}."
)
self.solver.analyze_pattern(sparse_matrix.matrix)
else:
self._type_assert(sparse_matrix)
Expand Down
15 changes: 14 additions & 1 deletion taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,10 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) {
auto data_source_ptrs = irpass::analysis::get_store_destination(stmt);
for (auto data_source_ptr : data_source_ptrs) {
// stmt provides a data source
if (after_lower_access && !(data_source_ptr->is<AllocaStmt>())) {
if (after_lower_access &&
!((data_source_ptr->is<MatrixPtrStmt>() &&
data_source_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
data_source_ptr->is<AllocaStmt>())) {
// After lower_access, we only analyze local variables.
continue;
}
Expand Down Expand Up @@ -552,6 +555,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
irpass::analysis::get_load_pointers(stmt, true /*get_alias*/);
for (auto &load_ptr : load_ptrs) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
if (!contain_variable(live_kill, load_ptr)) {
Expand All @@ -576,6 +581,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) {
}
for (auto store_ptr : store_ptrs) {
if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
live_kill.insert(store_ptr);
Expand Down Expand Up @@ -707,6 +714,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
auto store_ptr = *store_ptrs.begin();

if (!after_lower_access ||
(store_ptr->is<MatrixPtrStmt>() &&
store_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<AdStackAllocaStmt>())) {
// !may_contain_variable(live_in_this_node, store_ptr): address is not
// loaded after this store
Expand Down Expand Up @@ -806,6 +815,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
auto load_ptr = load_ptrs.begin()[0];

if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// live_load_in_this_node[addr]: tracks the
// next load to the same address
Expand All @@ -832,6 +843,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
// Update live_in_this_node
for (auto &load_ptr : load_ptrs) {
if (!after_lower_access ||
(load_ptr->is<MatrixPtrStmt>() &&
load_ptr->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<AdStackAllocaStmt>())) {
// Addr is used in this node, so it's live in this node
update_container_with_alias(tensor_to_matrix_ptrs_map,
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class Block : public IRNode {
// variables, and AllocaStmt for other variables.
std::map<Identifier, Stmt *> local_var_to_stmt;

Block(Kernel *kernel = nullptr) {
explicit Block(Kernel *kernel = nullptr) {
parent_ = kernel;
}

Expand Down
3 changes: 1 addition & 2 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ namespace irpass {
void re_id(IRNode *root);
void flag_access(IRNode *root);
void eliminate_immutable_local_vars(IRNode *root);
bool scalarize(IRNode *root);
void vectorize_half2(IRNode *root);
bool scalarize(IRNode *root, bool half2_optimization_enabled = false);
void lower_matrix_ptr(IRNode *root);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
Expand Down
11 changes: 11 additions & 0 deletions taichi/program/ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/program/ndarray.h"
#include "taichi/program/program.h"
#include "fp16.h"

#ifdef TI_WITH_LLVM
#include "taichi/runtime/llvm/llvm_context.h"
Expand Down Expand Up @@ -168,10 +169,20 @@ TypedConstant Ndarray::read(const std::vector<int> &I) const {
TypedConstant data(get_element_data_type());
std::memcpy(&data.value_bits, device_arr_ptr, size);
staging_buf_->device->unmap(*staging_buf_);

if (get_element_data_type()->is_primitive(PrimitiveTypeID::f16)) {
float float32 = fp16_ieee_to_fp32_value(data.val_u16);
data.val_f32 = float32;
}
return data;
}

void Ndarray::write(const std::vector<int> &I, TypedConstant val) const {
if (get_element_data_type()->is_primitive(PrimitiveTypeID::f16)) {
uint16_t float16 = fp16_ieee_from_fp32_value(val.val_f32);
std::memcpy(&val.value_bits, &float16, 4);
}

size_t index = flatten_index(total_shape_, I);
size_t size_ = data_type_size(get_element_data_type());
taichi::lang::Device::AllocParams alloc_params;
Expand Down
30 changes: 10 additions & 20 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,6 @@ void offload_to_executable(IRNode *ir,
irpass::demote_operations(ir, config);
print("Operations demoted");

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Simplified IV");
Expand All @@ -294,17 +285,16 @@ void offload_to_executable(IRNode *ir,
print("Bit struct stores optimized");
}

if (config.arch == Arch::cuda && config.half2_vectorization &&
!get_custom_cuda_library_path().empty()) {
irpass::vectorize_half2(ir);

irpass::type_check(ir, config);

irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});

irpass::flag_access(ir);
print("Half2 vectorized");
bool half2_optimization_enabled =
(config.arch == Arch::cuda && config.half2_vectorization &&
!get_custom_cuda_library_path().empty());
if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir, half2_optimization_enabled)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

// Final field registration correctness & type checking
Expand Down
98 changes: 92 additions & 6 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ class Scalarize : public BasicStmtVisitor {
public:
ImmediateIRModifier immediate_modifier_;
DelayedIRModifier delayed_modifier_;
bool half2_optimization_enabled_ = false;

explicit Scalarize(IRNode *node) : immediate_modifier_(node) {
explicit Scalarize(IRNode *node, bool half2_optimization)
: immediate_modifier_(node),
half2_optimization_enabled_(half2_optimization) {
}

/*
Expand Down Expand Up @@ -414,7 +417,90 @@ class Scalarize : public BasicStmtVisitor {
void visit(AtomicOpStmt *stmt) override {
auto dest_dtype = stmt->dest->ret_type.ptr_removed();
auto val_dtype = stmt->val->ret_type;
if (dest_dtype->is<TensorType>() || val_dtype->is<TensorType>()) {

bool half2_optimizable = half2_optimization_enabled_;
bool is_tensor_type = false;
if (dest_dtype->is<TensorType>()) {
is_tensor_type = true;
half2_optimizable &=
(dest_dtype.get_element_type() == PrimitiveType::f16);
half2_optimizable &=
(dest_dtype->as<TensorType>()->get_num_elements() == 2);
} else {
half2_optimizable = false;
}
is_tensor_type |= val_dtype->is<TensorType>();

if (half2_optimizable) {
/*
Before:
TensorType<2 x i32> old_val = AtomicStmt(TensorType<2 x i32>* dest,
TensorType<2 x i32> val)
After:
TensorType<2 x i32> old_val = AtomicStmt(TensorType<2 x i32>* dest,
TensorType<2 x i32> val)
TensorType(2, f16)* old_val_alloc = AllocaStmt(TensorType(2, f16))
LocalStoreStmt(old_val_alloc, old_val)
f16* old_val_ptr0 = MatrixPtrStmt(old_val_alloc, 0)
f16* old_val_ptr1 = MatrixPtrStmt(old_val_alloc, 0)
f16 old_val0 = LoadStmt(old_val_ptr0)
f16 old_val1 = LoadStmt(old_val_ptr1)
tmp = MatrixInitStmt(old_val0, old_val1)
stmt->replace_all_usages_with(tmp)
*/
auto atomic_stmt =
std::make_unique<AtomicOpStmt>(stmt->op_type, stmt->dest, stmt->val);
atomic_stmt->ret_type = stmt->ret_type;

auto alloca_stmt = std::make_unique<AllocaStmt>(dest_dtype);

auto local_store_stmt = std::make_unique<LocalStoreStmt>(
alloca_stmt.get(), atomic_stmt.get());

auto zero =
std::make_unique<ConstStmt>(TypedConstant(PrimitiveType::i32, 0));
auto one =
std::make_unique<ConstStmt>(TypedConstant(PrimitiveType::i32, 1));

auto matrix_ptr_0 =
std::make_unique<MatrixPtrStmt>(alloca_stmt.get(), zero.get());
auto matrix_ptr_1 =
std::make_unique<MatrixPtrStmt>(alloca_stmt.get(), one.get());
matrix_ptr_0->ret_type = PrimitiveType::f16;
matrix_ptr_0->ret_type.set_is_pointer(true);
matrix_ptr_1->ret_type = PrimitiveType::f16;
matrix_ptr_1->ret_type.set_is_pointer(true);

auto load_stmt_0 = std::make_unique<LocalLoadStmt>(matrix_ptr_0.get());
auto load_stmt_1 = std::make_unique<LocalLoadStmt>(matrix_ptr_1.get());
load_stmt_0->ret_type = PrimitiveType::f16;
load_stmt_1->ret_type = PrimitiveType::f16;

auto matrix_init_stmt = std::make_unique<MatrixInitStmt>(
std::vector<Stmt *>{load_stmt_0.get(), load_stmt_1.get()});
matrix_init_stmt->ret_type = stmt->ret_type;

immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(atomic_stmt));
delayed_modifier_.insert_before(stmt, std::move(alloca_stmt));
delayed_modifier_.insert_before(stmt, std::move(local_store_stmt));
delayed_modifier_.insert_before(stmt, std::move(zero));
delayed_modifier_.insert_before(stmt, std::move(one));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_0));
delayed_modifier_.insert_before(stmt, std::move(matrix_ptr_1));
delayed_modifier_.insert_before(stmt, std::move(load_stmt_0));
delayed_modifier_.insert_before(stmt, std::move(load_stmt_1));
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

delayed_modifier_.erase(stmt);

} else if (is_tensor_type) {
// Make sure broadcasting has been correctly applied by
// AtomicOpExpression::type_check().
TI_ASSERT(dest_dtype->is<TensorType>() && val_dtype->is<TensorType>());
Expand Down Expand Up @@ -838,8 +924,8 @@ class Scalarize : public BasicStmtVisitor {
}
}

static bool run(IRNode *node) {
Scalarize pass(node);
static bool run(IRNode *node, bool half2_optimization_enabled) {
Scalarize pass(node, half2_optimization_enabled);
node->accept(&pass);
return pass.delayed_modifier_.modify_ir();
}
Expand Down Expand Up @@ -1192,11 +1278,11 @@ class MergeExternalAndMatrixPtr : public BasicStmtVisitor {

namespace irpass {

bool scalarize(IRNode *root) {
bool scalarize(IRNode *root, bool half2_optimization_enabled) {
TI_AUTO_PROF;
bool modified = false;

modified |= Scalarize::run(root);
modified |= Scalarize::run(root, half2_optimization_enabled);
auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root);
modified |= ScalarizePointers::run(root, scalarizable_allocas);
modified |= ExtractLocalPointers::run(root);
Expand Down
Loading

0 comments on commit b9de29f

Please sign in to comment.