Skip to content

Commit

Permalink
Remove children nodes from branching candidates (#1072)
Browse files Browse the repository at this point in the history
  • Loading branch information
rrsadykov authored Sep 22, 2023
1 parent 6084518 commit bad4a12
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 56 deletions.
3 changes: 0 additions & 3 deletions docs/src/api/branching.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ Branching.AbstractBranchingCandidate
Branching.getdescription
Branching.get_lhs
Branching.get_local_id
Branching.get_children
Branching.set_children!
Branching.get_parent
Branching.generate_children!
```

Expand Down
8 changes: 4 additions & 4 deletions src/Algorithm/branching/printer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ function new_phase_context(
return PhasePrinter(inner_ctx, phase_index)
end

function Branching.perform_branching_phase!(candidates, phase::PhasePrinter, sb_state, env, reform, input)
function Branching.perform_branching_phase!(candidates, cand_children, phase::PhasePrinter, sb_state, env, reform, input)
println("**** Strong branching phase ", phase.phase_index, " is started *****");
scores = Branching.perform_branching_phase_inner!(candidates, phase, sb_state, env, reform, input)
for (candidate, score) in Iterators.zip(candidates, scores)
scores = Branching.perform_branching_phase_inner!(cand_children, phase, sb_state, env, reform, input)
for (candidate, children, score) in Iterators.zip(candidates, cand_children, scores)
@printf "SB phase %i branch on %+10s" phase.phase_index Branching.getdescription(candidate)
@printf " (lhs=%.4f) : [" Branching.get_lhs(candidate)
for (node_index, node) in enumerate(Branching.get_children(candidate))
for (node_index, node) in enumerate(children)
node_index > 1 && print(",")
@printf "%10.4f" getvalue(get_lp_primal_bound(node.optstate))
end
Expand Down
7 changes: 3 additions & 4 deletions src/Algorithm/branching/scores.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
struct ProductScore <: Branching.AbstractBranchingScore end

function Branching.compute_score(::ProductScore, candidate, input)
function Branching.compute_score(::ProductScore, children, input)
parent = Branching.get_conquer_opt_state(input)
parent_lp_dual_bound = get_lp_dual_bound(parent)
parent_ip_primal_bound = get_ip_primal_bound(parent)
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(Branching.get_children(candidate), Ref(:optstate)))
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, Ref(:optstate)))
return _product_score(parent_lp_dual_bound, parent_ip_primal_bound, children_lp_primal_bounds)
end

struct TreeDepthScore <: Branching.AbstractBranchingScore end

function Branching.compute_score(::TreeDepthScore, candidate, input)
function Branching.compute_score(::TreeDepthScore, children, input)
parent = Branching.get_conquer_opt_state(input)
parent_lp_dual_bound = get_lp_dual_bound(parent)
parent_ip_primal_bound = get_ip_primal_bound(parent)
children = Branching.get_children(candidate)
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, :optstate))
return _tree_depth_score(parent_lp_dual_bound, parent_ip_primal_bound, children_lp_primal_bounds)
end
Expand Down
7 changes: 1 addition & 6 deletions src/Algorithm/branching/single_var_branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,16 @@ mutable struct SingleVarBranchingCandidate <: Branching.AbstractBranchingCandida
varid::VarId
local_id::Int64
lhs::Float64
score::Float64
children::Vector{SbNode}
isconquered::Bool
function SingleVarBranchingCandidate(
varname::String, varid::VarId, local_id::Int64, lhs::Float64
)
return new(varname, varid, local_id, lhs, 0.0, SbNode[], false)
return new(varname, varid, local_id, lhs)
end
end

Branching.getdescription(candidate::SingleVarBranchingCandidate) = candidate.varname
Branching.get_lhs(candidate::SingleVarBranchingCandidate) = candidate.lhs
Branching.get_local_id(candidate::SingleVarBranchingCandidate) = candidate.local_id
Branching.get_children(candidate::SingleVarBranchingCandidate) = candidate.children
Branching.set_children!(candidate::SingleVarBranchingCandidate, children) = candidate.children = children

function get_branching_candidate_units_usage(::SingleVarBranchingCandidate, reform)
units_to_restore = UnitsUsage()
Expand Down
50 changes: 25 additions & 25 deletions src/Branching/Branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,13 @@ function select!(rule::AbstractBranchingRule, env, reform, input::Branching.Bran
local_id = input.local_id + length(candidates)
select_candidates!(candidates, input.criterion, input.max_nb_candidates)

for candidate in candidates
children = generate_children!(candidate, env, reform, input.input)
set_children!(candidate, children)
end
return BranchingRuleOutput(local_id, candidates)
end

abstract type AbstractBranchingContext <: AbstractDivideContext end

function advanced_select!(ctx::AbstractBranchingContext, candidates, _, reform, input::AbstractDivideInput)
children = get_children(first(candidates))
function advanced_select!(ctx::AbstractBranchingContext, candidates, env, reform, input::AbstractDivideInput)
children = generate_children!(first(candidates), env, reform, input)
return new_divide_output(children, nothing)
end

Expand Down Expand Up @@ -132,9 +128,9 @@ strong branching phase.
@mustimplement "StrongBranchingOptState" new_ip_primal_sols_pool(ctx, reform, input) = nothing

function advanced_select!(ctx::Branching.AbstractStrongBrContext, candidates, env, reform, input::Branching.AbstractDivideInput)
ip_primal_sols_found = perform_strong_branching!(ctx, env, reform, input, candidates)
children = get_children(first(candidates))
return new_divide_output(children, ip_primal_sols_found)
return perform_strong_branching!(ctx, env, reform, input, candidates)
#children = get_children(first(candidates))
#return new_divide_output(children, ip_primal_sols_found)
end

function perform_strong_branching!(
Expand All @@ -144,30 +140,33 @@ function perform_strong_branching!(
end

function perform_strong_branching_inner!(
ctx::AbstractStrongBrContext, env, model, input::Branching.AbstractDivideInput, candidates::Vector{C}
ctx::AbstractStrongBrContext, env, reform, input::Branching.AbstractDivideInput, candidates::Vector{C}
) where {C<:AbstractBranchingCandidate}
# We will store all the new ip primal solution found during the strong branching in the
# following data structure.
ip_primal_sols_found = new_ip_primal_sols_pool(ctx, model, input)
ip_primal_sols_found = new_ip_primal_sols_pool(ctx, reform, input)

cand_children = [generate_children!(candidate, env, reform, input) for candidate in candidates]

phases = get_phases(ctx)
for (phase_index, current_phase) in enumerate(phases)
nb_candidates_for_next_phase = 1
if phase_index < length(phases)
nb_candidates_for_next_phase = get_max_nb_candidates(phases[phase_index + 1])
if length(candidates) <= nb_candidates_for_next_phase
if length(cand_children) <= nb_candidates_for_next_phase
# If at the current phase, we have less candidates than the number of candidates
# we want to evaluate at the next phase, we skip the current phase.
continue
end
# In phase 1, we make sure that the number of candidates for the next phase is
# at least equal to the number of initial candidates.
nb_candidates_for_next_phase = min(nb_candidates_for_next_phase, length(candidates))
nb_candidates_for_next_phase = min(nb_candidates_for_next_phase, length(cand_children))
end

scores = perform_branching_phase!(candidates, current_phase, ip_primal_sols_found, env, model, input)
scores = perform_branching_phase!(candidates, cand_children, current_phase, ip_primal_sols_found, env, reform, input)

perm = sortperm(scores, rev=true)
permute!(cand_children, perm)
permute!(candidates, perm)

# The case where one/many candidate is conquered is not supported yet.
Expand All @@ -176,19 +175,20 @@ function perform_strong_branching_inner!(
# before deleting branching candidates which are not kept for the next phase
# we need to remove record kept in these nodes

resize!(cand_children, nb_candidates_for_next_phase)
resize!(candidates, nb_candidates_for_next_phase)
end
return ip_primal_sols_found
return new_divide_output(first(cand_children), ip_primal_sols_found)
end

function perform_branching_phase!(candidates, phase, ip_primal_sols_found, env, reform, input)
return perform_branching_phase_inner!(candidates, phase, ip_primal_sols_found, env, reform, input)
function perform_branching_phase!(candidates, cand_children, phase, ip_primal_sols_found, env, reform, input)
return perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_found, env, reform, input)
end

"Performs a branching phase."
function perform_branching_phase_inner!(candidates, phase, ip_primal_sols_found, env, reform, input)
function perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_found, env, reform, input)

return map(candidates) do candidate
return map(cand_children) do children
# TODO; I don't understand why we need to sort the children here.
# Looks like eval_children_of_candidiate! and the default implementation of
# eval_child_of_candidate is fully independent of the order of the children.
Expand All @@ -205,20 +205,20 @@ function perform_branching_phase_inner!(candidates, phase, ip_primal_sols_found,
# by = child -> get_lp_primal_bound(child)
# )

return eval_candidate!(candidate, phase, ip_primal_sols_found, env, reform, input)
return eval_candidate!(children, phase, ip_primal_sols_found, env, reform, input)
end
end

function eval_candidate!(candidate, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
return eval_candidate_inner!(candidate, phase, ip_primal_sols_found, env, reform, input)
function eval_candidate!(children, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
return eval_candidate_inner!(children, phase, ip_primal_sols_found, env, reform, input)
end

"Evaluates a candidate."
function eval_candidate_inner!(candidate, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
for child in get_children(candidate)
function eval_candidate_inner!(children, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
for child in children
eval_child_of_candidate!(child, phase, ip_primal_sols_found, env, reform, input)
end
return compute_score(get_score(phase), candidate, input)
return compute_score(get_score(phase), children, input)
end

"Evaluate children of a candidate."
Expand Down
14 changes: 0 additions & 14 deletions src/Branching/candidate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,10 @@ abstract type AbstractBranchingCandidate end
"Returns the generation id of the candidiate."
@mustimplement "BranchingCandidate" get_local_id(c::AbstractBranchingCandidate) = nothing

"Returns the children of the candidate."
@mustimplement "BranchingCandidate" get_children(c::AbstractBranchingCandidate) = nothing

"Set the children of the candidate."
@mustimplement "BranchingCandidate" set_children!(c::AbstractBranchingCandidate, children) = nothing

"Returns the parent node of the candidate's children."
@mustimplement "BranchingCandidate" get_parent(c::AbstractBranchingCandidate) = nothing

# TODO: this method should not generate the children of the tree search algorithm.
# However, AbstractBranchingCandidate should implement an interface to retrieve data to
# generate a children.
"""
generate_children!(branching_candidate, lhs, env, reform, node)
This method generates the children of a node described by `branching_candidate`.
Make sure that this method returns an object the same type as the second argument of
`set_children!(candiate, children)`.
"""
@mustimplement "BranchingCandidate" generate_children!(c::AbstractBranchingCandidate, env, reform, parent) = nothing

Expand Down

0 comments on commit bad4a12

Please sign in to comment.