Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] [lang] Deprecate x.data_type() and use x.dtype instead #1374

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def initialize_accessor(self):
return
snode = self.ptr.snode()

if self.snode().data_type() == f32 or self.snode().data_type() == f64:
if self.dtype == f32 or self.dtype == f64:

def getter(*key):
assert len(key) == taichi_lang_core.get_max_num_indices()
Expand All @@ -78,7 +78,7 @@ def setter(value, *key):
assert len(key) == taichi_lang_core.get_max_num_indices()
snode.write_float(key, value)
else:
if taichi_lang_core.is_signed(self.snode().data_type()):
if taichi_lang_core.is_signed(self.dtype):

def getter(*key):
assert len(key) == taichi_lang_core.get_max_num_indices()
Expand Down Expand Up @@ -135,15 +135,19 @@ def shape(self):
def dim(self):
return len(self.shape)

@property
def dtype(self):
return self.snode().dtype

@deprecated('x.data_type()', 'x.dtype')
def data_type(self):
return self.snode().data_type()
return self.snode().dtype

@python_scope
def to_numpy(self):
from .meta import tensor_to_ext_arr
import numpy as np
arr = np.zeros(shape=self.shape,
dtype=to_numpy_type(self.snode().data_type()))
arr = np.zeros(shape=self.shape, dtype=to_numpy_type(self.dtype))
tensor_to_ext_arr(self, arr)
import taichi as ti
ti.sync()
Expand All @@ -154,7 +158,7 @@ def to_torch(self, device=None):
from .meta import tensor_to_ext_arr
import torch
arr = torch.zeros(size=self.shape,
dtype=to_pytorch_type(self.snode().data_type()),
dtype=to_pytorch_type(self.dtype),
device=device)
tensor_to_ext_arr(self, arr)
import taichi as ti
Expand Down
16 changes: 9 additions & 7 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,13 @@ def shape(self):
def dim(self):
return len(self.shape)

@property
def dtype(self):
return self.loop_range().dtype

@deprecated('x.data_type()', 'x.dtype')
def data_type(self):
return self.loop_range().data_type()
return self.dtype

def make_grad(self):
ret = self.empty_copy()
Expand Down Expand Up @@ -619,9 +624,7 @@ def to_numpy(self, keep_dims=False, as_vector=None):
if not self.is_global():
return np.array(self.entries).reshape(shape_ext)

ret = np.empty(self.loop_range().shape + shape_ext,
dtype=to_numpy_type(
self.loop_range().snode().data_type()))
ret = np.empty(self.shape + shape_ext, dtype=to_numpy_type(self.dtype))
from .meta import matrix_to_ext_arr
matrix_to_ext_arr(self, ret, as_vector)
import taichi as ti
Expand All @@ -633,9 +636,8 @@ def to_torch(self, device=None, keep_dims=False):
import torch
as_vector = self.m == 1 and not keep_dims
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
ret = torch.empty(self.loop_range().shape + shape_ext,
dtype=to_pytorch_type(
self.loop_range().snode().data_type()),
ret = torch.empty(self.shape + shape_ext,
dtype=to_pytorch_type(self.dtype),
device=device)
from .meta import matrix_to_ext_arr
matrix_to_ext_arr(self, ret, as_vector)
Expand Down
7 changes: 6 additions & 1 deletion python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,14 @@ def parent(self, n=1):
return impl.root
return SNode(p)

def data_type(self):
@property
def dtype(self):
return self.ptr.data_type()

@deprecated('x.data_type()', 'x.dtype')
def data_type(self):
return self.dtype

@deprecated('x.dim()', 'len(x.shape)')
def dim(self):
return len(self.shape)
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/misc/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def set_image(self, img):
import taichi as ti

if isinstance(img, ti.Expr):
if ti.core.is_integral(img.data_type()) or len(img.shape) != 2:
if ti.core.is_integral(img.dtype) or len(img.shape) != 2:
# Images of uint is not optimized by xxx_to_image
self.img = self.cook_image(img.to_numpy())
else:
Expand All @@ -92,7 +92,7 @@ def set_image(self, img):
ti.sync()

elif isinstance(img, ti.Matrix):
if ti.core.is_integral(img.data_type()):
if ti.core.is_integral(img.dtype):
self.img = self.cook_image(img.to_numpy())
else:
# Type matched! We can use an optimized copy kernel.
Expand Down
10 changes: 6 additions & 4 deletions tests/python/test_tensor_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_POT():
ti.root.dense(ti.i, n).dense(ti.j, m).dense(ti.k, p).place(val)

assert val.shape == (n, m, p)
assert val.data_type() == ti.i32
assert val.dtype == ti.i32


@ti.all_archs
Expand All @@ -30,7 +30,7 @@ def test_non_POT():
blk3.place(val)

assert val.shape == (n, m, p)
assert val.data_type() == ti.i32
assert val.dtype == ti.i32


@ti.all_archs
Expand All @@ -46,7 +46,7 @@ def test_unordered():
blk3 = blk2.dense(ti.j, p)
blk3.place(val)

assert val.data_type() == ti.i32
assert val.dtype == ti.i32
assert val.shape == (n, m, p)
assert val.snode().parent(0) == val.snode()
assert val.snode().parent() == blk3
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_unordered_matrix():
blk3.place(val)

assert val.shape == (n, m, p)
assert val.data_type() == ti.i32
assert val.dtype == ti.i32
assert val.loop_range().snode().parent(0) == val.loop_range().snode()
assert val.loop_range().snode().parent() == blk3
assert val.loop_range().snode().parent(1) == blk3
Expand All @@ -104,8 +104,10 @@ def test_deprecated():
blk3.place(val, mat)

assert val.dim() == 3
assert val.data_type() == ti.f32
assert val.shape() == (n, m, p)
assert mat.dim() == 3
assert mat.data_type() == ti.i32
assert mat.shape() == (n, m, p)
assert blk3.dim() == 3
assert blk3.shape() == (n, m, p)
Expand Down