From 4c35879ba4ae39bacfa776690265da3fe301d6ca Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Tue, 23 Nov 2021 16:39:22 +0800 Subject: [PATCH 1/2] [mesh] Make ti.Mesh compatible with dynamic index --- python/taichi/lang/__init__.py | 3 --- taichi/ir/frontend_ir.h | 8 ++++---- tests/python/test_mesh.py | 14 +++++--------- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 4de41afb4ace0..dbe46b5535d0d 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -591,9 +591,6 @@ def block_local(*args): def mesh_local(*args): - if ti.current_cfg().dynamic_index: - raise InvalidOperationError( - 'dynamic_index is not allowed when mesh_local is turned on.') for a in args: for v in a.get_field_members(): _ti_core.insert_snode_access_flag( diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 11f7efcbdd0d1..d91750ffa646b 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -836,7 +836,7 @@ class MeshRelationAccessExpression : public Expression { MeshRelationAccessExpression(mesh::Mesh *mesh, const Expr mesh_idx, mesh::MeshElementType to_type) - : mesh(mesh), mesh_idx(mesh_idx), to_type(to_type) { + : mesh(mesh), mesh_idx(load_if_ptr(mesh_idx)), to_type(to_type) { } MeshRelationAccessExpression(mesh::Mesh *mesh, @@ -844,9 +844,9 @@ class MeshRelationAccessExpression : public Expression { mesh::MeshElementType to_type, const Expr neighbor_idx) : mesh(mesh), - mesh_idx(mesh_idx), + mesh_idx(load_if_ptr(mesh_idx)), to_type(to_type), - neighbor_idx(neighbor_idx) { + neighbor_idx(load_if_ptr(neighbor_idx)) { } void flatten(FlattenContext *ctx) override; @@ -872,7 +872,7 @@ class MeshIndexConversionExpression : public Expression { mesh::MeshElementType idx_type, const Expr idx, mesh::ConvType conv_type) - : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { + : mesh(mesh), idx_type(idx_type), idx(load_if_ptr(idx)), conv_type(conv_type) { } void flatten(FlattenContext *ctx) override; diff --git a/tests/python/test_mesh.py b/tests/python/test_mesh.py index e2215b5d98ffe..f3104fa49d682 100644 --- a/tests/python/test_mesh.py +++ b/tests/python/test_mesh.py @@ -8,7 +8,7 @@ model_file_path = os.path.join(this_dir, 'ell.json') -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_patch_idx(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'idx': ti.i32}) @@ -83,21 +83,19 @@ def vert_vert(): assert total == 1144 -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_for(): _test_mesh_for(False, False) _test_mesh_for(False, True) @ti.test(require=ti.extension.mesh, - dynamic_index=False, optimize_mesh_reordered_mapping=False) def test_mesh_reordered_opt(): _test_mesh_for(True, True, False) @ti.test(require=ti.extension.mesh, - dynamic_index=False, mesh_localize_to_end_mapping=False) def test_mesh_localize_mapping0(): _test_mesh_for(False, False, False) @@ -105,14 +103,13 @@ def test_mesh_localize_mapping0(): @ti.test(require=ti.extension.mesh, - dynamic_index=False, mesh_localize_from_end_mapping=True) def test_mesh_localize_mapping1(): _test_mesh_for(False, False, False) _test_mesh_for(True, True, False) -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_reorder(): vec3i = ti.types.vector(3, ti.i32) mesh_builder = ti.Mesh.Tet() @@ -150,7 +147,7 @@ def foo(): assert id234[i][2] == i**4 -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_minor_relations(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'y': ti.i32}) @@ -175,7 +172,6 @@ def foo(): @ti.test(require=ti.extension.mesh, - dynamic_index=False, demote_no_access_mesh_fors=True) def test_multiple_meshes(): mesh_builder = ti.Mesh.Tet() @@ -198,7 +194,7 @@ def foo(): assert out[i] == i**2 -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_local(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'a': ti.i32}) From 455b66c3f111c6791bd8b1027ccc04791f4e5137 Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Tue, 23 Nov 2021 08:45:10 +0000 Subject: [PATCH 2/2] Auto Format --- taichi/ir/frontend_ir.h | 5 ++++- tests/python/test_mesh.py | 12 ++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index d91750ffa646b..389c7f27f7b81 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -872,7 +872,10 @@ class MeshIndexConversionExpression : public Expression { mesh::MeshElementType idx_type, const Expr idx, mesh::ConvType conv_type) - : mesh(mesh), idx_type(idx_type), idx(load_if_ptr(idx)), conv_type(conv_type) { + : mesh(mesh), + idx_type(idx_type), + idx(load_if_ptr(idx)), + conv_type(conv_type) { } void flatten(FlattenContext *ctx) override; diff --git a/tests/python/test_mesh.py b/tests/python/test_mesh.py index f3104fa49d682..667187a390210 100644 --- a/tests/python/test_mesh.py +++ b/tests/python/test_mesh.py @@ -89,21 +89,18 @@ def test_mesh_for(): _test_mesh_for(False, True) -@ti.test(require=ti.extension.mesh, - optimize_mesh_reordered_mapping=False) +@ti.test(require=ti.extension.mesh, optimize_mesh_reordered_mapping=False) def test_mesh_reordered_opt(): _test_mesh_for(True, True, False) -@ti.test(require=ti.extension.mesh, - mesh_localize_to_end_mapping=False) +@ti.test(require=ti.extension.mesh, mesh_localize_to_end_mapping=False) def test_mesh_localize_mapping0(): _test_mesh_for(False, False, False) _test_mesh_for(True, True, False) -@ti.test(require=ti.extension.mesh, - mesh_localize_from_end_mapping=True) +@ti.test(require=ti.extension.mesh, mesh_localize_from_end_mapping=True) def test_mesh_localize_mapping1(): _test_mesh_for(False, False, False) _test_mesh_for(True, True, False) @@ -171,8 +168,7 @@ def foo(): assert total == 576 -@ti.test(require=ti.extension.mesh, - demote_no_access_mesh_fors=True) +@ti.test(require=ti.extension.mesh, demote_no_access_mesh_fors=True) def test_multiple_meshes(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'y': ti.i32})