Skip to content

Commit

Permalink
[Lang] [type] Add bool type in python as an alias to i32 (taichi-dev#…
Browse files Browse the repository at this point in the history
…6742)

Issue: taichi-dev#577 taichi-dev#6036

### Brief Summary

This PR adds `bool` as an alias to `ti.i32`. Specifically,

- `x: bool` is equivalent to `x: i32`
- `-> bool` is equivalent to `-> i32`
- `bool(x)` is equivalent to `ti.cast(x, i32)`.

This is a temporary solution while we work towards a standalone bool
type.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 924403d commit 7904eed
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def build_call_if_is_builtin(ctx, node, args, keywords):
id(min): ti_ops.min,
id(max): ti_ops.max,
id(int): impl.ti_int,
id(bool): impl.ti_bool,
id(float): impl.ti_float,
id(any): matrix_ops.any,
id(all): matrix_ops.all,
Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/common_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from taichi.lang import ops
from taichi.lang.util import in_python_scope
from taichi.types import primitive_types


class TaichiOperations:
Expand Down Expand Up @@ -298,5 +299,9 @@ def _augassign(self, x, op):
def __ti_int__(self):
return ops.cast(self, int)

def __ti_bool__(self):
return ops.cast(
self, primitive_types.i32) # TODO[Xiaoyan]: Use i1 in the future

def __ti_float__(self):
return ops.cast(self, float)
7 changes: 7 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,13 @@ def ti_int(_var):
return int(_var)


@taichi_scope
def ti_bool(_var):
if hasattr(_var, '__ti_bool__'):
return _var.__ti_bool__()
return bool(_var)


@taichi_scope
def ti_float(_var):
if hasattr(_var, '__ti_float__'):
Expand Down
2 changes: 2 additions & 0 deletions python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ def cook_dtype(dtype):
return impl.get_runtime().default_fp
if dtype is int:
return impl.get_runtime().default_ip
if dtype is bool:
return i32 # TODO[Xiaoyan]: Use i1 in the future
raise ValueError(f'Invalid data type {dtype}')


Expand Down
2 changes: 1 addition & 1 deletion python/taichi/types/primitive_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def ref(tp):
real_types = [f16, f32, f64, float]
real_type_ids = [id(t) for t in real_types]

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int]
integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int, bool]
integer_type_ids = [id(t) for t in integer_types]

all_types = real_types + integer_types
Expand Down
30 changes: 30 additions & 0 deletions tests/python/test_bool_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import taichi as ti
from tests import test_utils


@test_utils.test(debug=True)
def test_bool_type_anno():
@ti.func
def f(x: bool) -> bool:
return not x

@ti.kernel
def test():
assert f(True) == False
assert f(False) == True

test()


@test_utils.test(debug=True)
def test_bool_type_conv():
@ti.func
def f(x: ti.u32) -> bool:
return bool(x)

@ti.kernel
def test():
assert f(1000) == 1000
assert f(ti.u32(4_294_967_295)) == -1

test()

0 comments on commit 7904eed

Please sign in to comment.