Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Add ti.Vector.ndarray and ti.Matrix.ndarray #2808

Merged
merged 13 commits into from
Aug 30, 2021
3 changes: 2 additions & 1 deletion .github/workflows/presubmit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ jobs:
$PYTHON examples/algorithm/laplace.py
ti diagnose
ti changelog
ti test -vr2 -t2
ti test -vr2 -t2 -k "not ndarray"
ti test -vr2 -t1 -k "ndarray"

build_and_test_windows:
name: Build and Test (Windows)
Expand Down
4 changes: 3 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,14 +602,16 @@ def ndarray(dtype, shape):
"""Defines a Taichi ndarray with scalar elements.

Args:
dtype (DataType): Data type of the ndarray.
dtype (DataType): Data type of each value.
shape (Union[int, tuple[int]]): Shape of the ndarray.

Example:
The code below shows how a Taichi ndarray with scalar elements can be declared and defined::

>>> x = ti.ndarray(ti.f32, shape=(16, 8))
"""
if isinstance(shape, numbers.Number):
shape = (shape, )
return ScalarNdarray(dtype, shape)


Expand Down
20 changes: 19 additions & 1 deletion python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ArgAnyArray:
"""Type annotation for arbitrary arrays, including external arrays and Taichi ndarrays.

For external arrays, we can treat it as a Taichi field with Vector or Matrix elements by specifying element shape and layout.
For Taichi vector/matrix ndarrays, we need to specify element shape and layout for type checking.

Args:
element_shape (Tuple[Int], optional): () if scalar elements (default), (n) if vector elements, and (n, m) if matrix elements.
Expand All @@ -57,8 +58,25 @@ def __init__(self, element_shape=(), layout=Layout.AOS):
self.layout = layout

def extract(self, x):
shape = tuple(x.shape)
from taichi.lang.matrix import MatrixNdarray, VectorNdarray
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to top? (or will it cause circular import?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will cause circular import :-(

from taichi.lang.ndarray import Ndarray, ScalarNdarray
element_dim = len(self.element_shape)
if isinstance(x, Ndarray):
if isinstance(x, ScalarNdarray) and element_dim != 0:
raise ValueError("Invalid argument passed to ti.any_arr()")
if isinstance(x, VectorNdarray) and (element_dim != 1 or
self.element_shape[0] != x.n
or self.layout != x.layout):
raise ValueError("Invalid argument passed to ti.any_arr()")
if isinstance(x,
MatrixNdarray) and (element_dim != 2
or self.element_shape[0] != x.n
or self.element_shape[1] != x.m
or self.layout != x.layout):
raise ValueError("Invalid argument passed to ti.any_arr()")
return x.dtype, len(
x.shape) + element_dim, self.element_shape, self.layout
shape = tuple(x.shape)
if len(shape) < element_dim:
raise ValueError("Invalid argument passed to ti.any_arr()")
if element_dim > 0:
Expand Down
6 changes: 3 additions & 3 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.kernel_arguments import (any_arr, ext_arr,
sparse_matrix_builder, template)
from taichi.lang.ndarray import ScalarNdarray
from taichi.lang.ndarray import Ndarray
from taichi.lang.shell import _shell_pop_print, oinspect
from taichi.lang.transformer import ASTTransformerTotal
from taichi.misc.util import obsolete
Expand Down Expand Up @@ -490,8 +490,8 @@ def func__(*args):
# Pass only the base pointer of the ti.sparse_matrix_builder() argument
launch_ctx.set_arg_int(actual_argument_slot, v.get_addr())
elif (isinstance(needed, (any_arr, ext_arr)) and self.match_ext_arr(v)) or \
(isinstance(needed, any_arr) and isinstance(v, ScalarNdarray)):
if isinstance(v, ScalarNdarray):
(isinstance(needed, any_arr) and isinstance(v, Ndarray)):
if isinstance(v, Ndarray):
v = v.arr
has_external_arrays = True
has_torch = util.has_pytorch()
Expand Down
140 changes: 140 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.ndarray import Ndarray, NdarrayHostAccess
from taichi.lang.ops import cast
from taichi.lang.types import CompoundType
from taichi.lang.util import (cook_dtype, in_python_scope, is_taichi_class,
Expand Down Expand Up @@ -316,6 +317,8 @@ def __call__(self, *args, **kwargs):
ret = self.entries[self.linearize_entry_id(*args)]
if isinstance(ret, SNodeHostAccess):
ret = ret.accessor.getter(*ret.key)
elif isinstance(ret, NdarrayHostAccess):
ret = ret.getter()
return ret

def set_entry(self, i, j, e):
Expand All @@ -325,6 +328,8 @@ def set_entry(self, i, j, e):
else:
if isinstance(self.entries[idx], SNodeHostAccess):
self.entries[idx].accessor.setter(e, *self.entries[idx].key)
elif isinstance(self.entries[idx], NdarrayHostAccess):
self.entries[idx].setter(e)
else:
self.entries[idx] = e

Expand Down Expand Up @@ -1035,6 +1040,47 @@ def _Vector_var(cls, n, dt, *args, **kwargs):
_taichi_skip_traceback = 1
return cls._Vector_field(n, dt, *args, **kwargs)

@classmethod
@python_scope
def ndarray(cls, n, m, dtype, shape, layout=Layout.AOS):
"""Defines a Taichi ndarray with matrix elements.

Args:
n (int): Number of rows of the matrix.
m (int): Number of columns of the matrix.
dtype (DataType): Data type of each value.
shape (Union[int, tuple[int]]): Shape of the ndarray.
layout (Layout, optional): Memory layout, AOS by default.

Example:
The code below shows how a Taichi ndarray with matrix elements can be declared and defined::

>>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
"""
if isinstance(shape, numbers.Number):
shape = (shape, )
return MatrixNdarray(n, m, dtype, shape, layout)

@classmethod
@python_scope
def _Vector_ndarray(cls, n, dtype, shape, layout=Layout.AOS):
"""Defines a Taichi ndarray with vector elements.

Args:
n (int): Size of the vector.
dtype (DataType): Data type of each value.
shape (Union[int, tuple[int]]): Shape of the ndarray.
layout (Layout, optional): Memory layout, AOS by default.

Example:
The code below shows how a Taichi ndarray with vector elements can be declared and defined::

>>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
"""
if isinstance(shape, numbers.Number):
shape = (shape, )
return VectorNdarray(n, dtype, shape, layout)

@staticmethod
def rows(rows):
"""Construct a Matrix instance by concactinating Vectors/lists row by row.
Expand Down Expand Up @@ -1220,6 +1266,7 @@ def Vector(n, dt=None, shape=None, offset=None, **kwargs):

Vector.var = Matrix._Vector_var
Vector.field = Matrix._Vector_field
Vector.ndarray = Matrix._Vector_ndarray
Vector.zero = Matrix.zero
Vector.one = Matrix.one
Vector.dot = Matrix.dot
Expand Down Expand Up @@ -1435,3 +1482,96 @@ def empty(self):

def field(self, **kwargs):
return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)


class MatrixNdarray(Ndarray):
"""Taichi ndarray with matrix elements implemented with a torch tensor.

Args:
n (int): Number of rows of the matrix.
m (int): Number of columns of the matrix.
dtype (DataType): Data type of each value.
shape (Union[int, tuple[int]]): Shape of the ndarray.
layout (Layout): Memory layout.
"""
def __init__(self, n, m, dtype, shape, layout):
self.layout = layout
arr_shape = (n, m) + shape if layout == Layout.SOA else shape + (n, m)
super().__init__(dtype, arr_shape)

@property
def n(self):
return self.arr.shape[0 if self.layout == Layout.SOA else -2]

@property
def m(self):
return self.arr.shape[1 if self.layout == Layout.SOA else -1]

@property
def shape(self):
arr_shape = tuple(self.arr.shape)
return arr_shape[2:] if self.layout == Layout.SOA else arr_shape[:-2]

@python_scope
def __setitem__(self, key, value):
if not isinstance(value, (list, tuple)):
value = list(value)
if not isinstance(value[0], (list, tuple)):
value = [[i] for i in value]
for i in range(self.n):
for j in range(self.m):
self[key][i, j] = value[i][j]

@python_scope
def __getitem__(self, key):
key = () if key is None else (
key, ) if isinstance(key, numbers.Number) else tuple(key)
return Matrix.with_entries(self.n, self.m, [
NdarrayHostAccess(self, key, (i, j)) for i in range(self.n)
for j in range(self.m)
])

def __repr__(self):
return f'<{self.n}x{self.m} {self.layout} ti.Matrix.ndarray>'


class VectorNdarray(Ndarray):
"""Taichi ndarray with vector elements implemented with a torch tensor.

Args:
n (int): Size of the vector.
dtype (DataType): Data type of each value.
shape (Tuple[int]): Shape of the ndarray.
layout (Layout): Memory layout.
"""
def __init__(self, n, dtype, shape, layout):
self.layout = layout
arr_shape = (n, ) + shape if layout == Layout.SOA else shape + (n, )
super().__init__(dtype, arr_shape)

@property
def n(self):
return self.arr.shape[0 if self.layout == Layout.SOA else -1]

@property
def shape(self):
arr_shape = tuple(self.arr.shape)
return arr_shape[1:] if self.layout == Layout.SOA else arr_shape[:-1]

@python_scope
def __setitem__(self, key, value):
if not isinstance(value, (list, tuple)):
value = list(value)
for i in range(self.n):
self[key][i] = value[i]

@python_scope
def __getitem__(self, key):
key = () if key is None else (
key, ) if isinstance(key, numbers.Number) else tuple(key)
return Matrix.with_entries(
self.n, 1,
[NdarrayHostAccess(self, key, (i, )) for i in range(self.n)])

def __repr__(self):
ailzhang marked this conversation as resolved.
Show resolved Hide resolved
return f'<{self.n} {self.layout} ti.Vector.ndarray>'
44 changes: 31 additions & 13 deletions python/taichi/lang/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
import numbers

from taichi.core.util import ti_core as _ti_core
from taichi.lang import impl
from taichi.lang.util import (has_pytorch, python_scope, to_pytorch_type,
to_taichi_type)
from taichi.lang.enums import Layout
from taichi.lang.util import (cook_dtype, has_pytorch, python_scope,
to_pytorch_type, to_taichi_type)


class Ndarray:
"""Taichi ndarray class implemented with a torch tensor.

Args:
dtype (DataType): Data type of the ndarray.
shape (Union[int, tuple[int]]): Shape of the torch tensor.
dtype (DataType): Data type of each value.
shape (Tuple[int]): Shape of the torch tensor.
"""
def __init__(self, dtype, shape):
if isinstance(shape, numbers.Number):
shape = (shape, )
assert has_pytorch(
), "PyTorch must be available if you want to create a Taichi ndarray."
import torch
if impl.current_cfg().arch == _ti_core.Arch.cuda:
device = 'cuda:0'
else:
device = 'cpu'
self.arr = torch.empty(shape,
dtype=to_pytorch_type(dtype),
self.arr = torch.zeros(shape,
dtype=to_pytorch_type(cook_dtype(dtype)),
device=device)

@property
Expand Down Expand Up @@ -72,8 +69,8 @@ class ScalarNdarray(Ndarray):
"""Taichi ndarray with scalar elements implemented with a torch tensor.

Args:
dtype (DataType): Data type of the ndarray.
shape (Union[int, tuple[int]]): Shape of the ndarray.
dtype (DataType): Data type of each value.
shape (Tuple[int]): Shape of the ndarray.
"""
def __init__(self, dtype, shape):
super().__init__(dtype, shape)
Expand All @@ -84,11 +81,32 @@ def shape(self):

@python_scope
def __setitem__(self, key, value):
return self.arr.__setitem__(key, value)
self.arr.__setitem__(key, value)

@python_scope
def __getitem__(self, key):
return self.arr.__getitem__(key)

def __repr__(self):
return '<ti.ndarray>'


class NdarrayHostAccess:
"""Class for accessing VectorNdarray/MatrixNdarray in Python scope.
Args:
arr (Union[VectorNdarray, MatrixNdarray]): See above.
indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
"""
def __init__(self, arr, indices_first, indices_second):
self.arr = arr.arr
if arr.layout == Layout.SOA:
self.indices = indices_second + indices_first
else:
self.indices = indices_first + indices_second

def getter(self):
return self.arr[self.indices]

def setter(self, value):
self.arr[self.indices] = value
Loading