Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] [test] Improve code coverage in SNode #1214

Merged
merged 15 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions docs/snode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ See :ref:`layout` for more details. ``ti.root`` is the root node of the data str
x = ti.var(dt=ti.i32)
y = ti.var(dt=ti.f32)
ti.root.place(x, y)
assert x.snode() == y.snode()


.. function:: tensor.shape()

:parameter tensor: (Tensor)
:return: (tuple of integers) the shape of tensor

Equivalent to ``tensor.snode().shape()``.

For example,

::
Expand All @@ -44,33 +48,61 @@ See :ref:`layout` for more details. ``ti.root`` is the root node of the data str
x.shape() # returns (3, 5, 4)


.. function:: snode.get_shape(index)
.. function:: tensor.dim()

:parameter snode: (SNode)
:parameter index: axis (0 for ``i`` and 1 for ``j``)
:return: (scalar) the size of tensor along that axis
:parameter tensor: (Tensor)
:return: (scalar) the dimensionality of the tensor

Equivalent to ``tensor.shape()[i]``.
Equivalent to ``len(tensor.shape())``.

::

ti.root.dense(ti.ijk, (3, 5, 4)).place(x)
x.snode().get_shape(0) # 3
x.snode().get_shape(1) # 5
x.snode().get_shape(2) # 4
ti.root.dense(ti.ijk, (8, 9, 10)).place(x)
x.dim() # 3


.. function:: tensor.dim()
.. function:: tensor.snode()

:parameter tensor: (Tensor)
:return: (scalar) the dimensionality of the tensor
:return: (SNode) the structual node where ``tensor`` is placed

Equivalent to ``len(tensor.shape())``.
::

x = ti.var(dt=ti.i32)
y = ti.var(dt=ti.f32)
ti.root.place(x, y)
x.snode()


.. function:: snode.shape()

:parameter snode: (SNode)
:return: (tuple) the size of node along that axis

::

ti.root.dense(ti.ijk, (8, 9, 10)).place(x)
x.dim() # 3
blk1 = ti.root
blk2 = blk1.dense(ti.i, 3)
blk3 = blk2.dense(ti.jk, (5, 2))
blk4 = blk3.dense(ti.k, 2)
blk1.shape() # ()
blk2.shape() # (3, )
blk3.shape() # (3, 5, 2)
blk4.shape() # (3, 5, 4)


.. function:: snode.dim()

:parameter snode: (SNode)
:return: (scalar) the dimensionality of ``snode``

Equivalent to ``len(snode.shape())``.

::

blk1 = ti.root.dense(ti.ijk, (8, 9, 10))
ti.root.dim() # 0
blk1.dim() # 3


.. function:: snode.parent()
Expand Down Expand Up @@ -188,6 +220,11 @@ Indices
.. function:: ti.j
.. function:: ti.k
.. function:: ti.ij
.. function:: ti.ji
.. function:: ti.jk
.. function:: ti.kj
.. function:: ti.ik
.. function:: ti.ki
.. function:: ti.ijk
.. function:: ti.ijkl
.. function:: ti.indices(a, b, ...)
Expand Down
4 changes: 3 additions & 1 deletion examples/quadtree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import taichi as ti
import numpy as np

ti.init(arch=ti.cuda)
ti.init(arch=ti.cpu)

RES = 1024
K = 2
Expand All @@ -18,6 +18,8 @@

img = ti.Vector(3, dt=ti.f32, shape=(RES, RES))

print('The quad tree layout is:\n', qt.snode())


@ti.kernel
def action(p: ti.ext_arr()):
Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
k = indices(2)
l = indices(3)
ij = indices(0, 1)
ji = indices(1, 0)
jk = indices(1, 2)
kj = indices(2, 1)
ik = indices(0, 2)
ki = indices(2, 0)
ijk = indices(0, 1, 2)
ijkl = indices(0, 1, 2, 3)

Expand Down
24 changes: 6 additions & 18 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@

# Scalar, basic data type
class Expr(TaichiOperations):
materialize_layout_callback = None
layout_materialized = False

def __init__(self, *args, tb=None):
self.getter = None
self.setter = None
Expand Down Expand Up @@ -39,8 +36,7 @@ def __init__(self, *args, tb=None):

@python_scope
def __setitem__(self, key, value):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
self.initialize_accessor()
if key is None:
key = ()
Expand All @@ -53,8 +49,7 @@ def __setitem__(self, key, value):

@python_scope
def __getitem__(self, key):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
self.initialize_accessor()
if key is None:
key = ()
Expand All @@ -67,9 +62,6 @@ def __getitem__(self, key):
def loop_range(self):
return self

def serialize(self):
return self.ptr.serialize()
archibate marked this conversation as resolved.
Show resolved Hide resolved

@python_scope
def initialize_accessor(self):
if self.getter:
Expand Down Expand Up @@ -124,10 +116,8 @@ def fill(self, val):

def parent(self, n=1):
import taichi as ti
p = self.ptr.snode()
for i in range(n):
p = p.parent
return Expr(ti.core.global_var_expr_from_snode(p))
p = self.snode().parent(n)
return Expr(ti.core.global_var_expr_from_snode(p.ptr))

def snode(self):
from .snode import SNode
Expand All @@ -137,13 +127,11 @@ def __hash__(self):
return self.ptr.get_raw_address()

def dim(self):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
return self.snode().dim()

def shape(self):
if not Expr.layout_materialized:
self.materialize_layout_callback()
impl.get_runtime().try_materialize()
return self.snode().shape()

def data_type(self):
Expand Down
7 changes: 7 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def create_program(self):
if self.prog is None:
self.prog = taichi_lang_core.Program()

def try_materialize(self):
if not Expr.layout_materialized:
Expr.materialize_layout_callback()

def materialize(self):
if self.materialized:
return
Expand Down Expand Up @@ -249,6 +253,9 @@ def __getattribute__(self, item):
root = SNode(ti.get_runtime().prog.get_root())
return getattr(root, item)

def __repr__(self):
return 'ti.root'


root = Root()

Expand Down
34 changes: 30 additions & 4 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from . import impl
from .util import deprecated


class SNode:
def __init__(self, ptr):
self.ptr = ptr
Expand Down Expand Up @@ -48,20 +52,31 @@ def place(self, *args, offset=None):
def lazy_grad(self):
self.ptr.lazy_grad()

def parent(self):
return SNode(self.ptr.snode().parent)
def parent(self, n=1):
impl.get_runtime().try_materialize()
p = self.ptr
for i in range(n):
p = p.parent
if p.type == impl.taichi_lang_core.SNodeType.root:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the parent of root is nullptr, so I think we should return None instead of making a cycle?

return impl.root
else:
return SNode(p)

def data_type(self):
return self.ptr.data_type()

def dim(self):
impl.get_runtime().try_materialize()
return self.ptr.num_active_indices()

def shape(self):
return tuple(self.get_shape(i) for i in range(self.dim()))
impl.get_runtime().try_materialize()
return tuple(
self.ptr.get_num_elements_along_axis(i) for i in range(self.dim()))

@deprecated('snode.get_shape(i)', 'snode.shape()[i]')
def get_shape(self, i):
return self.ptr.get_num_elements_along_axis(i)
return self.shape()[i]

def loop_range(self):
import taichi as ti
Expand All @@ -85,6 +100,17 @@ def deactivate_all(self):
from .meta import snode_deactivate
snode_deactivate(self)

def __repr__(self):
# ti.root.dense(ti.i, 3).dense(ti.jk, (4, 5)).place(x)
# ti.root => dense [3] => dense [3, 4, 5] => place [3, 4, 5]
type = repr(self.ptr.type)[len('SNodeType.'):]
shape = repr(list(self.shape()))
parent = repr(self.parent())
return f'{parent} => {type} {shape}'

def __eq__(self, other):
return self.ptr == other.ptr

def physical_index_position(self):
ret = {}
for virtual, physical in enumerate(
Expand Down
60 changes: 54 additions & 6 deletions tests/python/test_tensor_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


@ti.all_archs
def test_POT1():
def test_POT():
val = ti.var(ti.i32)

n = 4
Expand All @@ -11,22 +11,27 @@ def test_POT1():

ti.root.dense(ti.i, n).dense(ti.j, m).dense(ti.k, p).place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.data_type() == ti.i32


@ti.all_archs
def test_POT2():
def test_non_POT():
val = ti.var(ti.i32)

n = 3
m = 7
p = 11

ti.root.dense(ti.i, n).dense(ti.j, m).dense(ti.k, p).place(val)
blk1 = ti.root.dense(ti.i, n)
blk2 = blk1.dense(ti.j, m)
blk3 = blk2.dense(ti.k, p)
blk3.place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.data_type() == ti.i32


@ti.all_archs
Expand All @@ -37,7 +42,50 @@ def test_unordered():
m = 7
p = 11

ti.root.dense(ti.k, n).dense(ti.i, m).dense(ti.j, p).place(val)
blk1 = ti.root.dense(ti.k, n)
blk2 = blk1.dense(ti.i, m)
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.data_type() == ti.i32
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.snode().parent(0) == val.snode()
assert val.snode().parent() == blk3
assert val.snode().parent(1) == blk3
assert val.snode().parent(2) == blk2
assert val.snode().parent(3) == blk1
assert val.snode().parent(4) == ti.root

assert val.snode() in blk3.get_children()
assert blk3 in blk2.get_children()
assert blk2 in blk1.get_children()
assert blk1 in ti.root.get_children()

expected_repr = f'ti.root => dense {[n]} => dense {[n, m]}' \
f' => dense {[n, m, p]} => place {[n, m, p]}'
assert repr(val.snode()) == expected_repr


@ti.all_archs
def test_unordered_matrix():
val = ti.Matrix(3, 2, ti.i32)

n = 3
m = 7
p = 11

blk1 = ti.root.dense(ti.k, n)
blk2 = blk1.dense(ti.i, m)
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.data_type() == ti.i32
assert val.loop_range().snode().parent(0) == val.loop_range().snode()
assert val.loop_range().snode().parent() == blk3
assert val.loop_range().snode().parent(1) == blk3
assert val.loop_range().snode().parent(2) == blk2
assert val.loop_range().snode().parent(3) == blk1
assert val.loop_range().snode().parent(4) == ti.root