Skip to content

Commit

Permalink
[Type] Support different element types for Matrix (#2135)
Browse files Browse the repository at this point in the history
Co-authored-by: Yuanming Hu <[email protected]>
  • Loading branch information
Hanke98 and yuanming-hu authored Jan 2, 2021
1 parent ff4dbdc commit b3f21f2
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 2 deletions.
22 changes: 20 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,26 @@ def field(cls,
self.n = n
self.m = m
self.dt = dtype
for i in range(n * m):
self.entries.append(impl.field(dtype))

if isinstance(dtype, (list, tuple, np.ndarray)):
# set different dtype for each element in Matrix
# see #2135
if m == 1:
assert len(np.shape(dtype)) == 1 and len(
dtype
) == n, f'Please set correct dtype list for Vector. The shape of dtype list should be ({n}, ) instead of {np.shape(dtype)}'
for i in range(n):
self.entries.append(impl.field(dtype[i]))
else:
assert len(np.shape(dtype)) == 2 and len(dtype) == n and len(
dtype[0]
) == m, f'Please set correct dtype list for Matrix. The shape of dtype list should be ({n}, {m}) instead of {np.shape(dtype)}'
for i in range(n):
for j in range(m):
self.entries.append(impl.field(dtype[i][j]))
else:
for _ in range(n * m):
self.entries.append(impl.field(dtype))
self.grad = self.make_grad()

if layout is not None:
Expand Down
99 changes: 99 additions & 0 deletions tests/python/test_matrix_different_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import taichi as ti
from pytest import approx


# TODO: test more matrix operations
@ti.test()
def test_vector():
type_list = [ti.f32, ti.i32]

a = ti.Vector.field(len(type_list), dtype=type_list, shape=())
b = ti.Vector.field(len(type_list), dtype=type_list, shape=())
c = ti.Vector.field(len(type_list), dtype=type_list, shape=())

@ti.kernel
def init():
a[None] = [1.0, 3]
b[None] = [2.0, 4]
c[None] = a[None] + b[None]

def verify():
assert isinstance(a[None][0], float)
assert isinstance(a[None][1], int)
assert isinstance(b[None][0], float)
assert isinstance(b[None][1], int)
assert c[None][0] == 3.0
assert c[None][1] == 7

init()
verify()


# TODO: Support different element types of Matrix on opengl
@ti.test(exclude=ti.opengl)
def test_matrix():
type_list = [[ti.f32, ti.i32], [ti.i64, ti.f32]]
a = ti.Matrix.field(len(type_list),
len(type_list[0]),
dtype=type_list,
shape=())
b = ti.Matrix.field(len(type_list),
len(type_list[0]),
dtype=type_list,
shape=())
c = ti.Matrix.field(len(type_list),
len(type_list[0]),
dtype=type_list,
shape=())

@ti.kernel
def init():
a[None] = [[1.0, 3], [1, 3.0]]
b[None] = [[2.0, 4], [-2, -3.0]]
c[None] = a[None] + b[None]

def verify():
assert isinstance(a[None][0], float)
assert isinstance(a[None][1], int)
assert isinstance(b[None][0], float)
assert isinstance(b[None][1], int)
assert c[None][0, 0] == 3.0
assert c[None][0, 1] == 7
assert c[None][1, 0] == -1
assert c[None][1, 1] == 0.0

init()
verify()


@ti.test(require=ti.extension.quant)
def test_custom_type():
cit1 = ti.type_factory.custom_int(bits=10, signed=True)
cft1 = ti.type_factory.custom_float(cit1, scale=0.1)
cit2 = ti.type_factory.custom_int(bits=22, signed=False)
cft2 = ti.type_factory.custom_float(cit2, scale=0.1)
type_list = [[cit1, cft2], [cft1, cit2]]
a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
c = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list)
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(a(0, 0), a(0, 1))
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(a(1, 0), a(1, 1))
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(b(0, 0), b(0, 1))
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(b(1, 0), b(1, 1))
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(c(0, 0), c(0, 1))
ti.root.dense(ti.i, 1)._bit_struct(num_bits=32).place(c(1, 0), c(1, 1))

@ti.kernel
def init():
a[0] = [[1, 3.], [2., 1]]
b[0] = [[2, 4.], [-2., 1]]
c[0] = a[0] + b[0]

def verify():
assert c[0][0, 0] == approx(3, 1e-3)
assert c[0][0, 1] == approx(7.0, 1e-3)
assert c[0][1, 0] == approx(0, 1e-3)
assert c[0][1, 1] == approx(2, 1e-3)

init()
verify()

0 comments on commit b3f21f2

Please sign in to comment.