From 0ccc11b076306c9b98373c8581e7e985626810dd Mon Sep 17 00:00:00 2001 From: Daiyi Peng Date: Thu, 23 Nov 2023 10:00:02 -0800 Subject: [PATCH] Enable symbolic comparison/hashing on Python functions and methods. This allows us to have consistent hash for the same functions/methods in semantics instead of their memory addresses. PiperOrigin-RevId: 584911091 --- pyglove/core/symbolic/base.py | 7 ++++++- pyglove/core/symbolic/object_test.py | 14 +++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pyglove/core/symbolic/base.py b/pyglove/core/symbolic/base.py index 2df8aea..ad3d9a5 100644 --- a/pyglove/core/symbolic/base.py +++ b/pyglove/core/symbolic/base.py @@ -1597,7 +1597,8 @@ class B: and not inspect.isclass(right) and right.sym_eq.__code__ is not Symbolic.sym_eq.__code__): return right.sym_eq(left) - return left == right + # Compare two maybe callable objects. + return pg_typing.callable_eq(left, right) def ne(left: Any, right: Any) -> bool: @@ -1791,6 +1792,10 @@ def __init__(self, x): """ if isinstance(x, Symbolic): return x.sym_hash() + if inspect.isfunction(x): + return hash(x.__code__.co_code) + if inspect.ismethod(x): + return hash((sym_hash(x.__self__), x.__code__.co_code)) # pytype: disable=attribute-error return hash(x) diff --git a/pyglove/core/symbolic/object_test.py b/pyglove/core/symbolic/object_test.py index 24158ef..6213ab4 100644 --- a/pyglove/core/symbolic/object_test.py +++ b/pyglove/core/symbolic/object_test.py @@ -1355,7 +1355,9 @@ def test_sym_hash(self): ('y', pg_typing.Int().noneable()) ]) class A(Object): - pass + + def result(self): + return self.x + self.y self.assertEqual(hash(A(0)), hash(A(0))) self.assertEqual(hash(A(1, None)), hash(A(1, None))) @@ -1384,6 +1386,16 @@ def __hash__(self): self.assertEqual(hash(A(a)), hash(A(b))) self.assertNotEqual(hash(A(Y(1))), hash(A(Y(2)))) + # Test symbolic hashing for functions and methods. + a = lambda x: x + b = base.from_json_str(base.to_json_str(a)) + self.assertNotEqual(hash(a), hash(b)) + self.assertEqual(base.sym_hash(a), base.sym_hash(b)) + self.assertEqual( + base.sym_hash(A(1, 2).result), base.sym_hash(A(1, 2).result)) + self.assertNotEqual( + base.sym_hash(A(1, 2).result), base.sym_hash(A(2, 3).result)) + def test_sym_parent(self): @pg_members([