Skip to content

Commit

Permalink
[lang] [test] Improve code coverage in SNode (#1214)
Browse files Browse the repository at this point in the history
* [lang] [test] Improve code coverage in Matrix and fix some methods

* [skip ci] enforce code format

* Improve tensor_reflection test

* Better SNode with __repr__

* [skip ci] apply

* [skip ci] Apply suggestions from code review

Co-authored-by: Ye Kuang <[email protected]>

* revert k-ye

* Revert {Expr => PyTaichi}.materalized_layout_callback

* Revert "[lang] [test] Improve code coverage in Matrix"

This reverts commit 00a05b2.

* [skip ci] Fix Matrix.snode() not found

fix

* [skip ci] loop_range().snode()

* improve & deprecate duplicated functions

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Ye Kuang <[email protected]>
  • Loading branch information
3 people authored Jun 23, 2020
1 parent 850ba56 commit 307b676
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 43 deletions.
65 changes: 51 additions & 14 deletions docs/snode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ See :ref:`layout` for more details. ``ti.root`` is the root node of the data str
x = ti.var(dt=ti.i32)
y = ti.var(dt=ti.f32)
ti.root.place(x, y)
assert x.snode() == y.snode()


.. function:: tensor.shape()

:parameter tensor: (Tensor)
:return: (tuple of integers) the shape of tensor

Equivalent to ``tensor.snode().shape()``.

For example,

::
Expand All @@ -44,33 +48,61 @@ See :ref:`layout` for more details. ``ti.root`` is the root node of the data str
x.shape() # returns (3, 5, 4)


.. function:: snode.get_shape(index)
.. function:: tensor.dim()

:parameter snode: (SNode)
:parameter index: axis (0 for ``i`` and 1 for ``j``)
:return: (scalar) the size of tensor along that axis
:parameter tensor: (Tensor)
:return: (scalar) the dimensionality of the tensor

Equivalent to ``tensor.shape()[i]``.
Equivalent to ``len(tensor.shape())``.

::

ti.root.dense(ti.ijk, (3, 5, 4)).place(x)
x.snode().get_shape(0) # 3
x.snode().get_shape(1) # 5
x.snode().get_shape(2) # 4
ti.root.dense(ti.ijk, (8, 9, 10)).place(x)
x.dim() # 3


.. function:: tensor.dim()
.. function:: tensor.snode()

:parameter tensor: (Tensor)
:return: (scalar) the dimensionality of the tensor
:return: (SNode) the structual node where ``tensor`` is placed

Equivalent to ``len(tensor.shape())``.
::

x = ti.var(dt=ti.i32)
y = ti.var(dt=ti.f32)
ti.root.place(x, y)
x.snode()


.. function:: snode.shape()

:parameter snode: (SNode)
:return: (tuple) the size of node along that axis

::

ti.root.dense(ti.ijk, (8, 9, 10)).place(x)
x.dim() # 3
blk1 = ti.root
blk2 = blk1.dense(ti.i, 3)
blk3 = blk2.dense(ti.jk, (5, 2))
blk4 = blk3.dense(ti.k, 2)
blk1.shape() # ()
blk2.shape() # (3, )
blk3.shape() # (3, 5, 2)
blk4.shape() # (3, 5, 4)


.. function:: snode.dim()

:parameter snode: (SNode)
:return: (scalar) the dimensionality of ``snode``

Equivalent to ``len(snode.shape())``.

::

blk1 = ti.root.dense(ti.ijk, (8, 9, 10))
ti.root.dim() # 0
blk1.dim() # 3


.. function:: snode.parent()
Expand Down Expand Up @@ -188,6 +220,11 @@ Indices
.. function:: ti.j
.. function:: ti.k
.. function:: ti.ij
.. function:: ti.ji
.. function:: ti.jk
.. function:: ti.kj
.. function:: ti.ik
.. function:: ti.ki
.. function:: ti.ijk
.. function:: ti.ijkl
.. function:: ti.indices(a, b, ...)
Expand Down
4 changes: 3 additions & 1 deletion examples/quadtree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import taichi as ti
import numpy as np

ti.init(arch=ti.cuda)
ti.init(arch=ti.cpu)

RES = 1024
K = 2
Expand All @@ -18,6 +18,8 @@

img = ti.Vector(3, dt=ti.f32, shape=(RES, RES))

print('The quad tree layout is:\n', qt.snode())


@ti.kernel
def action(p: ti.ext_arr()):
Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
k = indices(2)
l = indices(3)
ij = indices(0, 1)
ji = indices(1, 0)
jk = indices(1, 2)
kj = indices(2, 1)
ik = indices(0, 2)
ki = indices(2, 0)
ijk = indices(0, 1, 2)
ijkl = indices(0, 1, 2, 3)

Expand Down
24 changes: 6 additions & 18 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

# Scalar, basic data type
class Expr(TaichiOperations):
materialize_layout_callback = None
layout_materialized = False

def __init__(self, *args, tb=None):
self.getter = None
self.setter = None
Expand Down Expand Up @@ -39,8 +36,7 @@ def __init__(self, *args, tb=None):

@python_scope
def __setitem__(self, key, value):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
self.initialize_accessor()
if key is None:
key = ()
Expand All @@ -53,8 +49,7 @@ def __setitem__(self, key, value):

@python_scope
def __getitem__(self, key):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
self.initialize_accessor()
if key is None:
key = ()
Expand All @@ -67,9 +62,6 @@ def __getitem__(self, key):
def loop_range(self):
return self

def serialize(self):
return self.ptr.serialize()

@python_scope
def initialize_accessor(self):
if self.getter:
Expand Down Expand Up @@ -124,10 +116,8 @@ def fill(self, val):

def parent(self, n=1):
import taichi as ti
p = self.ptr.snode()
for i in range(n):
p = p.parent
return Expr(ti.core.global_var_expr_from_snode(p))
p = self.snode().parent(n)
return Expr(ti.core.global_var_expr_from_snode(p.ptr))

def snode(self):
from .snode import SNode
Expand All @@ -137,13 +127,11 @@ def __hash__(self):
return self.ptr.get_raw_address()

def dim(self):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
return self.snode().dim()

def shape(self):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
return self.snode().shape()

def data_type(self):
Expand Down
7 changes: 7 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def create_program(self):
if self.prog is None:
self.prog = taichi_lang_core.Program()

def try_materialize(self):
if not Expr.layout_materialized:
Expr.materialize_layout_callback()

def materialize(self):
if self.materialized:
return
Expand Down Expand Up @@ -249,6 +253,9 @@ def __getattribute__(self, item):
root = SNode(ti.get_runtime().prog.get_root())
return getattr(root, item)

def __repr__(self):
return 'ti.root'


root = Root()

Expand Down
34 changes: 30 additions & 4 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from . import impl
from .util import deprecated


class SNode:
def __init__(self, ptr):
self.ptr = ptr
Expand Down Expand Up @@ -48,20 +52,31 @@ def place(self, *args, offset=None):
def lazy_grad(self):
self.ptr.lazy_grad()

def parent(self):
return SNode(self.ptr.snode().parent)
def parent(self, n=1):
impl.get_runtime().try_materialize()
p = self.ptr
for i in range(n):
p = p.parent
if p.type == impl.taichi_lang_core.SNodeType.root:
return impl.root
else:
return SNode(p)

def data_type(self):
return self.ptr.data_type()

def dim(self):
impl.get_runtime().try_materialize()
return self.ptr.num_active_indices()

def shape(self):
return tuple(self.get_shape(i) for i in range(self.dim()))
impl.get_runtime().try_materialize()
return tuple(
self.ptr.get_num_elements_along_axis(i) for i in range(self.dim()))

@deprecated('snode.get_shape(i)', 'snode.shape()[i]')
def get_shape(self, i):
return self.ptr.get_num_elements_along_axis(i)
return self.shape()[i]

def loop_range(self):
import taichi as ti
Expand All @@ -85,6 +100,17 @@ def deactivate_all(self):
from .meta import snode_deactivate
snode_deactivate(self)

def __repr__(self):
# ti.root.dense(ti.i, 3).dense(ti.jk, (4, 5)).place(x)
# ti.root => dense [3] => dense [3, 4, 5] => place [3, 4, 5]
type = repr(self.ptr.type)[len('SNodeType.'):]
shape = repr(list(self.shape()))
parent = repr(self.parent())
return f'{parent} => {type} {shape}'

def __eq__(self, other):
return self.ptr == other.ptr

def physical_index_position(self):
ret = {}
for virtual, physical in enumerate(
Expand Down
60 changes: 54 additions & 6 deletions tests/python/test_tensor_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@ti.all_archs
def test_POT1():
def test_POT():
val = ti.var(ti.i32)

n = 4
Expand All @@ -11,22 +11,27 @@ def test_POT1():

ti.root.dense(ti.i, n).dense(ti.j, m).dense(ti.k, p).place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.data_type() == ti.i32


@ti.all_archs
def test_POT2():
def test_non_POT():
val = ti.var(ti.i32)

n = 3
m = 7
p = 11

ti.root.dense(ti.i, n).dense(ti.j, m).dense(ti.k, p).place(val)
blk1 = ti.root.dense(ti.i, n)
blk2 = blk1.dense(ti.j, m)
blk3 = blk2.dense(ti.k, p)
blk3.place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.data_type() == ti.i32


@ti.all_archs
Expand All @@ -37,7 +42,50 @@ def test_unordered():
m = 7
p = 11

ti.root.dense(ti.k, n).dense(ti.i, m).dense(ti.j, p).place(val)
blk1 = ti.root.dense(ti.k, n)
blk2 = blk1.dense(ti.i, m)
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.data_type() == ti.i32
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.snode().parent(0) == val.snode()
assert val.snode().parent() == blk3
assert val.snode().parent(1) == blk3
assert val.snode().parent(2) == blk2
assert val.snode().parent(3) == blk1
assert val.snode().parent(4) == ti.root

assert val.snode() in blk3.get_children()
assert blk3 in blk2.get_children()
assert blk2 in blk1.get_children()
assert blk1 in ti.root.get_children()

expected_repr = f'ti.root => dense {[n]} => dense {[n, m]}' \
f' => dense {[n, m, p]} => place {[n, m, p]}'
assert repr(val.snode()) == expected_repr


@ti.all_archs
def test_unordered_matrix():
val = ti.Matrix(3, 2, ti.i32)

n = 3
m = 7
p = 11

blk1 = ti.root.dense(ti.k, n)
blk2 = blk1.dense(ti.i, m)
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.data_type() == ti.i32
assert val.loop_range().snode().parent(0) == val.loop_range().snode()
assert val.loop_range().snode().parent() == blk3
assert val.loop_range().snode().parent(1) == blk3
assert val.loop_range().snode().parent(2) == blk2
assert val.loop_range().snode().parent(3) == blk1
assert val.loop_range().snode().parent(4) == ti.root

0 comments on commit 307b676

Please sign in to comment.