Skip to content

Commit

Permalink
Refactor strong branching algorithm (#682)
Browse files Browse the repository at this point in the history
* wip

* ok
  • Loading branch information
guimarqu authored Aug 10, 2022
1 parent 9043f4d commit c9d3009
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 282 deletions.
6 changes: 3 additions & 3 deletions src/Algorithm/branchcutprice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function BranchCutAndPriceAlgorithm(;
branching_phases = BranchingPhase[]
if length(stbranch_phases_num_candidates) >= 2
push!(branching_phases,
BranchingPhase(first(stbranch_phases_num_candidates), RestrMasterLPConquer())
BranchingPhase(first(stbranch_phases_num_candidates), RestrMasterLPConquer(), ProductScore())
)
if length(stbranch_phases_num_candidates) >= 3
intrmphase_stages = ColumnGeneration[]
Expand Down Expand Up @@ -145,11 +145,11 @@ function BranchCutAndPriceAlgorithm(;
opt_rtol = opt_rtol
)
push!(branching_phases,
BranchingPhase(stbranch_phases_num_candidates[2], intrmphase_conquer)
BranchingPhase(stbranch_phases_num_candidates[2], intrmphase_conquer, ProductScore())
)
end
end
push!(branching_phases, BranchingPhase(last(stbranch_phases_num_candidates), conquer))
push!(branching_phases, BranchingPhase(last(stbranch_phases_num_candidates), conquer, TreeDepthScore()))
branching = StrongBranching(rules = branching_rules, phases = branching_phases)
else
branching = StrongBranching(rules = branching_rules)
Expand Down
242 changes: 109 additions & 133 deletions src/Algorithm/branching/branchingalgo.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
############################################################################################
# NoBranching
############################################################################################

# This is only for strong branching
# returns the optimization part of the output of the conquer algorithm
function apply_conquer_alg_to_node!(
node::SbNode, algo::AbstractConquerAlgorithm, env::Env, reform::Reformulation,
units_to_restore::UnitsUsage, opt_rtol::Float64 = Coluna.DEF_OPTIMALITY_RTOL,
opt_atol::Float64 = Coluna.DEF_OPTIMALITY_ATOL
)
nodestate = getoptstate(node)
if isverbose(algo)
@logmsg LogLevel(-1) string("Node IP DB: ", get_ip_dual_bound(nodestate))
@logmsg LogLevel(-1) string("Tree IP PB: ", get_ip_primal_bound(nodestate))
end
if ip_gap_closed(nodestate, rtol = opt_rtol, atol = opt_atol)
@info "IP Gap is closed: $(ip_gap(nodestate)). Abort treatment."
else
isverbose(algo) && @logmsg LogLevel(-1) string("IP Gap is positive. Need to treat node.")
"""
NoBranching
run!(algo, env, reform, ConquerInput(Node(node, -1), units_to_restore))
store_records!(reform, node.recordids)
end
node.conquerwasrun = true
return
Divide algorithm that does nothing. It does not generate any child.
"""
struct NoBranching <: AbstractDivideAlgorithm end

function run!(::NoBranching, ::Env, reform::Reformulation, ::DivideInput)::DivideOutput
return DivideOutput([], OptimizationState(getmaster(reform)))
end

############################################################################################
# StrongBranching
############################################################################################

"""
BranchingPhase(max_nb_candidates, conquer_algo)
Expand All @@ -32,12 +26,13 @@ to evaluate and the conquer algorithm which does evaluation.
struct BranchingPhase
max_nb_candidates::Int64
conquer_algo::AbstractConquerAlgorithm
score::AbstractBranchingScore
end

"""
PrioritisedBranchingRule
A branching rule with root and non-root priorities.
A branching rule with root and non-root priorities.
"""
struct PrioritisedBranchingRule
rule::AbstractBranchingRule
Expand All @@ -49,25 +44,6 @@ function getpriority(rule::PrioritisedBranchingRule, isroot::Bool)::Float64
return isroot ? rule.root_priority : rule.nonroot_priority
end

############################################################################################
# NoBranching
############################################################################################

"""
NoBranching
Divide algorithm that does nothing. It does not generate any child.
"""
struct NoBranching <: AbstractDivideAlgorithm end

function run!(::NoBranching, ::Env, reform::Reformulation, ::DivideInput)::DivideOutput
return DivideOutput([], OptimizationState(getmaster(reform)))
end

############################################################################################
# StrongBranching
############################################################################################

"""
StrongBranching
Expand Down Expand Up @@ -118,37 +94,83 @@ function get_child_algorithms(algo::StrongBranching, reform::Reformulation)
return child_algos
end

function exploits_primal_solutions(algo::StrongBranching)
for phase in algo.phases
exploits_primal_solutions(phase.conquer_algo) && return true
# This is only for strong branching
# returns the optimization part of the output of the conquer algorithm
function _apply_conquer_alg_to_child!(
child::SbNode, algo::AbstractConquerAlgorithm, env::Env, reform::Reformulation,
units_to_restore::UnitsUsage, opt_rtol::Float64 = Coluna.DEF_OPTIMALITY_RTOL,
opt_atol::Float64 = Coluna.DEF_OPTIMALITY_ATOL
)
child_state = getoptstate(child)
if ip_gap_closed(child_state, rtol = opt_rtol, atol = opt_atol)
@info "IP Gap is closed: $(ip_gap(child_state)). Abort treatment."
else
run!(algo, env, reform, ConquerInput(Node(child, -1), units_to_restore))
store_records!(reform, child.recordids)
end
return false
child.conquerwasrun = true
return
end

function perform_strong_branching_with_phases!(
function _eval_children_of_candidate!(
children::Vector{SbNode}, phase::BranchingPhase, phase_index, conquer_units_to_restore,
sbstate, env, reform, varname
)
for (child_index, child) in enumerate(children)
#### TODO: remove logs from algo logic
if isverbose(phase.conquer_algo)
print(
"**** SB phase ", phase_index, " evaluation of candidate ",
varname, " (branch ", child_index, " : ", child.branchdescription
)
@printf "), value = %6.2f\n" getvalue(get_lp_primal_bound(getoptstate(child)))
end

child_state = getoptstate(child)
update_ip_primal_bound!(child_state, get_ip_primal_bound(sbstate))

# TODO: We consider that all branching algorithms don't exploit the primal solution
# at the moment.
# best_ip_primal_sol = get_best_ip_primal_sol(sbstate)
# if !isnothing(best_ip_primal_sol)
# set_ip_primal_sol!(nodestate, best_ip_primal_sol)
# end

_apply_conquer_alg_to_child!(
child, phase.conquer_algo, env, reform, conquer_units_to_restore
)

add_ip_primal_sols!(sbstate, get_ip_primal_sols(child_state)...)

if to_be_pruned(child)
if isverbose(phase.conquer_algo)
println("Branch is conquered!")
end
end
end
return
end

function _perform_strong_branching_with_phases!(
algo::StrongBranching, env::Env, reform::Reformulation, input::DivideInput, candidates::Vector{C}
)::OptimizationState where {C<:AbstractBranchingCandidate}

parent = getparent(input)
exploitsprimalsolutions::Bool = exploits_primal_solutions(algo)
# TODO: We consider that conquer algorithms in the branching algo don't exploit the
# primal solution at the moment (3rd arg).
sbstate = OptimizationState(
getmaster(reform), getoptstate(input), exploitsprimalsolutions, false
getmaster(reform), getoptstate(input), false, false
)

for (phase_index, current_phase) in enumerate(algo.phases)
nb_candidates_for_next_phase = 1

# If at the current phase, we have less candidates than the number of candidates
# we want to evaluate at the next phase, we skip the current phase.
# We always execute phase 1 because it is the phase in which we generate the
# children for each branching candidate.
if phase_index < length(algo.phases)
nb_candidates_for_next_phase = algo.phases[phase_index + 1].max_nb_candidates
if phase_index > 1 && length(candidates) <= nb_candidates_for_next_phase
if length(candidates) <= nb_candidates_for_next_phase
# If at the current phase, we have less candidates than the number of candidates
# we want to evaluate at the next phase, we skip the current phase.
continue
end
# In phase 1, we make sure that the number of candidates for the next phase is
# at least equal to the number of initial candidates
# at least equal to the number of initial candidates.
nb_candidates_for_next_phase = min(nb_candidates_for_next_phase, length(candidates))
end

Expand All @@ -157,78 +179,32 @@ function perform_strong_branching_with_phases!(
conquer_units_to_restore, current_phase.conquer_algo, reform
)

#TO DO : we need to define a print level parameter
# TODO: separate printing logic from algo logic.
println("**** Strong branching phase ", phase_index, " is started *****");

#for nice printing, we compute the maximum description length
max_descr_length::Int64 = 0
for candidate in candidates
description = getdescription(candidate)
if (max_descr_length < length(description))
max_descr_length = length(description)
end
end

for (candidate_index, candidate) in enumerate(candidates)

if phase_index > 1
sort!(candidate.children, by = x -> get_lp_primal_bound(getoptstate(x)))
end
# TODO: end

# Here, we avoid the removal of pruned nodes at this point to let them
# appear in the branching tree file
for (node_index, node) in enumerate(candidate.children)
if isverbose(current_phase.conquer_algo)
print(
"**** SB phase ", phase_index, " evaluation of candidate ",
candidate_index, " (branch ", node_index, " : ", node.branchdescription
)
@printf "), value = %6.2f\n" getvalue(get_lp_primal_bound(getoptstate(node)))
end

nodestate = getoptstate(node)

update_ip_primal_bound!(nodestate, get_ip_primal_bound(sbstate))
best_ip_primal_sol = get_best_ip_primal_sol(sbstate)
if exploitsprimalsolutions && best_ip_primal_sol !== nothing
set_ip_primal_sol!(nodestate, best_ip_primal_sol)
end

apply_conquer_alg_to_node!(
node, current_phase.conquer_algo, env, reform, conquer_units_to_restore
)

add_ip_primal_sols!(sbstate, get_ip_primal_sols(nodestate)...)

if to_be_pruned(node)
if isverbose(current_phase.conquer_algo)
println("Branch is conquered!")
end
end
end
scores = map(candidates) do candidate
children = sort(get_children(candidate), by = child -> get_lp_primal_bound(getoptstate(child)))
_eval_children_of_candidate!(
children, current_phase, phase_index, conquer_units_to_restore, sbstate, env,
reform, candidate.varname
)

if phase_index < length(algo.phases)
# not the last phase, thus we compute the product score
candidate.score = compute_score(ProductScore(), candidate)
else
# the last phase, thus we compute the tree size score
candidate.score = compute_score(TreeDepthScore(), candidate)
end
print_bounds_and_score(candidate, phase_index, max_descr_length)
score = compute_score(current_phase.score, candidate)
print_bounds_and_score(candidate, phase_index, 30, score) # TODO: rm
return score
end

sort!(candidates, rev = true, by = x -> (x.isconquered, x.score))

if candidates[1].isconquered
nb_candidates_for_next_phase = 1
end
perm = sortperm(scores, rev=true)
permute!(candidates, perm)

# before deleting branching groups which are not kept for the next phase
# The case where one/many candidate is conquered is not supported yet.
# In this case, the number of candidates for next phase is one.

# before deleting branching candidates which are not kept for the next phase
# we need to remove record kept in these nodes
for candidate_index = nb_candidates_for_next_phase + 1 : length(candidates)
for (node_index, node) in enumerate(candidates[candidate_index].children)
remove_records!(node.recordids)
for child in candidates[candidate_index].children
remove_records!(child.recordids)
end
end
resize!(candidates, nb_candidates_for_next_phase)
Expand All @@ -253,11 +229,11 @@ function _select_candidates_with_branching_rule(rules, phases, selection_criteri
else
# If the branching algorithm has phases, then it first selects the maximum number of
# candidates required by the first phases. The last phase returns the "best" candidate.
phases[1].max_nb_candidates
first(phases).max_nb_candidates
end

local_id = 0 # TODO: this variable needs an explicit name.
priority_of_last_generated_groups = nothing
priority_of_last_gen_candidates = nothing

for prioritised_rule in sorted_rules
rule = prioritised_rule.rule
Expand All @@ -273,38 +249,38 @@ function _select_candidates_with_branching_rule(rules, phases, selection_criteri
# than priorities of not yet considered branching rules; (TODO: example? use case?)
# 2. all needed candidates were generated and their smallest priority is strictly greater
# than priorities of not yet considered branching rules.
stop_gen_condition_1 = !isnothing(priority_of_last_generated_groups) &&
nb_candidates_found > 0 && priority < floor(priority_of_last_generated_groups)
stop_gen_condition_1 = !isnothing(priority_of_last_gen_candidates) &&
nb_candidates_found > 0 && priority < floor(priority_of_last_gen_candidates)

stop_gen_condition_2 = !isnothing(priority_of_last_generated_groups) &&
nb_candidates_found >= max_nb_candidates && priority < priority_of_last_generated_groups
stop_gen_condition_2 = !isnothing(priority_of_last_gen_candidates) &&
nb_candidates_found >= max_nb_candidates && priority < priority_of_last_gen_candidates

if stop_gen_condition_1 || stop_gen_condition_2
break
end

# Generate candidates.
output = run!(
output = select!(
rule, env, reform, BranchingRuleInput(
original_solution, true, max_nb_candidates, selection_criterion,
local_id, int_tol, priority, parent
)
)
append!(kept_branch_candidates, output.groups)
append!(kept_branch_candidates, output.candidates)
local_id = output.local_id

if projection_is_possible(getmaster(reform)) && !isnothing(extended_solution)
output = run!(
output = select!(
rule, env, reform, BranchingRuleInput(
extended_solution, false, max_nb_candidates, selection_criterion,
local_id, int_tol, priority, parent
)
)
append!(kept_branch_candidates, output.groups)
append!(kept_branch_candidates, output.candidates)
local_id = output.local_id
end
select_candidates!(kept_branch_candidates, selection_criterion, max_nb_candidates)
priority_of_last_generated_groups = priority
priority_of_last_gen_candidates = priority
end
return kept_branch_candidates
end
Expand Down Expand Up @@ -349,7 +325,7 @@ function run!(algo::StrongBranching, env::Env, reform::Reformulation, input::Div
return DivideOutput(children, OptimizationState(getmaster(reform)))
end

sbstate = perform_strong_branching_with_phases!(algo, env, reform, input, kept_branch_candidates)
sbstate = _perform_strong_branching_with_phases!(algo, env, reform, input, kept_branch_candidates)
children = get_children(first(kept_branch_candidates))
return DivideOutput(children, sbstate)
end
Loading

0 comments on commit c9d3009

Please sign in to comment.