Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional strict argument to Type.is_valid_value #995

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def convert_variable(self, var: Variable) -> Optional[Variable]:

return None

def is_valid_value(self, data: D) -> bool:
def is_valid_value(self, data: D, strict: bool = True) -> bool:
"""Return ``True`` for any python object that would be a legal value for a `Variable` of this `Type`."""
try:
self.filter(data, strict=True)
self.filter(data, strict=strict)
return True
except (TypeError, ValueError):
return False
Expand Down
91 changes: 54 additions & 37 deletions aesara/tensor/random/type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Generic, TypeVar

import numpy as np

import aesara
from aesara.graph.type import Type


T = TypeVar("T", np.random.RandomState, np.random.Generator)


gen_states_keys = {
"MT19937": (["state"], ["key", "pos"]),
"PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]),
Expand All @@ -18,22 +23,15 @@
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}


class RandomType(Type):
class RandomType(Type, Generic[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""

@classmethod
def filter(cls, data, strict=False, allow_downcast=None):
if cls.is_valid_value(data, strict):
return data
else:
raise TypeError()

@staticmethod
def may_share_memory(a, b):
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator


class RandomStateType(RandomType):
class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.

The reason this exists (and `Generic` doesn't suffice) is that
Expand All @@ -49,28 +47,38 @@ class RandomStateType(RandomType):
def __repr__(self):
return "RandomStateType"

@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.RandomState):
return True
def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.

In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data

if not strict and isinstance(a, dict):
if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]

for key in gen_keys:
if key not in a:
return False
if key not in data:
raise TypeError()

for key in state_keys:
if key not in a["state"]:
return False
if key not in data["state"]:
raise TypeError()

state_key = a["state"]["key"]
state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
return True
# TODO: Add an option to convert to a `RandomState` instance?
return data

return False
raise TypeError()

@staticmethod
def values_eq(a, b):
Expand Down Expand Up @@ -114,7 +122,7 @@ def __hash__(self):
random_state_type = RandomStateType()


class RandomGeneratorType(RandomType):
class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`.

The reason this exists (and `Generic` doesn't suffice) is that
Expand All @@ -130,16 +138,25 @@ class RandomGeneratorType(RandomType):
def __repr__(self):
return "RandomGeneratorType"

@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.Generator):
return True
def filter(self, data, strict=False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomGeneratorType`.

In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.Generator):
return data

if not strict and isinstance(a, dict):
if "bit_generator" not in a:
return False
if not strict and isinstance(data, dict):
if "bit_generator" not in data:
raise TypeError()
else:
bit_gen_key = a["bit_generator"]
bit_gen_key = data["bit_generator"]

if hasattr(bit_gen_key, "_value"):
bit_gen_key = int(bit_gen_key._value)
Expand All @@ -148,16 +165,16 @@ def is_valid_value(a, strict):
gen_keys, state_keys = gen_states_keys[bit_gen_key]

for key in gen_keys:
if key not in a:
return False
if key not in data:
raise TypeError()

for key in state_keys:
if key not in a["state"]:
return False
if key not in data["state"]:
raise TypeError()

return True
return data

return False
raise TypeError()

@staticmethod
def values_eq(a, b):
Expand Down
28 changes: 16 additions & 12 deletions tests/tensor/random/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,17 @@ def test_filter(self):
with pytest.raises(TypeError):
rng_type.filter(1)

rng = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng, strict=False)
rng_dict = rng.get_state(legacy=False)

rng["state"] = {}
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)

assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict["state"] = {}

rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
assert rng_type.is_valid_value(rng_dict, strict=False) is False

rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False

def test_values_eq(self):

Expand Down Expand Up @@ -147,15 +149,17 @@ def test_filter(self):
with pytest.raises(TypeError):
rng_type.filter(1)

rng = rng.__getstate__()
assert rng_type.is_valid_value(rng, strict=False)
rng_dict = rng.__getstate__()

assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)

rng["state"] = {}
rng_dict["state"] = {}

assert rng_type.is_valid_value(rng, strict=False) is False
assert rng_type.is_valid_value(rng_dict, strict=False) is False

rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False

def test_values_eq(self):

Expand Down