Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset random state on seed change #2130

Merged
merged 8 commits into from
Oct 1, 2021
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_func_execution(func, func_mode):
elif func_mode == 'Python':
return func.function
else:
assert False, "Unknown function mode: {}".format(mech_mode)
assert False, "Unknown function mode: {}".format(func_mode)

@pytest.helpers.register
def get_mech_execution(mech, mech_mode):
Expand Down
72 changes: 35 additions & 37 deletions psyneulink/core/components/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,49 +333,48 @@ def _seed_setter(value, owning_component, context):
if value in {None, DEFAULT_SEED}:
value = get_global_seed()

value = int(value)
# Remove any old PRNG state
owning_component.parameters.random_state.set(None, context=context)
return int(value)

owning_component.parameters.random_state._set(
np.random.RandomState([value]),
context
)

return value
class SeededRandomState(np.random.RandomState):
def __init__(self, *args, **kwargs):
# Extract seed
self.used_seed = (kwargs.get('seed', None) or args[0])[:]
super().__init__(*args, **kwargs)

def __deepcopy__(self, memo):
# There's no easy way to deepcopy parent first.
# Create new instance and rewrite the state.
dup = type(self)(seed=self.used_seed)
dup.set_state(self.get_state())
return dup

def seed(self, seed):
assert False, "Use 'seed' parameter instead of seeding the random state directly"


def _random_state_getter(self, owning_component, context):
seed_param = owning_component.parameters.seed

seed_param = owning_component.parameters.seed
try:
is_modulated = seed_param._port.is_modulated(context)
except AttributeError:
# no ParameterPort
pass
is_modulated = False

if is_modulated:
seed_value = [int(owning_component._get_current_parameter_value(seed_param, context))]
else:
if is_modulated:
# can manage reset_for_context only in getter because we
# don't want to store any copied values from other contexts
# (from _initialize_from_context)
try:
reset_for_context = self._reset_for_context[context.execution_id]
except AttributeError:
self._reset_for_context = {}
reset_for_context = False
except KeyError:
reset_for_context = False

if not reset_for_context:
self._reset_for_context[context.execution_id] = True
return np.random.RandomState([
int(
owning_component._get_current_parameter_value(
seed_param,
context
)
)
])

return self.values[context.execution_id]
seed_value = [int(seed_param._get(context=context))]

assert seed_value != [DEFAULT_SEED], "Invalid seed for {} in context: {} ({})".format(owning_component, context.execution_id, seed_param)

current_state = self.values.get(context.execution_id, None)
if current_state is None or current_state.used_seed != seed_value:
return SeededRandomState(seed_value)

return current_state


class Function_Base(Function):
Expand Down Expand Up @@ -616,11 +615,10 @@ def __deepcopy__(self, memo):
new = super().__deepcopy__(memo)
# ensure copy does not have identical name
register_category(new, Function_Base, new.name, FunctionRegistry)
try:
if "random_state" in new.parameters:
# HACK: Make sure any copies are re-seeded to avoid dependent RNG.
new.random_state.seed([get_global_seed()])
except:
pass
# functions with "random_state" param must have "seed" parameter
new.seed.base = DEFAULT_SEED
return new

@handle_external_context()
Expand Down
64 changes: 35 additions & 29 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,46 +1050,52 @@ def test_control_of_mech_port(self, comp_mode):

@pytest.mark.control
@pytest.mark.composition
def test_modulation_of_random_state(self):
src = pnl.ProcessingMechanism()
mech = pnl.ProcessingMechanism(function=pnl.UniformDist())

comp = pnl.Composition(retain_old_simulation_data=True)
@pytest.mark.parametrize("mode", [pnl.ExecutionMode.Python])
@pytest.mark.parametrize("num_generators", [5])
def test_modulation_of_random_state(self, mode, num_generators):
obj = pnl.ObjectiveMechanism()
# Set original seed that is not used by any evaluation
# this prevents dirty state from initialization skewing the results.
# The alternative would be to set:
# mech.functions.seed.base = mech.functions.seed.base
# to reset the PRNG
mech = pnl.ProcessingMechanism(function=pnl.UniformDist(seed=num_generators))

comp = pnl.Composition(retain_old_simulation_data=True,
controller_mode=pnl.BEFORE)
comp.add_node(mech, required_roles=pnl.NodeRole.INPUT)
comp.add_node(src)
comp.add_linear_processing_pathway([mech, obj])

comp.add_controller(
pnl.OptimizationControlMechanism(
monitor_for_control=src,
objective_mechanism=obj,
control_signals=pnl.ControlSignal(
modulates=('seed', mech),
modulation=pnl.OVERRIDE,
allocation_samples=pnl.SampleSpec(start=0, stop=5, step=1),
allocation_samples=pnl.SampleSpec(start=0, stop=num_generators - 1, step=1),
cost_options=pnl.CostFunctions.NONE
)
)
)

def seed_check(context):
latest_sim = comp.controller.parameters.simulation_ids._get(context)[-1]

seed = mech.get_mod_seed(latest_sim)
rs = mech.function.parameters.random_state.get(latest_sim)

# mech (and so its function and random_state) should be called
# exactly twice in one trial given the specified condition,
# and the random_state should be reset before the first execution
new_rs = np.random.RandomState([int(seed)])

for i in range(2):
new_rs.uniform(0, 1)

assert rs.uniform(0, 1) == new_rs.uniform(0, 1)

comp.termination_processing = {pnl.TimeScale.TRIAL: pnl.AfterNCalls(mech, 2)}
comp.run(
inputs={src: [1], mech: [1]},
call_after_trial=seed_check
)
comp.run(inputs={mech: [1]}, num_trials=2, execution_mode=mode)

# Construct expected results.
# First all generators rest their sequence.
# In the second trial, the "winning" seed from the previous one continues its
# random sequence
all_generators = [np.random.RandomState([seed]) for seed in range(num_generators)]
first_generator_samples = [g.uniform(0, 1) for g in all_generators]
best_first = max(first_generator_samples)
index_best = first_generator_samples.index(best_first)
second_generator_samples = [g.uniform(0, 1) for g in all_generators]
second_considerations = first_generator_samples[:index_best] + \
second_generator_samples[index_best:index_best + 1] + \
first_generator_samples[index_best + 1:]
best_second = max(second_considerations)
# Check that we select the maximum of generated values
assert np.allclose(best_first, comp.results[0])
assert np.allclose(best_second, comp.results[1])


class TestModelBasedOptimizationControlMechanisms:
Expand Down
4 changes: 2 additions & 2 deletions tests/mechanisms/test_mechanisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_noise_variations(self, noise):
t2 = pnl.TransferMechanism(name='t2', size=2)
t2.integrator_function.parameters.noise.set(noise())

t1.integrator_function.noise.base.random_state = np.random.RandomState([0])
t2.integrator_function.noise.base.random_state = np.random.RandomState([0])
t1.integrator_function.noise.base.seed = 0
t2.integrator_function.noise.base.seed = 0

for _ in range(5):
np.testing.assert_equal(t1.execute([1, 1]), t2.execute([1, 1]))
Expand Down
27 changes: 15 additions & 12 deletions tests/mechanisms/test_transfer_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,18 +301,7 @@ def test_transfer_mech_exponential_noise(self):
def test_transfer_mech_uniform_to_normal_noise(self):
try:
import scipy
T = TransferMechanism(
name='T',
default_variable=[0, 0, 0, 0],
function=Linear(),
noise=UniformToNormalDist(),
integration_rate=1.0
)
T.noise.base.parameters.random_state.get(None).seed(22)
val = T.execute([0, 0, 0, 0])
assert np.allclose(val, [[-0.81177443, -0.04593492, -0.20051725, 1.07665147]])

except:
except ModuleNotFoundError:
with pytest.raises(FunctionError) as error_text:
T = TransferMechanism(
name='T',
Expand All @@ -322,6 +311,20 @@ def test_transfer_mech_uniform_to_normal_noise(self):
integration_rate=1.0
)
assert "The UniformToNormalDist function requires the SciPy package." in str(error_text.value)
else:
T = TransferMechanism(
name='T',
default_variable=[0, 0, 0, 0],
function=Linear(),
noise=UniformToNormalDist(),
integration_rate=1.0
)
# This is equivalent to
# T.noise.base.parameters.random_state.get(None).seed([22])
T.noise.base.parameters.seed.set(22, None)
val = T.execute([0, 0, 0, 0])
assert np.allclose(val, [[1.73027452, -1.07866481, -1.98421126, 2.99564032]])



@pytest.mark.mechanism
Expand Down