Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] MatrixType bug fix: Fix error with BLS #6664

Merged
merged 13 commits into from
Dec 2, 2022
10 changes: 7 additions & 3 deletions python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import collections.abc

import numpy as np
from taichi.lang import ops
from taichi.lang import impl, ops
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.expr import Expr
from taichi.lang.matrix import _IntermediateMatrix
from taichi.lang.matrix import _IntermediateMatrix, make_matrix
from taichi.types import primitive_types
from taichi.types.utils import is_integral


Expand Down Expand Up @@ -144,7 +145,10 @@ def __init__(self, r):

def __iter__(self):
for ind in self.r:
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)
if impl.current_cfg().real_matrix:
yield make_matrix(list(ind), dt=primitive_types.i32)
else:
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)


__all__ = ['ndrange']
61 changes: 52 additions & 9 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import numbers
import warnings

Expand All @@ -7,6 +8,34 @@
from taichi.lang.util import get_traceback


def _get_expanded_indices(indices):
if isinstance(indices, matrix.Matrix):
indices = indices.entries
elif isinstance(indices, expr.Expr) and indices.is_tensor():
indices = [
expr.Expr(x)
for x in impl.get_runtime().prog.current_ast_builder().expand_expr(
[indices.ptr])
]
return indices


def _expand_indices(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# indices is the second argument to ti.append, ti.activate, ...
if len(args) > 1:
args = list(args)
args[1] = _get_expanded_indices(args[1])
else:
assert "indices" in kwargs.keys()
kwargs["indices"] = _get_expanded_indices(kwargs["indices"])

return func(*args, **kwargs)

return wrapper


class SNode:
"""A Python-side SNode wrapper.

Expand Down Expand Up @@ -357,6 +386,7 @@ def rescale_index(a, b, I):
Returns:
Ib (:class:`~taichi.Vector`): rescaled grouped loop index
"""

assert isinstance(
a, (Field, SNode)), "The first argument must be a field or an SNode"
assert isinstance(
Expand All @@ -365,17 +395,25 @@ def rescale_index(a, b, I):
I = matrix.Vector(I)
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
else:
assert isinstance(
I, matrix.Matrix
), "The third argument must be an index (list or ti.Vector)"
entries = [I(i) for i in range(I.n)]
for n in range(min(I.n, min(len(a.shape), len(b.shape)))):
if a.shape[n] > b.shape[n]:
entries[n] = I(n) // (a.shape[n] // b.shape[n])
if a.shape[n] < b.shape[n]:
entries[n] = I(n) * (b.shape[n] // a.shape[n])
return matrix.Vector(entries)
I, (list, expr.Expr, matrix.Matrix)
), "The third argument must be an index (list, ti.Vector, or Expr with TensorType)"
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved

from taichi.lang.kernel_impl import pyfunc # pylint: disable=C0415

@pyfunc
def _rescale_index():
entries = [I[i] for i in range(I.n)]
for n in impl.static(range(min(I.n, min(len(a.shape), len(b.shape))))):
if a.shape[n] > b.shape[n]:
entries[n] = I[n] // (a.shape[n] // b.shape[n])
if a.shape[n] < b.shape[n]:
entries[n] = I[n] * (b.shape[n] // a.shape[n])
return matrix.Vector(entries)
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved

return _rescale_index()


@_expand_indices
def append(node, indices, val):
"""Append a value `val` to a SNode `node` at index `indices`.

Expand All @@ -392,6 +430,7 @@ def append(node, indices, val):
return a


@_expand_indices
def is_active(node, indices):
"""Explicitly query whether a cell in a SNode `node` at location
`indices` is active or not.
Expand All @@ -408,6 +447,7 @@ def is_active(node, indices):
expr.make_expr_group(indices)))


@_expand_indices
def activate(node, indices):
"""Explicitly activate a cell of `node` at location `indices`.

Expand All @@ -419,6 +459,7 @@ def activate(node, indices):
node._snode.ptr, expr.make_expr_group(indices))


@_expand_indices
def deactivate(node, indices):
"""Explicitly deactivate a cell of `node` at location `indices`.

Expand All @@ -433,6 +474,7 @@ def deactivate(node, indices):
node._snode.ptr, expr.make_expr_group(indices))


@_expand_indices
def length(node, indices):
"""Return the length of the dynamic SNode `node` at index `indices`.

Expand All @@ -448,6 +490,7 @@ def length(node, indices):
expr.make_expr_group(indices)))


@_expand_indices
def get_addr(f, indices):
"""Query the memory address (on CUDA/x64) of field `f` at index `indices`.

Expand Down
75 changes: 65 additions & 10 deletions tests/python/test_bls_assume_in_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,23 @@
from .bls_test_template import bls_particle_grid


@test_utils.test(require=ti.extension.bls)
def test_scattering():
def _test_scattering():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=True,
use_offset=False)


@test_utils.test(require=ti.extension.bls)
def test_scattering_offset():
def _test_scattering_offset():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=True,
use_offset=True)


@test_utils.test(require=ti.extension.bls)
def test_scattering_two_pointer_levels():
def _test_scattering_two_pointer_levels():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
Expand All @@ -32,22 +29,80 @@ def test_scattering_two_pointer_levels():
use_offset=False)


@test_utils.test(require=ti.extension.bls)
def test_gathering():
def _test_gathering():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=False,
use_offset=False)


@test_utils.test(require=ti.extension.bls)
def test_gathering_offset():
def _test_gathering_offset():
bls_particle_grid(N=128,
ppc=10,
block_size=8,
scatter=False,
use_offset=True)


@test_utils.test(require=ti.extension.bls)
def test_gathering():
_test_gathering()


@test_utils.test(require=ti.extension.bls)
def test_gathering_offset():
_test_gathering_offset()


@test_utils.test(require=ti.extension.bls)
def test_scattering_two_pointer_levels():
_test_scattering_two_pointer_levels()


@test_utils.test(require=ti.extension.bls)
def test_scattering():
_test_scattering()


@test_utils.test(require=ti.extension.bls)
def test_scattering_offset():
_test_scattering_offset()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gathering_matrix_scalarize():
_test_gathering()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_gathering_offset_matrix_scalarize():
_test_gathering_offset()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scattering_matrix_scalarize():
_test_scattering()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scattering_offset_matrix_scalarize():
_test_scattering_offset()


@test_utils.test(require=ti.extension.bls,
real_matrix=True,
real_matrix_scalarize=True)
def test_scattering_two_pointer_levels_matrix_scalarize():
_test_scattering_two_pointer_levels()


# TODO: debug mode behavior of assume_in_range