Skip to content

Commit

Permalink
[bug] Fix name collision in ti.dataclass (#6737)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Dec 5, 2022
1 parent 354e7de commit f430eae
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
18 changes: 11 additions & 7 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class Struct(TaichiOperations):
dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})])
"""
_is_taichi_class = True
_instance_count = 0

def __init__(self, *args, **kwargs):
# converts lists to matrices and dicts to structs
Expand Down Expand Up @@ -96,12 +97,15 @@ def items(self):
return self.entries.items()

def _register_members(self):
for k in self.keys:
setattr(Struct, k,
property(
Struct._make_getter(k),
Struct._make_setter(k),
))
# https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance
cls = self.__class__
new_cls_name = cls.__name__ + str(cls._instance_count)
cls._instance_count += 1
properties = {
k: property(cls._make_getter(k), cls._make_setter(k))
for k in self.keys
}
self.__class__ = type(new_cls_name, (cls, ), properties)

def _register_methods(self):
for name, method in self.methods.items():
Expand Down Expand Up @@ -769,7 +773,7 @@ def dataclass(cls):
and methods from the class attached.
"""
# save the annotation fields for the struct
fields = cls.__annotations__
fields = getattr(cls, '__annotations__', {})
# get the class methods to be attached to the struct types
fields['__struct_methods'] = {
attribute: getattr(cls, attribute)
Expand Down
18 changes: 18 additions & 0 deletions tests/python/test_custom_struct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from pytest import approx
from taichi.lang.misc import get_host_arch_list

import taichi as ti
from tests import test_utils
Expand Down Expand Up @@ -443,3 +444,20 @@ def test():
assert A.mass == 2.0

test()


@test_utils.test(arch=get_host_arch_list())
def test_name_collision():
# https://github.com/taichi-dev/taichi/issues/6652
@ti.dataclass
class Foo:
zoo: ti.f32

@ti.dataclass
class Bar:
@ti.func
def zoo(self):
return 0

Foo() # instantiate struct with zoo as member first
Bar() # then instantiate struct with zoo as method

0 comments on commit f430eae

Please sign in to comment.