Skip to content

Commit

Permalink
Run conquer algorithm in TreeSearch only if conquer_output of the n…
Browse files Browse the repository at this point in the history
…ode is empty (#1076)

* Get rid of Node.conquerwasrun, instead we maintain conquer_output
Now, we run conquer only if conquer_output is empty

* Correction to pass the tests

* Forgot to commit a file
  • Loading branch information
rrsadykov authored Sep 27, 2023
1 parent cb8cf62 commit adea8e0
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 32 deletions.
6 changes: 2 additions & 4 deletions src/Algorithm/branching/branchingalgo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ end

function Branching.eval_child_of_candidate!(child, phase::Branching.AbstractStrongBrPhaseContext, ip_primal_sols_found, env, reform, input)
child_state = OptimizationState(getmaster(reform))
child.optstate = child_state
child.conquer_output = child_state

# In the `ip_primal_sols_found`, we maintain all the primal solutions found during the
# strong branching procedure but also the best primal bound found so far (in the whole optimization).
Expand All @@ -284,11 +284,9 @@ function Branching.eval_child_of_candidate!(child, phase::Branching.AbstractStro
units_to_restore = Branching.get_units_to_restore_for_conquer(phase)
restore_from_records!(units_to_restore, child.records)
conquer_input = ConquerInputFromSb(Branching.get_global_primal_handler(input), child, units_to_restore)
child_state = run!(Branching.get_conquer(phase), env, reform, conquer_input)
child.optstate = child_state
child.conquer_output = run!(Branching.get_conquer(phase), env, reform, conquer_input)
TreeSearch.set_records!(child, create_records(reform))
end
child.conquerwasrun = true

# Store new primal solutions found during the evaluation of the child.
add_ip_primal_sols!(ip_primal_sols_found, get_ip_primal_sols(child_state)...)
Expand Down
2 changes: 1 addition & 1 deletion src/Algorithm/branching/printer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function Branching.perform_branching_phase!(candidates, cand_children, phase::Ph
@printf " (lhs=%.4f) : [" Branching.get_lhs(candidate)
for (node_index, node) in enumerate(children)
node_index > 1 && print(",")
@printf "%10.4f" getvalue(get_lp_primal_bound(node.optstate))
@printf "%10.4f" getvalue(get_lp_primal_bound(node.conquer_output))
end
@printf "], score = %10.4f\n" score
end
Expand Down
14 changes: 4 additions & 10 deletions src/Algorithm/branching/sbnode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,21 @@ mutable struct SbNode <: TreeSearch.AbstractNode
# There information are printed by the StrongBranchingPrinter.
# These information will be then transfered to the B&B algorithm when instantating the
# node of the tree search.
optstate::OptimizationState
conquer_output::Union{Nothing, OptimizationState}

var_name::String
branchdescription::String
ip_dual_bound::Bound
records::Records
conquerwasrun::Bool
function SbNode(
reform::Reformulation, depth, var_name::String, branch_description::String, records::Records, input
depth, branch_description::String, ip_dual_bound::Bound, records::Records
)
node_state = OptimizationState(
getmaster(reform);
ip_dual_bound = get_ip_dual_bound(Branching.get_conquer_opt_state(input))
)
return new(depth, node_state, var_name, branch_description, records, false)
return new(depth, nothing, branch_description, ip_dual_bound, records)
end
end

getdepth(n::SbNode) = n.depth

TreeSearch.set_records!(n::SbNode, records) = n.records = records
TreeSearch.get_branch_description(n::SbNode) = n.branchdescription
get_var_name(n::SbNode) = n.var_name
TreeSearch.isroot(n::SbNode) = false
Branching.isroot(n::SbNode) = TreeSearch.isroot(n)
4 changes: 2 additions & 2 deletions src/Algorithm/branching/scores.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function Branching.compute_score(::ProductScore, children, input)
parent = Branching.get_conquer_opt_state(input)
parent_lp_dual_bound = get_lp_dual_bound(parent)
parent_ip_primal_bound = get_ip_primal_bound(parent)
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, Ref(:optstate)))
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, Ref(:conquer_output)))
return _product_score(parent_lp_dual_bound, parent_ip_primal_bound, children_lp_primal_bounds)
end

Expand All @@ -14,7 +14,7 @@ function Branching.compute_score(::TreeDepthScore, children, input)
parent = Branching.get_conquer_opt_state(input)
parent_lp_dual_bound = get_lp_dual_bound(parent)
parent_ip_primal_bound = get_ip_primal_bound(parent)
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, :optstate))
children_lp_primal_bounds = get_lp_primal_bound.(getfield.(children, :conquer_output))
return _tree_depth_score(parent_lp_dual_bound, parent_ip_primal_bound, children_lp_primal_bounds)
end

Expand Down
6 changes: 4 additions & 2 deletions src/Algorithm/branching/single_var_branching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ function Branching.generate_children!(
units_to_restore = get_branching_candidate_units_usage(candidate, reform)
d = Branching.get_parent_depth(input)

parent_ip_dual_bound = get_ip_dual_bound(Branching.get_conquer_opt_state(input))

# adding the first branching constraints
restore_from_records!(units_to_restore, Branching.parent_records(input))
setconstr!(
Expand All @@ -57,7 +59,7 @@ function Branching.generate_children!(
members = Dict{VarId,Float64}(candidate.varid => 1.0)
)
child1description = candidate.varname * ">=" * string(ceil(lhs))
child1 = SbNode(reform, d+1, candidate.varname, child1description, create_records(reform), input)
child1 = SbNode(d+1, child1description, parent_ip_dual_bound, create_records(reform))

# adding the second branching constraints
restore_from_records!(units_to_restore, Branching.parent_records(input))
Expand All @@ -71,7 +73,7 @@ function Branching.generate_children!(
members = Dict{VarId,Float64}(candidate.varid => 1.0)
)
child2description = candidate.varname * "<=" * string(floor(lhs))
child2 = SbNode(reform, d+1, candidate.varname, child2description, create_records(reform), input)
child2 = SbNode(d+1, child2description, parent_ip_dual_bound, create_records(reform))

return [child1, child2]
end
Expand Down
11 changes: 7 additions & 4 deletions src/Algorithm/treesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,13 @@ function TreeSearch.children(space::AbstractColunaSearchSpace, current::TreeSear
# Else we run the conquer algorithm.
# This algorithm has the responsibility to check whether the node is pruned.
reform = get_reformulation(space)
conquer_alg = get_conquer(space)
conquer_input = get_input(conquer_alg, space, current)
conquer_output = run!(conquer_alg, env, reform, conquer_input)
after_conquer!(space, current, conquer_output) # callback to do some operations after the conquer.
conquer_output = TreeSearch.get_conquer_output(current)
if conquer_output === nothing
conquer_alg = get_conquer(space)
conquer_input = get_input(conquer_alg, space, current)
conquer_output = run!(conquer_alg, env, reform, conquer_input)
after_conquer!(space, current, conquer_output) # callback to do some operations after the conquer.
end
# Build the divide input from the conquer output
divide_alg = get_divide(space)
divide_input = get_input(divide_alg, space, current, conquer_output)
Expand Down
11 changes: 4 additions & 7 deletions src/Algorithm/treesearch/branch_and_bound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ mutable struct Node <: TreeSearch.AbstractNode
# Information to restore the reformulation after the creation of the node (e.g. creation
# of the branching constraint) or its partial evaluation (e.g. strong branching).
records::Records

conquerwasrun::Bool # TODO: rename: full_evaluation (?)
end

getdepth(n::Node) = n.depth

TreeSearch.isroot(n::Node) = n.depth == 0
Branching.isroot(n::Node) = TreeSearch.isroot(n)
TreeSearch.set_records!(n::Node, records) = n.records = records
TreeSearch.get_conquer_output(n::Node) = n.conquer_output

TreeSearch.get_branch_description(n::Node) = n.branchdescription # printer

Expand All @@ -45,10 +44,9 @@ TreeSearch.get_priority(::TreeSearch.BestDualBoundStrategy, n::Node) = n.ip_dual

# TODO move
function Node(node::SbNode)
ip_dual_bound = get_ip_dual_bound(node.optstate)
return Node(
node.depth, node.branchdescription, node.optstate, ip_dual_bound,
node.records, node.conquerwasrun
node.depth, node.branchdescription, node.conquer_output, node.ip_dual_bound,
node.records
)
end

Expand Down Expand Up @@ -203,7 +201,7 @@ end
function TreeSearch.new_root(sp::BaBSearchSpace, input)
nodestate = OptimizationState(getmaster(sp.reformulation), input, false, false)
return Node(
0, "", nothing, get_ip_dual_bound(nodestate), create_records(sp.reformulation), false
0, "", nothing, get_ip_dual_bound(nodestate), create_records(sp.reformulation)
)
end

Expand All @@ -215,7 +213,6 @@ function after_conquer!(space::BaBSearchSpace, current, conquer_output)
store_ip_primal_sol!(space.inc_primal_manager, sol)
end
current.records = create_records(space.reformulation)
current.conquerwasrun = true
space.nb_nodes_treated += 1

# Branch & Bound returns the primal LP & the dual solution found at the root node.
Expand Down
3 changes: 3 additions & 0 deletions src/TreeSearch/TreeSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ abstract type AbstractNode end
"Returns the priority of the node depending on the explore strategy."
@mustimplement "Node" get_priority(::AbstractExploreStrategy, ::AbstractNode) = nothing

"Returns the conquer output if the conquer was already run for this node, otherwise returns nothing"
get_conquer_output(::AbstractNode) = nothing

##### Additional methods for the node interface (needed by conquer)
## TODO: move outside TreeSearch module.
@mustimplement "Node" set_records!(::AbstractNode, records) = nothing
Expand Down
2 changes: 1 addition & 1 deletion test/unit/Branching/branching_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function test_strong_branching()
# TODO: interface to register Records.
records = Coluna.Algorithm.Records()
node = Coluna.Algorithm.Node(
0, "", nothing, MathProg.DualBound(reform), records, false
0, "", nothing, MathProg.DualBound(reform), records
)

global_primal_handler = Coluna.Algorithm.GlobalPrimalBoundHandler(reform)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/TreeSearch/treesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ end

# constructs a real node from a LightNode, used in new_children to built real children from the minimal information contained in LightNode
function Coluna.Algorithm.Node(node::LightNode)
return Coluna.Algorithm.Node(node.depth, " ", nothing, node.parent_ip_dual_bound, Coluna.Algorithm.Records(), false)
return Coluna.Algorithm.Node(node.depth, " ", nothing, node.parent_ip_dual_bound, Coluna.Algorithm.Records())
end

## The candidates are passed as LightNodes and the current node is passed as a TestBaBNode. The method retrieves the inner nodes to run the native method new_children of branch_and_bound.jl, gets the result as a vector of Nodes and then re-built a solution as a vector of TestBaBNodes using the nodes ids contained in LightNode structures.
Expand Down

0 comments on commit adea8e0

Please sign in to comment.