Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplification of the divide output #1080

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 9 additions & 28 deletions src/Algorithm/branching/branchingalgo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
struct NoBranching <: AlgoAPI.AbstractDivideAlgorithm end

function run!(::NoBranching, ::Env, reform::Reformulation, ::Branching.AbstractDivideInput)
return DivideOutput([], OptimizationState(getmaster(reform)))
return DivideOutput([])

Check warning on line 26 in src/Algorithm/branching/branchingalgo.jl

View check run for this annotation

Codecov / codecov/patch

src/Algorithm/branching/branchingalgo.jl#L26

Added line #L26 was not covered by tests
end

############################################################################################
Expand Down Expand Up @@ -76,11 +76,6 @@
Branching.get_selection_criterion(ctx::BranchingContext) = ctx.selection_criterion
Branching.get_rules(ctx::BranchingContext) = ctx.rules

function Branching.new_ip_primal_sols_pool(ctx::BranchingContext, reform::Reformulation, input)
# Optimization state with no information.
return OptimizationState(getmaster(reform))
end

function _is_integer(sol::PrimalSolution)
for (varid, val) in sol
integer_val = abs(val - round(val)) < 1e-5
Expand Down Expand Up @@ -130,8 +125,8 @@
return _why_no_candidate(master, reform, input, extended_sol, original_sol)
end

Branching.new_divide_output(children::Vector{SbNode}, optimization_state) = DivideOutput(children, optimization_state)
Branching.new_divide_output(::Nothing, optimization_state) = DivideOutput(SbNode[], optimization_state)
Branching.new_divide_output(children::Vector{SbNode}) = DivideOutput(children)
Branching.new_divide_output(::Nothing) = DivideOutput(SbNode[])

############################################################################################
# Branching API implementation for the strong branching
Expand Down Expand Up @@ -272,38 +267,24 @@
)
end

function Branching.eval_child_of_candidate!(child, phase::Branching.AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
function Branching.eval_child_of_candidate!(child, phase::Branching.AbstractStrongBrPhaseContext, env, reform, input)
child_state = OptimizationState(getmaster(reform))
child.conquer_output = child_state

# In the `ip_primal_sols_found`, we maintain all the primal solutions found during the
# strong branching procedure but also the best primal bound found so far (in the whole optimization).
update_ip_primal_bound!(child_state, get_ip_primal_bound(ip_primal_sols_found))

global_primal_handler = Branching.get_global_primal_handler(input)
update_ip_primal_bound!(child_state, get_global_primal_bound(global_primal_handler))

if !ip_gap_closed(child_state)
units_to_restore = Branching.get_units_to_restore_for_conquer(phase)
restore_from_records!(units_to_restore, child.records)
conquer_input = ConquerInputFromSb(Branching.get_global_primal_handler(input), child, units_to_restore)
conquer_input = ConquerInputFromSb(global_primal_handler, child, units_to_restore)
child.conquer_output = run!(Branching.get_conquer(phase), env, reform, conquer_input)
TreeSearch.set_records!(child, create_records(reform))
end

# Store new primal solutions found during the evaluation of the child.
add_ip_primal_sols!(ip_primal_sols_found, get_ip_primal_sols(child_state)...)
for sol in get_ip_primal_sols(child_state)
store_ip_primal_sol!(Branching.get_global_primal_handler(input), sol)
store_ip_primal_sol!(global_primal_handler, sol)

Check warning on line 287 in src/Algorithm/branching/branchingalgo.jl

View check run for this annotation

Codecov / codecov/patch

src/Algorithm/branching/branchingalgo.jl#L287

Added line #L287 was not covered by tests
end
return
end

function Branching.new_ip_primal_sols_pool(ctx::StrongBranchingContext, reform, input)
# Optimization state with copy of bounds only (except lp_primal_bound).
# Only the ip primal bound is used to avoid inserting integer solutions that are not
# better than the incumbent.
# We also use the primal bound to init candidate nodes in the strong branching procedure.
input_opt_state = Branching.get_conquer_opt_state(input)
return OptimizationState(
getmaster(reform);
ip_primal_bound = get_ip_primal_bound(input_opt_state),
)
end
2 changes: 0 additions & 2 deletions src/Algorithm/branching/interface.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
struct DivideOutput{N} <: Branching.AbstractDivideOutput
children::Vector{N}
optstate::Union{Nothing,OptimizationState}
end

Branching.get_children(output::DivideOutput) = output.children
#Branching.get__opt_state(output::DivideOutput) = output.optstate

function get_extended_sol(reform, opt_state)
return get_best_lp_primal_sol(opt_state)
Expand Down
1 change: 0 additions & 1 deletion src/Algorithm/branching/printer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Branching.get_int_tol(ctx::BranchingPrinter) = Branching.get_int_tol(ctx.inner)
Branching.get_selection_criterion(ctx::BranchingPrinter) = Branching.get_selection_criterion(ctx.inner)
Branching.get_selection_nb_candidates(ctx::BranchingPrinter) = Branching.get_selection_nb_candidates(ctx.inner)
Branching.get_phases(ctx::BranchingPrinter) = Branching.get_phases(ctx.inner)
Branching.new_ip_primal_sols_pool(ctx::BranchingPrinter, reform, input) = Branching.new_ip_primal_sols_pool(ctx.inner, reform, input)

struct PhasePrinter{PhaseContext<:Branching.AbstractStrongBrPhaseContext} <: Branching.AbstractStrongBrPhaseContext
inner::PhaseContext
Expand Down
42 changes: 16 additions & 26 deletions src/Branching/Branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@

# branching output
"""
new_divide_output(children::Union{Vector{N}, Nothing}, ip_primal_sols_found{C, Nothing}) where {N, C} -> AbstractDivideOutput
new_divide_output(children::Union{Vector{N}, Nothing}) where {N} -> AbstractDivideOutput

where:
- `N` is the type of nodes generated by the branching algorithm.
- `C` is the type of the collection that stores all ip primal solutions found by the branching algorithm.

If no nodes nor ip primal solutions are found, the generic implementation may provide `nothing`.
If no nodes are found, the generic implementation may provide `nothing`.
"""
@mustimplement "BranchingOutput" new_divide_output(children, ip_primal_sols_found) = nothing
@mustimplement "BranchingOutput" new_divide_output(children) = nothing

Check warning on line 75 in src/Branching/Branching.jl

View check run for this annotation

Codecov / codecov/patch

src/Branching/Branching.jl#L75

Added line #L75 was not covered by tests

# Default implementations.
"Candidates selection for branching algorithms."
Expand All @@ -89,7 +88,7 @@

function advanced_select!(ctx::AbstractBranchingContext, candidates, env, reform, input::AbstractDivideInput)
children = generate_children!(first(candidates), env, reform, input)
return new_divide_output(children, nothing)
return new_divide_output(children)
end

############################################################################################
Expand Down Expand Up @@ -124,13 +123,8 @@
"Returns the maximum number of candidates kept at the end of a given strong branching phase."
@mustimplement "StrongBranching" get_max_nb_candidates(::AbstractStrongBrPhaseContext) = nothing

""
@mustimplement "StrongBranchingOptState" new_ip_primal_sols_pool(ctx, reform, input) = nothing

function advanced_select!(ctx::Branching.AbstractStrongBrContext, candidates, env, reform, input::Branching.AbstractDivideInput)
return perform_strong_branching!(ctx, env, reform, input, candidates)
#children = get_children(first(candidates))
#return new_divide_output(children, ip_primal_sols_found)
end

function perform_strong_branching!(
Expand All @@ -142,10 +136,6 @@
function perform_strong_branching_inner!(
ctx::AbstractStrongBrContext, env, reform, input::Branching.AbstractDivideInput, candidates::Vector{C}
) where {C<:AbstractBranchingCandidate}
# We will store all the new ip primal solution found during the strong branching in the
# following data structure.
ip_primal_sols_found = new_ip_primal_sols_pool(ctx, reform, input)

cand_children = [generate_children!(candidate, env, reform, input) for candidate in candidates]

phases = get_phases(ctx)
Expand All @@ -163,7 +153,7 @@
nb_candidates_for_next_phase = min(nb_candidates_for_next_phase, length(cand_children))
end

scores = perform_branching_phase!(candidates, cand_children, current_phase, ip_primal_sols_found, env, reform, input)
scores = perform_branching_phase!(candidates, cand_children, current_phase, env, reform, input)

perm = sortperm(scores, rev=true)
permute!(cand_children, perm)
Expand All @@ -178,15 +168,15 @@
resize!(cand_children, nb_candidates_for_next_phase)
resize!(candidates, nb_candidates_for_next_phase)
end
return new_divide_output(first(cand_children), ip_primal_sols_found)
return new_divide_output(first(cand_children))
end

function perform_branching_phase!(candidates, cand_children, phase, ip_primal_sols_found, env, reform, input)
return perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_found, env, reform, input)
function perform_branching_phase!(candidates, cand_children, phase, env, reform, input)
return perform_branching_phase_inner!(cand_children, phase, env, reform, input)
end

"Performs a branching phase."
function perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_found, env, reform, input)
function perform_branching_phase_inner!(cand_children, phase, env, reform, input)

return map(cand_children) do children
# TODO; I don't understand why we need to sort the children here.
Expand All @@ -205,24 +195,24 @@
# by = child -> get_lp_primal_bound(child)
# )

return eval_candidate!(children, phase, ip_primal_sols_found, env, reform, input)
return eval_candidate!(children, phase, env, reform, input)
end
end

function eval_candidate!(children, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
return eval_candidate_inner!(children, phase, ip_primal_sols_found, env, reform, input)
function eval_candidate!(children, phase::AbstractStrongBrPhaseContext, env, reform, input)
return eval_candidate_inner!(children, phase, env, reform, input)
end

"Evaluates a candidate."
function eval_candidate_inner!(children, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
function eval_candidate_inner!(children, phase::AbstractStrongBrPhaseContext, env, reform, input)
for child in children
eval_child_of_candidate!(child, phase, ip_primal_sols_found, env, reform, input)
eval_child_of_candidate!(child, phase, env, reform, input)
end
return compute_score(get_score(phase), children, input)
end

"Evaluate children of a candidate."
@mustimplement "StrongBranching" eval_child_of_candidate!(child, phase, ip_primal_sols_found, env, reform, input) = nothing
@mustimplement "StrongBranching" eval_child_of_candidate!(child, phase, env, reform, input) = nothing

Check warning on line 215 in src/Branching/Branching.jl

View check run for this annotation

Codecov / codecov/patch

src/Branching/Branching.jl#L215

Added line #L215 was not covered by tests

@mustimplement "Branching" isroot(node) = nothing

Expand Down Expand Up @@ -302,7 +292,7 @@
if length(candidates) == 0
@warn "No candidate generated. No children will be generated. However, the node is not conquered."
why_no_candidate(reform, input, extended_sol, original_sol)
return new_divide_output(nothing, nothing)
return new_divide_output(nothing)
end
return advanced_select!(ctx, candidates, env, reform, input)
end
Expand Down
2 changes: 1 addition & 1 deletion test/unit/TreeSearch/treesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function Coluna.Algorithm.run!(alg::DeterministicDivide, ::Coluna.Env, ::Coluna.
for c in children ## update flag
push!(alg.nodes_created_by_divide, c.id)
end
return Coluna.Algorithm.DivideOutput(children, nothing)
return Coluna.Algorithm.DivideOutput(children)
end

# constructs a real node from a LightNode, used in new_children to built real children from the minimal information contained in LightNode
Expand Down
Loading