Skip to content

Commit

Permalink
Merge branch 'master' into propagate_local_bound
Browse files Browse the repository at this point in the history
  • Loading branch information
guimarqu authored Sep 28, 2023
2 parents f7513ea + ed60579 commit 5000e14
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 90 deletions.
43 changes: 11 additions & 32 deletions src/Algorithm/branching/branchingalgo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Divide algorithm that does nothing. It does not generate any child.
struct NoBranching <: AlgoAPI.AbstractDivideAlgorithm end

function run!(::NoBranching, ::Env, reform::Reformulation, ::Branching.AbstractDivideInput)
return DivideOutput([], OptimizationState(getmaster(reform)))
return DivideOutput([])
end

############################################################################################
Expand Down Expand Up @@ -76,11 +76,6 @@ Branching.get_int_tol(ctx::BranchingContext) = ctx.int_tol
Branching.get_selection_criterion(ctx::BranchingContext) = ctx.selection_criterion
Branching.get_rules(ctx::BranchingContext) = ctx.rules

function Branching.new_ip_primal_sols_pool(ctx::BranchingContext, reform::Reformulation, input)
# Optimization state with no information.
return OptimizationState(getmaster(reform))
end

function _is_integer(sol::PrimalSolution)
for (varid, val) in sol
integer_val = abs(val - round(val)) < 1e-5
Expand Down Expand Up @@ -130,8 +125,8 @@ function Branching.why_no_candidate(reform::Reformulation, input, extended_sol,
return _why_no_candidate(master, reform, input, extended_sol, original_sol)
end

Branching.new_divide_output(children::Vector{SbNode}, optimization_state) = DivideOutput(children, optimization_state)
Branching.new_divide_output(::Nothing, optimization_state) = DivideOutput(SbNode[], optimization_state)
Branching.new_divide_output(children::Vector{SbNode}) = DivideOutput(children)
Branching.new_divide_output(::Nothing) = DivideOutput(SbNode[])

############################################################################################
# Branching API implementation for the strong branching
Expand Down Expand Up @@ -272,40 +267,24 @@ function new_context(
)
end

function Branching.eval_child_of_candidate!(child, phase::Branching.AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
function Branching.eval_child_of_candidate!(child, phase::Branching.AbstractStrongBrPhaseContext, env, reform, input)
child_state = OptimizationState(getmaster(reform))
child.optstate = child_state
child.conquer_output = child_state

# In the `ip_primal_sols_found`, we maintain all the primal solutions found during the
# strong branching procedure but also the best primal bound found so far (in the whole optimization).
update_ip_primal_bound!(child_state, get_ip_primal_bound(ip_primal_sols_found))

global_primal_handler = Branching.get_global_primal_handler(input)
update_ip_primal_bound!(child_state, get_global_primal_bound(global_primal_handler))

if !ip_gap_closed(child_state)
units_to_restore = Branching.get_units_to_restore_for_conquer(phase)
restore_from_records!(units_to_restore, child.records)
conquer_input = ConquerInputFromSb(Branching.get_global_primal_handler(input), child, units_to_restore)
child_state = run!(Branching.get_conquer(phase), env, reform, conquer_input)
child.optstate = child_state
conquer_input = ConquerInputFromSb(global_primal_handler, child, units_to_restore)
child.conquer_output = run!(Branching.get_conquer(phase), env, reform, conquer_input)
TreeSearch.set_records!(child, create_records(reform))
end
child.conquerwasrun = true

# Store new primal solutions found during the evaluation of the child.
add_ip_primal_sols!(ip_primal_sols_found, get_ip_primal_sols(child_state)...)
for sol in get_ip_primal_sols(child_state)
store_ip_primal_sol!(Branching.get_global_primal_handler(input), sol)
store_ip_primal_sol!(global_primal_handler, sol)
end
return
end

function Branching.new_ip_primal_sols_pool(ctx::StrongBranchingContext, reform, input)
# Optimization state with copy of bounds only (except lp_primal_bound).
# Only the ip primal bound is used to avoid inserting integer solutions that are not
# better than the incumbent.
# We also use the primal bound to init candidate nodes in the strong branching procedure.
input_opt_state = Branching.get_conquer_opt_state(input)
return OptimizationState(
getmaster(reform);
ip_primal_bound = get_ip_primal_bound(input_opt_state),
)
end
2 changes: 0 additions & 2 deletions src/Algorithm/branching/interface.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
struct DivideOutput{N} <: Branching.AbstractDivideOutput
children::Vector{N}
optstate::Union{Nothing,OptimizationState}
end

Branching.get_children(output::DivideOutput) = output.children
#Branching.get__opt_state(output::DivideOutput) = output.optstate

function get_extended_sol(reform, opt_state)
return get_best_lp_primal_sol(opt_state)
Expand Down
3 changes: 1 addition & 2 deletions src/Algorithm/branching/printer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Branching.get_int_tol(ctx::BranchingPrinter) = Branching.get_int_tol(ctx.inner)
Branching.get_selection_criterion(ctx::BranchingPrinter) = Branching.get_selection_criterion(ctx.inner)
Branching.get_selection_nb_candidates(ctx::BranchingPrinter) = Branching.get_selection_nb_candidates(ctx.inner)
Branching.get_phases(ctx::BranchingPrinter) = Branching.get_phases(ctx.inner)
Branching.new_ip_primal_sols_pool(ctx::BranchingPrinter, reform, input) = Branching.new_ip_primal_sols_pool(ctx.inner, reform, input)

struct PhasePrinter{PhaseContext<:Branching.AbstractStrongBrPhaseContext} <: Branching.AbstractStrongBrPhaseContext
inner::PhaseContext
Expand Down Expand Up @@ -39,7 +38,7 @@ function Branching.perform_branching_phase!(candidates, cand_children, phase::Ph
@printf " (lhs=%.4f) : [" Branching.get_lhs(candidate)
for (node_index, node) in enumerate(children)
node_index > 1 && print(",")
@printf "%10.4f" getvalue(get_lp_primal_bound(node.optstate))
@printf "%10.4f" getvalue(get_lp_primal_bound(node.conquer_output))
end
@printf "], score = %10.4f\n" score
end
Expand Down
14 changes: 4 additions & 10 deletions src/Algorithm/branching/sbnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,21 @@ mutable struct SbNode <: TreeSearch.AbstractNode
# There information are printed by the StrongBranchingPrinter.
# These information will be then transfered to the B&B algorithm when instantating the
# node of the tree search.
optstate::OptimizationState
conquer_output::Union{Nothing, OptimizationState}

var_name::String
branchdescription::String
ip_dual_bound::Bound
records::Records
conquerwasrun::Bool
function SbNode(
reform::Reformulation, depth, var_name::String, branch_description::String, records::Records, input
depth, branch_description::String, ip_dual_bound::Bound, records::Records
)
node_state = OptimizationState(
getmaster(reform);
ip_dual_bound = get_ip_dual_bound(Branching.get_conquer_opt_state(input))
)
return new(depth, node_state, var_name, branch_description, records, false)
return new(depth, nothing, branch_description, ip_dual_bound, records)
end
end

getdepth(n::SbNode) = n.depth

TreeSearch.set_records!(n::SbNode, records) = n.records = records
TreeSearch.get_branch_description(n::SbNode) = n.branchdescription
get_var_name(n::SbNode) = n.var_name
TreeSearch.isroot(n::SbNode) = false
Branching.isroot(n::SbNode) = TreeSearch.isroot(n)
4 changes: 2 additions & 2 deletions src/Algorithm/branching/scores.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function Branching.compute_score(::ProductScore, children, input)
parent = Branching.get_conquer_opt_state(input)
parent_lp_dual_bound = get_lp_dual_bound(parent)
parent_ip_primal_bound = get_ip_primal_bound(parent)
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, Ref(:optstate)))
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, Ref(:conquer_output)))
return _product_score(parent_lp_dual_bound, parent_ip_primal_bound, children_lp_primal_bounds)
end

Expand All @@ -14,7 +14,7 @@ function Branching.compute_score(::TreeDepthScore, children, input)
parent = Branching.get_conquer_opt_state(input)
parent_lp_dual_bound = get_lp_dual_bound(parent)
parent_ip_primal_bound = get_ip_primal_bound(parent)
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, :optstate))
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, :conquer_output))
return _tree_depth_score(parent_lp_dual_bound, parent_ip_primal_bound, children_lp_primal_bounds)
end

Expand Down
6 changes: 4 additions & 2 deletions src/Algorithm/branching/single_var_branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ function Branching.generate_children!(
units_to_restore = get_branching_candidate_units_usage(candidate, reform)
d = Branching.get_parent_depth(input)

parent_ip_dual_bound = get_ip_dual_bound(Branching.get_conquer_opt_state(input))

# adding the first branching constraints
restore_from_records!(units_to_restore, Branching.parent_records(input))
setconstr!(
Expand All @@ -57,7 +59,7 @@ function Branching.generate_children!(
members = Dict{VarId,Float64}(candidate.varid => 1.0)
)
child1description = candidate.varname * ">=" * string(ceil(lhs))
child1 = SbNode(reform, d+1, candidate.varname, child1description, create_records(reform), input)
child1 = SbNode(d+1, child1description, parent_ip_dual_bound, create_records(reform))

# adding the second branching constraints
restore_from_records!(units_to_restore, Branching.parent_records(input))
Expand All @@ -71,7 +73,7 @@ function Branching.generate_children!(
members = Dict{VarId,Float64}(candidate.varid => 1.0)
)
child2description = candidate.varname * "<=" * string(floor(lhs))
child2 = SbNode(reform, d+1, candidate.varname, child2description, create_records(reform), input)
child2 = SbNode(d+1, child2description, parent_ip_dual_bound, create_records(reform))

return [child1, child2]
end
Expand Down
11 changes: 7 additions & 4 deletions src/Algorithm/treesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,13 @@ function TreeSearch.children(space::AbstractColunaSearchSpace, current::TreeSear
# Else we run the conquer algorithm.
# This algorithm has the responsibility to check whether the node is pruned.
reform = get_reformulation(space)
conquer_alg = get_conquer(space)
conquer_input = get_input(conquer_alg, space, current)
conquer_output = run!(conquer_alg, env, reform, conquer_input)
after_conquer!(space, current, conquer_output) # callback to do some operations after the conquer.
conquer_output = TreeSearch.get_conquer_output(current)
if conquer_output === nothing
conquer_alg = get_conquer(space)
conquer_input = get_input(conquer_alg, space, current)
conquer_output = run!(conquer_alg, env, reform, conquer_input)
after_conquer!(space, current, conquer_output) # callback to do some operations after the conquer.
end
# Build the divide input from the conquer output
divide_alg = get_divide(space)
divide_input = get_input(divide_alg, space, current, conquer_output)
Expand Down
11 changes: 4 additions & 7 deletions src/Algorithm/treesearch/branch_and_bound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ mutable struct Node <: TreeSearch.AbstractNode
# Information to restore the reformulation after the creation of the node (e.g. creation
# of the branching constraint) or its partial evaluation (e.g. strong branching).
records::Records

conquerwasrun::Bool # TODO: rename: full_evaluation (?)
end

getdepth(n::Node) = n.depth

TreeSearch.isroot(n::Node) = n.depth == 0
Branching.isroot(n::Node) = TreeSearch.isroot(n)
TreeSearch.set_records!(n::Node, records) = n.records = records
TreeSearch.get_conquer_output(n::Node) = n.conquer_output

TreeSearch.get_branch_description(n::Node) = n.branchdescription # printer

Expand All @@ -45,10 +44,9 @@ TreeSearch.get_priority(::TreeSearch.BestDualBoundStrategy, n::Node) = n.ip_dual

# TODO move
function Node(node::SbNode)
ip_dual_bound = get_ip_dual_bound(node.optstate)
return Node(
node.depth, node.branchdescription, node.optstate, ip_dual_bound,
node.records, node.conquerwasrun
node.depth, node.branchdescription, node.conquer_output, node.ip_dual_bound,
node.records
)
end

Expand Down Expand Up @@ -203,7 +201,7 @@ end
function TreeSearch.new_root(sp::BaBSearchSpace, input)
nodestate = OptimizationState(getmaster(sp.reformulation), input, false, false)
return Node(
0, "", nothing, get_ip_dual_bound(nodestate), create_records(sp.reformulation), false
0, "", nothing, get_ip_dual_bound(nodestate), create_records(sp.reformulation)
)
end

Expand All @@ -215,7 +213,6 @@ function after_conquer!(space::BaBSearchSpace, current, conquer_output)
store_ip_primal_sol!(space.inc_primal_manager, sol)
end
current.records = create_records(space.reformulation)
current.conquerwasrun = true
space.nb_nodes_treated += 1

# Branch & Bound returns the primal LP & the dual solution found at the root node.
Expand Down
42 changes: 16 additions & 26 deletions src/Branching/Branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ abstract type AbstractDivideContext end

# branching output
"""
new_divide_output(children::Union{Vector{N}, Nothing}, ip_primal_sols_found{C, Nothing}) where {N, C} -> AbstractDivideOutput
new_divide_output(children::Union{Vector{N}, Nothing}) where {N} -> AbstractDivideOutput
where:
- `N` is the type of nodes generated by the branching algorithm.
- `C` is the type of the collection that stores all ip primal solutions found by the branching algorithm.
If no nodes nor ip primal solutions are found, the generic implementation may provide `nothing`.
If no nodes are found, the generic implementation may provide `nothing`.
"""
@mustimplement "BranchingOutput" new_divide_output(children, ip_primal_sols_found) = nothing
@mustimplement "BranchingOutput" new_divide_output(children) = nothing

# Default implementations.
"Candidates selection for branching algorithms."
Expand All @@ -89,7 +88,7 @@ abstract type AbstractBranchingContext <: AbstractDivideContext end

function advanced_select!(ctx::AbstractBranchingContext, candidates, env, reform, input::AbstractDivideInput)
children = generate_children!(first(candidates), env, reform, input)
return new_divide_output(children, nothing)
return new_divide_output(children)
end

############################################################################################
Expand Down Expand Up @@ -124,13 +123,8 @@ strong branching phase.
"Returns the maximum number of candidates kept at the end of a given strong branching phase."
@mustimplement "StrongBranching" get_max_nb_candidates(::AbstractStrongBrPhaseContext) = nothing

""
@mustimplement "StrongBranchingOptState" new_ip_primal_sols_pool(ctx, reform, input) = nothing

function advanced_select!(ctx::Branching.AbstractStrongBrContext, candidates, env, reform, input::Branching.AbstractDivideInput)
return perform_strong_branching!(ctx, env, reform, input, candidates)
#children = get_children(first(candidates))
#return new_divide_output(children, ip_primal_sols_found)
end

function perform_strong_branching!(
Expand All @@ -142,10 +136,6 @@ end
function perform_strong_branching_inner!(
ctx::AbstractStrongBrContext, env, reform, input::Branching.AbstractDivideInput, candidates::Vector{C}
) where {C<:AbstractBranchingCandidate}
# We will store all the new ip primal solution found during the strong branching in the
# following data structure.
ip_primal_sols_found = new_ip_primal_sols_pool(ctx, reform, input)

cand_children = [generate_children!(candidate, env, reform, input) for candidate in candidates]

phases = get_phases(ctx)
Expand All @@ -163,7 +153,7 @@ function perform_strong_branching_inner!(
nb_candidates_for_next_phase = min(nb_candidates_for_next_phase, length(cand_children))
end

scores = perform_branching_phase!(candidates, cand_children, current_phase, ip_primal_sols_found, env, reform, input)
scores = perform_branching_phase!(candidates, cand_children, current_phase, env, reform, input)

perm = sortperm(scores, rev=true)
permute!(cand_children, perm)
Expand All @@ -178,15 +168,15 @@ function perform_strong_branching_inner!(
resize!(cand_children, nb_candidates_for_next_phase)
resize!(candidates, nb_candidates_for_next_phase)
end
return new_divide_output(first(cand_children), ip_primal_sols_found)
return new_divide_output(first(cand_children))
end

function perform_branching_phase!(candidates, cand_children, phase, ip_primal_sols_found, env, reform, input)
return perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_found, env, reform, input)
function perform_branching_phase!(candidates, cand_children, phase, env, reform, input)
return perform_branching_phase_inner!(cand_children, phase, env, reform, input)
end

"Performs a branching phase."
function perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_found, env, reform, input)
function perform_branching_phase_inner!(cand_children, phase, env, reform, input)

return map(cand_children) do children
# TODO; I don't understand why we need to sort the children here.
Expand All @@ -205,24 +195,24 @@ function perform_branching_phase_inner!(cand_children, phase, ip_primal_sols_fou
# by = child -> get_lp_primal_bound(child)
# )

return eval_candidate!(children, phase, ip_primal_sols_found, env, reform, input)
return eval_candidate!(children, phase, env, reform, input)
end
end

function eval_candidate!(children, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
return eval_candidate_inner!(children, phase, ip_primal_sols_found, env, reform, input)
function eval_candidate!(children, phase::AbstractStrongBrPhaseContext, env, reform, input)
return eval_candidate_inner!(children, phase, env, reform, input)
end

"Evaluates a candidate."
function eval_candidate_inner!(children, phase::AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
function eval_candidate_inner!(children, phase::AbstractStrongBrPhaseContext, env, reform, input)
for child in children
eval_child_of_candidate!(child, phase, ip_primal_sols_found, env, reform, input)
eval_child_of_candidate!(child, phase, env, reform, input)
end
return compute_score(get_score(phase), children, input)
end

"Evaluate children of a candidate."
@mustimplement "StrongBranching" eval_child_of_candidate!(child, phase, ip_primal_sols_found, env, reform, input) = nothing
@mustimplement "StrongBranching" eval_child_of_candidate!(child, phase, env, reform, input) = nothing

@mustimplement "Branching" isroot(node) = nothing

Expand Down Expand Up @@ -302,7 +292,7 @@ function run_branching!(ctx, env, reform, input::Branching.AbstractDivideInput,
if length(candidates) == 0
@warn "No candidate generated. No children will be generated. However, the node is not conquered."
why_no_candidate(reform, input, extended_sol, original_sol)
return new_divide_output(nothing, nothing)
return new_divide_output(nothing)
end
return advanced_select!(ctx, candidates, env, reform, input)
end
Expand Down
3 changes: 3 additions & 0 deletions src/TreeSearch/TreeSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ abstract type AbstractNode end
"Returns the priority of the node depending on the explore strategy."
@mustimplement "Node" get_priority(::AbstractExploreStrategy, ::AbstractNode) = nothing

"Returns the conquer output if the conquer was already run for this node, otherwise returns nothing"
get_conquer_output(::AbstractNode) = nothing

##### Additional methods for the node interface (needed by conquer)
## TODO: move outside TreeSearch module.
@mustimplement "Node" set_records!(::AbstractNode, records) = nothing
Expand Down
2 changes: 1 addition & 1 deletion test/unit/Branching/branching_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function test_strong_branching()
# TODO: interface to register Records.
records = Coluna.Algorithm.Records()
node = Coluna.Algorithm.Node(
0, "", nothing, MathProg.DualBound(reform), records, false
0, "", nothing, MathProg.DualBound(reform), records
)

global_primal_handler = Coluna.Algorithm.GlobalPrimalBoundHandler(reform)
Expand Down
Loading

0 comments on commit 5000e14

Please sign in to comment.