Skip to content

Commit

Permalink
[bug] Fix numerical issue with TensorType'd arithmetics (taichi-dev#7526
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 authored and quadpixels committed May 13, 2023
1 parent 4676bd9 commit 7df2602
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
5 changes: 3 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1897,14 +1897,15 @@ void TaskCodeGenLLVM::visit(ExternalPtrStmt *stmt) {
auto address_offset = builder->CreateSExt(
linear_index, llvm::Type::getInt64Ty(*llvm_context));

if (stmt->ret_type->is<TensorType>()) {
auto stmt_ret_type = stmt->ret_type.ptr_removed();
if (stmt_ret_type->is<TensorType>()) {
// This case corresponds to outter indexing only
// The stride for linear_index is num_elements() in TensorType.
address_offset = builder->CreateMul(
address_offset,
tlctx->get_constant(
get_data_type<int64>(),
stmt->ret_type->cast<TensorType>()->get_num_elements()));
stmt_ret_type->cast<TensorType>()->get_num_elements()));
} else {
// This case corresponds to outter + inner indexing
// Since both outter and inner indices are linearized into linear_index,
Expand Down
34 changes: 34 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,3 +1224,37 @@ def foo():
assert a == 2.5

foo()


@test_utils.test(arch=[ti.cpu, ti.cuda], real_matrix_scalarize=False)
def test_matrix_arithmatics():
f = ti.ndarray(ti.math.vec4, 4)

@ti.kernel
def fill(arr: ti.types.ndarray()):
v0 = ti.math.vec4([0.0, 1.0, 2.0, 3.0])
v1 = ti.math.vec4([1.0, 2.0, 3.0, 4.0])
v2 = ti.math.vec4([2.0, 3.0, 4.0, 5.0])
v3 = ti.math.vec4([4.0, 5.0, 6.0, 7.0])
arr[0] = v0
arr[1] = v1
arr[2] = v2
arr[3] = v3

@ti.kernel
def vec_test(arr: ti.types.ndarray()):
v0 = arr[0]
v1 = arr[1]
v2 = arr[2]
v3 = arr[3]

arr[0] = v0 * v1 + v2
arr[1] = v1 * v2 + v3
arr[2] = v0 * v2 + v3

fill(f)
vec_test(f)

assert (f.to_numpy() == np.array([[2., 5., 10., 17.], [6., 11., 18., 27.],
[4., 8., 14., 22.], [4., 5., 6.,
7.]])).all()

0 comments on commit 7df2602

Please sign in to comment.