Skip to content

Commit

Permalink
Merge pull request #2702 from PrincetonUniversity/pec_optuna
Browse files Browse the repository at this point in the history
Add Optuna support to PEC
  • Loading branch information
davidt0x authored Jun 26, 2023
2 parents a944ad1 + fbc78b5 commit d4a5c64
Show file tree
Hide file tree
Showing 5 changed files with 389 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import psyneulink as pnl
import pandas as pd
import optuna

from psyneulink.core.globals.utilities import set_global_seed

Expand Down Expand Up @@ -91,7 +92,7 @@
responseGate = comp.nodes["RESPONSE_GATE"]

fit_parameters = {
("threshold", decisionMaker): np.linspace(0.01, 0.5, 10), # Threshold
("threshold", decisionMaker): np.linspace(0.01, 0.5, 100), # Threshold
}


Expand All @@ -113,7 +114,9 @@ def reward_rate(sim_data):
responseGate.output_ports[0],
],
objective_function=reward_rate,
optimization_function='differential_evolution',
optimization_function=pnl.PECOptimizationFunction(method=optuna.samplers.CmaEsSampler(),
max_iterations=50,
direction='minimize'),
num_estimates=num_estimates,
)

Expand Down
4 changes: 3 additions & 1 deletion psyneulink/core/components/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import function
from .nonstateful import selectionfunctions, objectivefunctions, optimizationfunctions, combinationfunctions, \
learningfunctions, transferfunctions, distributionfunctions
learningfunctions, transferfunctions, distributionfunctions, fitfunctions
from . import stateful
from .stateful import integratorfunctions, memoryfunctions
from . import userdefinedfunction
Expand All @@ -12,6 +12,7 @@
from psyneulink.core.components.functions.nonstateful.distributionfunctions import *
from psyneulink.core.components.functions.nonstateful.objectivefunctions import *
from psyneulink.core.components.functions.nonstateful.optimizationfunctions import *
from psyneulink.core.components.functions.nonstateful.fitfunctions import *
from psyneulink.core.components.functions.nonstateful.learningfunctions import *
from .stateful import *
from psyneulink.core.components.functions.stateful.integratorfunctions import *
Expand All @@ -27,6 +28,7 @@
__all__.extend(distributionfunctions.__all__)
__all__.extend(objectivefunctions.__all__)
__all__.extend(optimizationfunctions.__all__)
__all__.extend(fitfunctions.__all__)
__all__.extend(learningfunctions.__all__)
__all__.extend(integratorfunctions.__all__)
__all__.extend(memoryfunctions.__all__)
Loading

0 comments on commit d4a5c64

Please sign in to comment.