Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [refactor] Deprecate x.shape() and x.dim(), use x.shape instead #1318

Merged
merged 5 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,11 @@ def __hash__(self):
return self.ptr.get_raw_address()

def dim(self):
impl.get_runtime().try_materialize()
return self.snode().dim()

@property
def shape(self):
impl.get_runtime().try_materialize()
return self.snode().shape()
return self.snode().shape

def data_type(self):
return self.snode().data_type()
Expand All @@ -141,7 +140,7 @@ def data_type(self):
def to_numpy(self):
from .meta import tensor_to_ext_arr
import numpy as np
arr = np.zeros(shape=self.shape(),
arr = np.zeros(shape=self.shape,
dtype=to_numpy_type(self.snode().data_type()))
tensor_to_ext_arr(self, arr)
import taichi as ti
Expand All @@ -152,7 +151,7 @@ def to_numpy(self):
def to_torch(self, device=None):
from .meta import tensor_to_ext_arr
import torch
arr = torch.zeros(size=self.shape(),
arr = torch.zeros(size=self.shape,
dtype=to_pytorch_type(self.snode().data_type()),
device=device)
tensor_to_ext_arr(self, arr)
Expand All @@ -163,7 +162,7 @@ def to_torch(self, device=None):
@python_scope
def from_numpy(self, arr):
assert self.dim() == len(arr.shape)
s = self.shape()
s = self.shape
for i in range(self.dim()):
assert s[i] == arr.shape[i]
from .meta import ext_arr_to_tensor
Expand Down
13 changes: 7 additions & 6 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,13 +512,14 @@ def diag(dim, val):
def loop_range(self):
return self.entries[0]

def dim(self):
return self.loop_range().dim()

@property
def shape(self):
# Took `self.entries[0]` as a representation of this tensor-of-matrices.
# https://github.com/taichi-dev/taichi/issues/1069#issuecomment-635712140
return self.loop_range().shape()

def dim(self):
return self.loop_range().dim()
return self.loop_range().shape

def data_type(self):
return self.loop_range().data_type()
Expand Down Expand Up @@ -622,7 +623,7 @@ def to_numpy(self, keep_dims=False, as_vector=None):
if not self.is_global():
return np.array(self.entries).reshape(shape_ext)

ret = np.empty(self.loop_range().shape() + shape_ext,
ret = np.empty(self.loop_range().shape + shape_ext,
dtype=to_numpy_type(
self.loop_range().snode().data_type()))
from .meta import matrix_to_ext_arr
Expand All @@ -636,7 +637,7 @@ def to_torch(self, device=None, keep_dims=False):
import torch
as_vector = self.m == 1 and not keep_dims
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
ret = torch.empty(self.loop_range().shape() + shape_ext,
ret = torch.empty(self.loop_range().shape + shape_ext,
dtype=to_pytorch_type(
self.loop_range().snode().data_type()),
device=device)
Expand Down
24 changes: 17 additions & 7 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,28 @@ def parent(self, n=1):
def data_type(self):
return self.ptr.data_type()

@deprecated('x.dim()', 'len(x.shape)')
def dim(self):
impl.get_runtime().try_materialize()
return self.ptr.num_active_indices()
return len(self.shape)

@property
def shape(self):
impl.get_runtime().try_materialize()
return tuple(
self.ptr.get_num_elements_along_axis(i) for i in range(self.dim()))
dim = self.ptr.num_active_indices()
ret = [
self.ptr.get_num_elements_along_axis(i) for i in range(dim)]

class callable_tuple(tuple):
@deprecated('x.shape()', 'x.shape')
def __call__(self):
return self

ret = callable_tuple(ret)
return ret

@deprecated('snode.get_shape(i)', 'snode.shape()[i]')
@deprecated('x.get_shape(i)', 'x.shape[i]')
def get_shape(self, i):
return self.shape()[i]
return self.shape[i]

def loop_range(self):
import taichi as ti
Expand Down Expand Up @@ -104,7 +114,7 @@ def __repr__(self):
# ti.root.dense(ti.i, 3).dense(ti.jk, (4, 5)).place(x)
# ti.root => dense [3] => dense [3, 4, 5] => place [3, 4, 5]
type = repr(self.ptr.type)[len('SNodeType.'):]
shape = repr(list(self.shape()))
shape = repr(list(self.shape))
parent = repr(self.parent())
return f'{parent} => {type} {shape}'

Expand Down
2 changes: 1 addition & 1 deletion python/taichi/misc/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def set_image(self, img):
import taichi as ti

if isinstance(img, ti.Expr):
if ti.core.is_integral(img.data_type()) or len(img.shape()) != 2:
if ti.core.is_integral(img.data_type()) or len(img.shape) != 2:
# Images of uint is not optimized by xxx_to_image
self.img = self.cook_image(img.to_numpy())
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def fill():
assert m4[0][j, i] == int(i + 3 * j + 1)


# Remove this once the apis are fully deprecated in incoming version.
# TODO: Remove this once the apis are fully deprecated in incoming version.
@pytest.mark.filterwarnings('ignore')
@ti.host_arch_only
def test_init_matrix_from_vectors_deprecated():
Expand Down
36 changes: 28 additions & 8 deletions tests/python/test_tensor_reflection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import taichi as ti
import pytest


@ti.all_archs
Expand All @@ -11,8 +12,7 @@ def test_POT():

ti.root.dense(ti.i, n).dense(ti.j, m).dense(ti.k, p).place(val)

assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.shape == (n, m, p)
assert val.data_type() == ti.i32


Expand All @@ -29,8 +29,7 @@ def test_non_POT():
blk3 = blk2.dense(ti.k, p)
blk3.place(val)

assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.shape == (n, m, p)
assert val.data_type() == ti.i32


Expand All @@ -48,8 +47,7 @@ def test_unordered():
blk3.place(val)

assert val.data_type() == ti.i32
assert val.shape() == (n, m, p)
assert val.dim() == 3
assert val.shape == (n, m, p)
assert val.snode().parent(0) == val.snode()
assert val.snode().parent() == blk3
assert val.snode().parent(1) == blk3
Expand Down Expand Up @@ -80,12 +78,34 @@ def test_unordered_matrix():
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert val.shape == (n, m, p)
assert val.data_type() == ti.i32
assert val.loop_range().snode().parent(0) == val.loop_range().snode()
assert val.loop_range().snode().parent() == blk3
assert val.loop_range().snode().parent(1) == blk3
assert val.loop_range().snode().parent(2) == blk2
assert val.loop_range().snode().parent(3) == blk1
assert val.loop_range().snode().parent(4) == ti.root


@pytest.mark.filterwarnings('ignore')
@ti.host_arch_only
def test_deprecated():
val = ti.var(ti.f32)
mat = ti.Matrix(3, 2, ti.i32)

n = 3
m = 7
p = 11

blk1 = ti.root.dense(ti.k, n)
blk2 = blk1.dense(ti.i, m)
blk3 = blk2.dense(ti.j, p)
blk3.place(val, mat)

assert val.dim() == 3
assert val.shape() == (n, m, p)
assert mat.dim() == 3
assert mat.shape() == (n, m, p)
assert blk3.dim() == 3
assert blk3.shape() == (n, m, p)
2 changes: 1 addition & 1 deletion tests/python/test_tuple_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_unpack_from_shape():

@ti.kernel
def func():
a[None], b[None], c[None] = d.shape()
a[None], b[None], c[None] = d.shape

func()
assert a[None] == 2
Expand Down