diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index 4de41afb4ace0..dbe46b5535d0d 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -591,9 +591,6 @@ def block_local(*args): def mesh_local(*args): - if ti.current_cfg().dynamic_index: - raise InvalidOperationError( - 'dynamic_index is not allowed when mesh_local is turned on.') for a in args: for v in a.get_field_members(): _ti_core.insert_snode_access_flag( diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 11f7efcbdd0d1..389c7f27f7b81 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -836,7 +836,7 @@ class MeshRelationAccessExpression : public Expression { MeshRelationAccessExpression(mesh::Mesh *mesh, const Expr mesh_idx, mesh::MeshElementType to_type) - : mesh(mesh), mesh_idx(mesh_idx), to_type(to_type) { + : mesh(mesh), mesh_idx(load_if_ptr(mesh_idx)), to_type(to_type) { } MeshRelationAccessExpression(mesh::Mesh *mesh, @@ -844,9 +844,9 @@ class MeshRelationAccessExpression : public Expression { mesh::MeshElementType to_type, const Expr neighbor_idx) : mesh(mesh), - mesh_idx(mesh_idx), + mesh_idx(load_if_ptr(mesh_idx)), to_type(to_type), - neighbor_idx(neighbor_idx) { + neighbor_idx(load_if_ptr(neighbor_idx)) { } void flatten(FlattenContext *ctx) override; @@ -872,7 +872,10 @@ class MeshIndexConversionExpression : public Expression { mesh::MeshElementType idx_type, const Expr idx, mesh::ConvType conv_type) - : mesh(mesh), idx_type(idx_type), idx(idx), conv_type(conv_type) { + : mesh(mesh), + idx_type(idx_type), + idx(load_if_ptr(idx)), + conv_type(conv_type) { } void flatten(FlattenContext *ctx) override; diff --git a/tests/python/test_mesh.py b/tests/python/test_mesh.py index e2215b5d98ffe..667187a390210 100644 --- a/tests/python/test_mesh.py +++ b/tests/python/test_mesh.py @@ -8,7 +8,7 @@ model_file_path = os.path.join(this_dir, 'ell.json') -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_patch_idx(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'idx': ti.i32}) @@ -83,36 +83,30 @@ def vert_vert(): assert total == 1144 -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_for(): _test_mesh_for(False, False) _test_mesh_for(False, True) -@ti.test(require=ti.extension.mesh, - dynamic_index=False, - optimize_mesh_reordered_mapping=False) +@ti.test(require=ti.extension.mesh, optimize_mesh_reordered_mapping=False) def test_mesh_reordered_opt(): _test_mesh_for(True, True, False) -@ti.test(require=ti.extension.mesh, - dynamic_index=False, - mesh_localize_to_end_mapping=False) +@ti.test(require=ti.extension.mesh, mesh_localize_to_end_mapping=False) def test_mesh_localize_mapping0(): _test_mesh_for(False, False, False) _test_mesh_for(True, True, False) -@ti.test(require=ti.extension.mesh, - dynamic_index=False, - mesh_localize_from_end_mapping=True) +@ti.test(require=ti.extension.mesh, mesh_localize_from_end_mapping=True) def test_mesh_localize_mapping1(): _test_mesh_for(False, False, False) _test_mesh_for(True, True, False) -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_reorder(): vec3i = ti.types.vector(3, ti.i32) mesh_builder = ti.Mesh.Tet() @@ -150,7 +144,7 @@ def foo(): assert id234[i][2] == i**4 -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_minor_relations(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'y': ti.i32}) @@ -174,9 +168,7 @@ def foo(): assert total == 576 -@ti.test(require=ti.extension.mesh, - dynamic_index=False, - demote_no_access_mesh_fors=True) +@ti.test(require=ti.extension.mesh, demote_no_access_mesh_fors=True) def test_multiple_meshes(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'y': ti.i32}) @@ -198,7 +190,7 @@ def foo(): assert out[i] == i**2 -@ti.test(require=ti.extension.mesh, dynamic_index=False) +@ti.test(require=ti.extension.mesh) def test_mesh_local(): mesh_builder = ti.Mesh.Tet() mesh_builder.verts.place({'a': ti.i32})