Skip to content

Commit

Permalink
[Doc] Add Python docstring in several places (#2592)
Browse files Browse the repository at this point in the history
* [Doc] Add Python docstring in several places

* Auto Format

* Add more docs

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
strongoier and taichi-gardener authored Jul 27, 2021
1 parent a3acd1e commit 740a289
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 2 deletions.
27 changes: 26 additions & 1 deletion python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,19 @@ def memory_profiler_print():


extension = _ti_core.Extension
is_extension_supported = _ti_core.is_extension_supported


def is_extension_supported(arch, ext):
"""Checks whether an extension is supported on an arch.
Args:
arch (taichi_core.Arch): Specified arch.
ext (taichi_core.Extension): Specified extension.
Returns:
bool: Whether `ext` is supported on `arch`.
"""
return _ti_core.is_extension_supported(arch, ext)


def reset():
Expand Down Expand Up @@ -609,6 +621,14 @@ def stat_write(key, value):


def is_arch_supported(arch):
"""Checks whether an arch is supported on the machine.
Args:
arch (taichi_core.Arch): Specified arch.
Returns:
bool: Whether `arch` is supported on the machine.
"""
arch_table = {
cuda: _ti_core.with_cuda,
metal: _ti_core.with_metal,
Expand All @@ -631,6 +651,11 @@ def is_arch_supported(arch):


def supported_archs():
"""Gets all supported archs on the machine.
Returns:
List[taichi_core.Arch]: All supported archs on the machine.
"""
archs = [cpu, cuda, metal, opengl, cc]

wanted_archs = os.environ.get('TI_WANTED_ARCHS', '')
Expand Down
36 changes: 36 additions & 0 deletions python/taichi/lang/quant_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,39 @@


class Quant:
"""Generator of quantized types.
For more details, read https://yuanming.taichi.graphics/publication/2021-quantaichi/quantaichi.pdf.
"""
@staticmethod
def int(bits, signed=False, compute=None):
"""Generates a quantized type for integers.
Args:
bits (int): Number of bits.
signed (bool): Signed or unsigned.
compute (DataType): Type for computation.
Returns:
DataType: The specified type.
"""
if compute is None:
compute = impl.get_runtime().default_ip
return tf_impl.type_factory.custom_int(bits, signed, compute)

@staticmethod
def fixed(frac, signed=True, range=1.0, compute=None):
"""Generates a quantized type for fixed-point real numbers.
Args:
frac (int): Number of bits.
signed (bool): Signed or unsigned.
range (float): Range of the number.
compute (DataType): Type for computation.
Returns:
DataType: The specified type.
"""
# TODO: handle cases with frac > 32
frac_type = Quant.int(bits=frac, signed=signed, compute=ti.i32)
if signed:
Expand All @@ -26,6 +51,17 @@ def fixed(frac, signed=True, range=1.0, compute=None):

@staticmethod
def float(exp, frac, signed=True, compute=None):
"""Generates a quantized type for floating-point real numbers.
Args:
exp (int): Number of exponent bits.
frac (int): Number of fraction bits.
signed (bool): Signed or unsigned.
compute (DataType): Type for computation.
Returns:
DataType: The specified type.
"""
# Exponent is always unsigned
exp_type = Quant.int(bits=exp, signed=False, compute=ti.i32)
# TODO: handle cases with frac > 32
Expand Down
56 changes: 55 additions & 1 deletion python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def place(self, *args, offset=None, shared_exponent=False):
Args:
*args (List[ti.field]): A list of Taichi fields to place.
offsest (Union[Number, tuple[Number]]): Offset of the field domain.
offset (Union[Number, tuple[Number]]): Offset of the field domain.
shared_exponent (bool): Only useful for quant types.
Returns:
Expand Down Expand Up @@ -160,6 +160,14 @@ def lazy_grad(self):
self.ptr.lazy_grad()

def parent(self, n=1):
"""Gets an ancestor of `self` in the SNode tree.
Args:
n (int): the number of levels going up from `self`.
Returns:
Union[None, _Root, SNode]: The n-th parent of `self`.
"""
impl.get_runtime().materialize()
p = self.ptr
while p and n > 0:
Expand All @@ -173,6 +181,11 @@ def parent(self, n=1):

@property
def dtype(self):
"""Gets the data type of `self`.
Returns:
DataType: The data type of `self`.
"""
return self.ptr.data_type()

@deprecated('x.data_type()', 'x.dtype')
Expand All @@ -185,10 +198,20 @@ def dim(self):

@property
def id(self):
"""Gets the id of `self`.
Returns:
int: The id of `self`.
"""
return self.ptr.id

@property
def shape(self):
"""Gets the number of elements from root in each axis of `self`.
Returns:
Tuple[int]: The number of elements from root in each axis of `self`.
"""
impl.get_runtime().materialize()
dim = self.ptr.num_active_indices()
ret = [self.ptr.get_shape_along_axis(i) for i in range(dim)]
Expand All @@ -206,10 +229,20 @@ def get_shape(self, i):
return self.shape[i]

def loop_range(self):
"""Wraps `self` into an :class:`~taichi.lang.Expr` to serve as loop range.
Returns:
Expr: The wrapped result.
"""
return Expr(_ti_core.global_var_expr_from_snode(self.ptr))

@property
def name(self):
"""Gets the name of `self`.
Returns:
str: The name of `self`.
"""
return self.ptr.name()

@deprecated('x.snode()', 'x.snode')
Expand All @@ -218,13 +251,28 @@ def __call__(self): # TODO: remove this after v0.7.0

@property
def snode(self):
"""Gets `self`.
Returns:
SNode: `self`.
"""
return self

@property
def needs_grad(self):
"""Checks whether `self` has a corresponding gradient :class:`~taichi.lang.SNode`.
Returns:
bool: Whether `self` has a corresponding gradient :class:`~taichi.lang.SNode`.
"""
return self.ptr.has_grad()

def get_children(self):
"""Gets all children components of `self`.
Returns:
List[SNode]: All children components of `self`.
"""
children = []
for i in range(self.ptr.get_num_ch()):
children.append(SNode(self.ptr.get_ch(i)))
Expand All @@ -243,6 +291,7 @@ def cell_size_bytes(self):
return self.ptr.cell_size_bytes

def deactivate_all(self):
"""Recursively deactivate all children components of `self`."""
ch = self.get_children()
for c in ch:
c.deactivate_all()
Expand Down Expand Up @@ -272,6 +321,11 @@ def __eq__(self, other):
return self.ptr == other.ptr

def physical_index_position(self):
"""Gets mappings from virtual axes to physical axes.
Returns:
Dict[int, int]: Mappings from virtual axes to physical axes.
"""
ret = {}
for virtual, physical in enumerate(
self.ptr.get_physical_index_position()):
Expand Down
22 changes: 22 additions & 0 deletions python/taichi/lang/type_factory_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,21 @@


class TypeFactory:
"""A Python-side TypeFactory wrapper."""
def __init__(self):
self.core = _ti_core.get_type_factory_instance()

def custom_int(self, bits, signed=True, compute_type=None):
"""Generates a custom int type.
Args:
bits (int): Number of bits.
signed (bool): Signed or unsigned.
compute_type (DataType): Type for computation.
Returns:
DataType: The specified type.
"""
if compute_type is None:
compute_type = impl.get_runtime().default_ip
if isinstance(compute_type, _ti_core.DataType):
Expand All @@ -18,6 +29,17 @@ def custom_float(self,
exponent_type=None,
compute_type=None,
scale=1.0):
"""Generates a custom float type.
Args:
significand_type (DataType): Type of significand.
exponent_type (DataType): Type of exponent.
compute_type (DataType): Type for computation.
scale (float): Scaling factor.
Returns:
DataType: The specified type.
"""
if compute_type is None:
compute_type = impl.get_runtime().default_fp
if isinstance(compute_type, _ti_core.DataType):
Expand Down

0 comments on commit 740a289

Please sign in to comment.