Skip to content

Commit

Permalink
Avoid specialization on search options
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Mar 14, 2024
1 parent 7d3d4f6 commit 3a33882
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ abstract type AbstractSRRegressor <: MMI.Deterministic end
"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
function modelexpr(model_name::Symbol)
struct_def = :(Base.@kwdef mutable struct $(model_name){
D<:AbstractDimensions,L,use_recorder,N<:AbstractExpressionNode
D<:AbstractDimensions,L,N<:AbstractExpressionNode
} <: AbstractSRRegressor
niterations::Int = 10
node_type::Type{N} = Node
Expand Down
14 changes: 4 additions & 10 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ function Options end
optimizer_probability::Real=0.14,
optimizer_iterations::Union{Nothing,Integer}=nothing,
optimizer_options::Union{Dict,NamedTuple,Optim.Options,Nothing}=nothing,
val_recorder::Val{use_recorder}=Val(false),
use_recorder::Bool=false,
recorder_file::AbstractString="pysr_recorder.json",
early_stop_condition::Union{Function,Real,Nothing}=nothing,
timeout_in_seconds::Union{Nothing,Real}=nothing,
Expand All @@ -435,7 +435,7 @@ function Options end
npopulations::Union{Nothing,Integer}=nothing,
npop::Union{Nothing,Integer}=nothing,
kws...,
) where {use_recorder}
)
for k in keys(kws)
!haskey(deprecated_options_mapping, k) && error("Unknown keyword argument: $k")
new_key = deprecated_options_mapping[k]
Expand Down Expand Up @@ -753,14 +753,7 @@ function Options end
@assert print_precision > 0

options = Options{
eltype(complexity_mapping),
typeof(operators),
use_recorder,
typeof(optimizer_options),
typeof(optimizer_algorithm),
turbo,
bumper,
typeof(tournament_selection_weights),
eltype(complexity_mapping),turbo,bumper,typeof(tournament_selection_weights)
}(
operators,
bin_constraints,
Expand Down Expand Up @@ -822,6 +815,7 @@ function Options end
nested_constraints,
deterministic,
define_helper_functions,
use_recorder,
)

return options
Expand Down
20 changes: 6 additions & 14 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module OptionsStructModule

using Optim: Optim
using DynamicExpressions: AbstractOperatorEnum
using DynamicExpressions: AbstractOperatorEnum, OperatorEnum
using LossFunctions: SupervisedLoss

import ..MutationWeightsModule: MutationWeights
Expand Down Expand Up @@ -38,17 +38,8 @@ function ComplexityMapping(;
)
end

struct Options{
CT,
OP<:AbstractOperatorEnum,
use_recorder,
OPT<:Optim.Options,
OPT_A<:Optim.AbstractOptimizer,
_turbo,
_bumper,
W,
}
operators::OP
struct Options{CT,_turbo,_bumper,W}
operators::AbstractOperatorEnum
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
complexity_mapping::ComplexityMapping{CT}
Expand Down Expand Up @@ -94,10 +85,10 @@ struct Options{
loss_function::Union{Nothing,Function}
progress::Union{Bool,Nothing}
terminal_width::Union{Int,Nothing}
optimizer_algorithm::OPT_A
optimizer_algorithm::Optim.AbstractOptimizer
optimizer_probability::Float32
optimizer_nrestarts::Int
optimizer_options::OPT
optimizer_options::Optim.Options
recorder_file::String
prob_pick_first::Float32
early_stop_condition::Union{Function,Nothing}
Expand All @@ -108,6 +99,7 @@ struct Options{
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
deterministic::Bool
define_helper_functions::Bool
use_recorder::Bool
end

function Base.print(io::IO, options::Options)
Expand Down
4 changes: 1 addition & 3 deletions src/Recorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ module RecorderModule

using ..CoreModule: RecordType, Options

is_recording(::Options{<:Any,<:Any,use_recorder}) where {use_recorder} = use_recorder

"Assumes that `options` holds the user options::Options"
macro recorder(ex)
quote
if is_recording($(esc(:options)))
if $(esc(:options)).use_recorder
$(esc(ex))
end
end
Expand Down
26 changes: 14 additions & 12 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ using .HallOfFameModule:
HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve
using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population
using .ProgressBarsModule: WrappedProgressBar
using .RecorderModule: @recorder, is_recording, find_iteration_from_record
using .RecorderModule: @recorder, find_iteration_from_record
using .MigrationModule: migrate!
using .SearchUtilsModule:
DefaultWorkerOutputType,
Expand Down Expand Up @@ -354,7 +354,7 @@ function equation_search(
y::AbstractMatrix{T};
niterations::Int=10,
weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing,
options::Options=Options(),
@nospecialize(options::Options = Options()),
node_type::Type{N}=Node,
variable_names::Union{AbstractVector{String},Nothing}=nothing,
display_variable_names::Union{AbstractVector{String},Nothing}=variable_names,
Expand Down Expand Up @@ -446,7 +446,7 @@ function equation_search(
datasets::Vector{D},
::Type{N}=Node;
niterations::Int=10,
options::Options=Options(),
@nospecialize(options::Options = Options()),
parallelism=:multithreading,
numprocs::Union{Int,Nothing}=nothing,
procs::Union{Vector{Int},Nothing}=nothing,
Expand Down Expand Up @@ -597,7 +597,7 @@ function _equation_search(
::Val{DIM_OUT},
datasets::Vector{D},
niterations::Int,
options::Options,
@nospecialize(options::Options),
::Type{N},
numprocs::Int,
procs::Union{Vector{Int},Nothing},
Expand Down Expand Up @@ -629,7 +629,9 @@ function _equation_search(

example_dataset = datasets[1]
nout = size(datasets, 1)
@assert nout >= 1
@assert (nout == 1 || DIM_OUT == 2)
@assert options.populations >= 1

if runtests
test_option_configuration(PARALLELISM, datasets, saved_state, options, verbosity)
Expand Down Expand Up @@ -701,23 +703,23 @@ function _equation_search(
# Get the next worker process to give a job:
worker_assignment = initialize_worker_assignment()

hallOfFame = load_saved_hall_of_fame(saved_state)
hallOfFame = if hallOfFame === nothing
[HallOfFame(options, T, L, N) for j in 1:nout]
init_hall_of_fame = load_saved_hall_of_fame(saved_state)
hallOfFame = if init_hall_of_fame === nothing
HallOfFameType[HallOfFame(options, T, L, N) for j in 1:nout]
else
# Recompute losses for the hall of fame, in
# case the dataset changed:
for (hof, dataset) in zip(hallOfFame, datasets)
for (hof, dataset) in zip(init_hall_of_fame, datasets)
for member in hof.members[hof.exists]
score, result_loss = score_func(dataset, member, options)
member.score = score
member.loss = result_loss
end
end
hallOfFame
init_hall_of_fame
end
@assert length(hallOfFame) == nout
hallOfFame::Vector{HallOfFameType}
@assert length(hallOfFame) == nout

for j in 1:nout, i in 1:(options.populations)
worker_idx = assign_next_worker!(
Expand Down Expand Up @@ -944,7 +946,7 @@ function _equation_search(
worker_idx = assign_next_worker!(
worker_assignment; out=j, pop=i, parallelism=PARALLELISM, procs
)
iteration = if is_recording(options)
iteration = if options.use_recorder
key = "out$(j)_pop$(i)"
find_iteration_from_record(key, record) + 1
else
Expand Down Expand Up @@ -1076,7 +1078,7 @@ end
function _dispatch_s_r_cycle(
in_pop::Population,
dataset::Dataset,
options::Options;
@nospecialize(options::Options);
pop::Int,
out::Int,
iteration::Int,
Expand Down
2 changes: 1 addition & 1 deletion test/test_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ default_params = (
optimizer_nrestarts=3,
optimizer_probability=0.1f0,
optimizer_iterations=100,
val_recorder=Val(false),
use_recorder=false,
recorder_file="pysr_recorder.json",
tournament_selection_p=1.0,
early_stop_condition=nothing,
Expand Down
2 changes: 1 addition & 1 deletion test/test_recorder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2
options = SymbolicRegression.Options(;
binary_operators=(+, *, /, -),
unary_operators=(cos,),
val_recorder=Val(true),
use_recorder=true,
recorder_file="pysr_recorder.json",
crossover_probability=0.0, # required for recording, as not set up to track crossovers.
populations=2,
Expand Down

0 comments on commit 3a33882

Please sign in to comment.