Skip to content

Commit

Permalink
Enable symbolic comparison/hashing on Python functions and methods.
Browse files Browse the repository at this point in the history
This allows us to have consistent hash for the same functions/methods in semantics instead of their memory addresses.

PiperOrigin-RevId: 584911091
  • Loading branch information
daiyip authored and pyglove authors committed Nov 23, 2023
1 parent e223f6d commit 0ccc11b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
7 changes: 6 additions & 1 deletion pyglove/core/symbolic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,8 @@ class B:
and not inspect.isclass(right)
and right.sym_eq.__code__ is not Symbolic.sym_eq.__code__):
return right.sym_eq(left)
return left == right
# Compare two maybe callable objects.
return pg_typing.callable_eq(left, right)


def ne(left: Any, right: Any) -> bool:
Expand Down Expand Up @@ -1791,6 +1792,10 @@ def __init__(self, x):
"""
if isinstance(x, Symbolic):
return x.sym_hash()
if inspect.isfunction(x):
return hash(x.__code__.co_code)
if inspect.ismethod(x):
return hash((sym_hash(x.__self__), x.__code__.co_code)) # pytype: disable=attribute-error
return hash(x)


Expand Down
14 changes: 13 additions & 1 deletion pyglove/core/symbolic/object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,9 @@ def test_sym_hash(self):
('y', pg_typing.Int().noneable())
])
class A(Object):
pass

def result(self):
return self.x + self.y

self.assertEqual(hash(A(0)), hash(A(0)))
self.assertEqual(hash(A(1, None)), hash(A(1, None)))
Expand Down Expand Up @@ -1384,6 +1386,16 @@ def __hash__(self):
self.assertEqual(hash(A(a)), hash(A(b)))
self.assertNotEqual(hash(A(Y(1))), hash(A(Y(2))))

# Test symbolic hashing for functions and methods.
a = lambda x: x
b = base.from_json_str(base.to_json_str(a))
self.assertNotEqual(hash(a), hash(b))
self.assertEqual(base.sym_hash(a), base.sym_hash(b))
self.assertEqual(
base.sym_hash(A(1, 2).result), base.sym_hash(A(1, 2).result))
self.assertNotEqual(
base.sym_hash(A(1, 2).result), base.sym_hash(A(2, 3).result))

def test_sym_parent(self):

@pg_members([
Expand Down

0 comments on commit 0ccc11b

Please sign in to comment.