Skip to content

Commit

Permalink
Reset random state on seed change (#2130)
Browse files Browse the repository at this point in the history
The original semantics was to reset random state once per context with the most up-to-date value of the seed.
The new approach is to reset random state every time the seed value changes, whether through direct assignment or modulation.

The practical change in semantics is subtle and helps match results between virtual executions and the simulated composition (see the comments in 'test_modulation_of_random_state' for details).
  • Loading branch information
jvesely authored Oct 1, 2021
2 parents 2c3a429 + 429ffea commit 5405ac9
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 81 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_func_execution(func, func_mode):
elif func_mode == 'Python':
return func.function
else:
assert False, "Unknown function mode: {}".format(mech_mode)
assert False, "Unknown function mode: {}".format(func_mode)

@pytest.helpers.register
def get_mech_execution(mech, mech_mode):
Expand Down
72 changes: 35 additions & 37 deletions psyneulink/core/components/functions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,49 +333,48 @@ def _seed_setter(value, owning_component, context):
if value in {None, DEFAULT_SEED}:
value = get_global_seed()

value = int(value)
# Remove any old PRNG state
owning_component.parameters.random_state.set(None, context=context)
return int(value)

owning_component.parameters.random_state._set(
np.random.RandomState([value]),
context
)

return value
class SeededRandomState(np.random.RandomState):
def __init__(self, *args, **kwargs):
# Extract seed
self.used_seed = (kwargs.get('seed', None) or args[0])[:]
super().__init__(*args, **kwargs)

def __deepcopy__(self, memo):
# There's no easy way to deepcopy parent first.
# Create new instance and rewrite the state.
dup = type(self)(seed=self.used_seed)
dup.set_state(self.get_state())
return dup

def seed(self, seed):
assert False, "Use 'seed' parameter instead of seeding the random state directly"


def _random_state_getter(self, owning_component, context):
seed_param = owning_component.parameters.seed

seed_param = owning_component.parameters.seed
try:
is_modulated = seed_param._port.is_modulated(context)
except AttributeError:
# no ParameterPort
pass
is_modulated = False

if is_modulated:
seed_value = [int(owning_component._get_current_parameter_value(seed_param, context))]
else:
if is_modulated:
# can manage reset_for_context only in getter because we
# don't want to store any copied values from other contexts
# (from _initialize_from_context)
try:
reset_for_context = self._reset_for_context[context.execution_id]
except AttributeError:
self._reset_for_context = {}
reset_for_context = False
except KeyError:
reset_for_context = False

if not reset_for_context:
self._reset_for_context[context.execution_id] = True
return np.random.RandomState([
int(
owning_component._get_current_parameter_value(
seed_param,
context
)
)
])

return self.values[context.execution_id]
seed_value = [int(seed_param._get(context=context))]

assert seed_value != [DEFAULT_SEED], "Invalid seed for {} in context: {} ({})".format(owning_component, context.execution_id, seed_param)

current_state = self.values.get(context.execution_id, None)
if current_state is None or current_state.used_seed != seed_value:
return SeededRandomState(seed_value)

return current_state


class Function_Base(Function):
Expand Down Expand Up @@ -616,11 +615,10 @@ def __deepcopy__(self, memo):
new = super().__deepcopy__(memo)
# ensure copy does not have identical name
register_category(new, Function_Base, new.name, FunctionRegistry)
try:
if "random_state" in new.parameters:
# HACK: Make sure any copies are re-seeded to avoid dependent RNG.
new.random_state.seed([get_global_seed()])
except:
pass
# functions with "random_state" param must have "seed" parameter
new.seed.base = DEFAULT_SEED
return new

@handle_external_context()
Expand Down
64 changes: 35 additions & 29 deletions tests/composition/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,46 +1050,52 @@ def test_control_of_mech_port(self, comp_mode):

@pytest.mark.control
@pytest.mark.composition
def test_modulation_of_random_state(self):
src = pnl.ProcessingMechanism()
mech = pnl.ProcessingMechanism(function=pnl.UniformDist())

comp = pnl.Composition(retain_old_simulation_data=True)
@pytest.mark.parametrize("mode", [pnl.ExecutionMode.Python])
@pytest.mark.parametrize("num_generators", [5])
def test_modulation_of_random_state(self, mode, num_generators):
obj = pnl.ObjectiveMechanism()
# Set original seed that is not used by any evaluation
# this prevents dirty state from initialization skewing the results.
# The alternative would be to set:
# mech.functions.seed.base = mech.functions.seed.base
# to reset the PRNG
mech = pnl.ProcessingMechanism(function=pnl.UniformDist(seed=num_generators))

comp = pnl.Composition(retain_old_simulation_data=True,
controller_mode=pnl.BEFORE)
comp.add_node(mech, required_roles=pnl.NodeRole.INPUT)
comp.add_node(src)
comp.add_linear_processing_pathway([mech, obj])

comp.add_controller(
pnl.OptimizationControlMechanism(
monitor_for_control=src,
objective_mechanism=obj,
control_signals=pnl.ControlSignal(
modulates=('seed', mech),
modulation=pnl.OVERRIDE,
allocation_samples=pnl.SampleSpec(start=0, stop=5, step=1),
allocation_samples=pnl.SampleSpec(start=0, stop=num_generators - 1, step=1),
cost_options=pnl.CostFunctions.NONE
)
)
)

def seed_check(context):
latest_sim = comp.controller.parameters.simulation_ids._get(context)[-1]

seed = mech.get_mod_seed(latest_sim)
rs = mech.function.parameters.random_state.get(latest_sim)

# mech (and so its function and random_state) should be called
# exactly twice in one trial given the specified condition,
# and the random_state should be reset before the first execution
new_rs = np.random.RandomState([int(seed)])

for i in range(2):
new_rs.uniform(0, 1)

assert rs.uniform(0, 1) == new_rs.uniform(0, 1)

comp.termination_processing = {pnl.TimeScale.TRIAL: pnl.AfterNCalls(mech, 2)}
comp.run(
inputs={src: [1], mech: [1]},
call_after_trial=seed_check
)
comp.run(inputs={mech: [1]}, num_trials=2, execution_mode=mode)

# Construct expected results.
# First all generators rest their sequence.
# In the second trial, the "winning" seed from the previous one continues its
# random sequence
all_generators = [np.random.RandomState([seed]) for seed in range(num_generators)]
first_generator_samples = [g.uniform(0, 1) for g in all_generators]
best_first = max(first_generator_samples)
index_best = first_generator_samples.index(best_first)
second_generator_samples = [g.uniform(0, 1) for g in all_generators]
second_considerations = first_generator_samples[:index_best] + \
second_generator_samples[index_best:index_best + 1] + \
first_generator_samples[index_best + 1:]
best_second = max(second_considerations)
# Check that we select the maximum of generated values
assert np.allclose(best_first, comp.results[0])
assert np.allclose(best_second, comp.results[1])


class TestModelBasedOptimizationControlMechanisms:
Expand Down
4 changes: 2 additions & 2 deletions tests/mechanisms/test_mechanisms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def test_noise_variations(self, noise):
t2 = pnl.TransferMechanism(name='t2', size=2)
t2.integrator_function.parameters.noise.set(noise())

t1.integrator_function.noise.base.random_state = np.random.RandomState([0])
t2.integrator_function.noise.base.random_state = np.random.RandomState([0])
t1.integrator_function.noise.base.seed = 0
t2.integrator_function.noise.base.seed = 0

for _ in range(5):
np.testing.assert_equal(t1.execute([1, 1]), t2.execute([1, 1]))
Expand Down
27 changes: 15 additions & 12 deletions tests/mechanisms/test_transfer_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,18 +301,7 @@ def test_transfer_mech_exponential_noise(self):
def test_transfer_mech_uniform_to_normal_noise(self):
try:
import scipy
T = TransferMechanism(
name='T',
default_variable=[0, 0, 0, 0],
function=Linear(),
noise=UniformToNormalDist(),
integration_rate=1.0
)
T.noise.base.parameters.random_state.get(None).seed(22)
val = T.execute([0, 0, 0, 0])
assert np.allclose(val, [[-0.81177443, -0.04593492, -0.20051725, 1.07665147]])

except:
except ModuleNotFoundError:
with pytest.raises(FunctionError) as error_text:
T = TransferMechanism(
name='T',
Expand All @@ -322,6 +311,20 @@ def test_transfer_mech_uniform_to_normal_noise(self):
integration_rate=1.0
)
assert "The UniformToNormalDist function requires the SciPy package." in str(error_text.value)
else:
T = TransferMechanism(
name='T',
default_variable=[0, 0, 0, 0],
function=Linear(),
noise=UniformToNormalDist(),
integration_rate=1.0
)
# This is equivalent to
# T.noise.base.parameters.random_state.get(None).seed([22])
T.noise.base.parameters.seed.set(22, None)
val = T.execute([0, 0, 0, 0])
assert np.allclose(val, [[1.73027452, -1.07866481, -1.98421126, 2.99564032]])



@pytest.mark.mechanism
Expand Down

0 comments on commit 5405ac9

Please sign in to comment.