Skip to content

Commit

Permalink
1. Maxcut runtime optimization.
Browse files Browse the repository at this point in the history
2. Add "warning experimental" to maxcut feature.
3. add timeout to maxcut solver after 1st iteration.
  • Loading branch information
elad-c committed Jan 15, 2025
1 parent 228c35f commit 0fd48ce
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
# ==============================================================================
from collections import namedtuple

from typing import Tuple, List
import timeout_decorator

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.memory_graph.cut import Cut
Expand Down Expand Up @@ -47,9 +48,25 @@ def compute_graph_max_cut(memory_graph: MemoryGraph,
l_bound = memory_graph.memory_lbound_single_op
u_bound = 2 * sum([t.total_size for t in memory_graph.b_nodes]) - l_bound
it = 0

@timeout_decorator.timeout(300)
def solver_wrapper(_estimate, _iter_limit):
return max_cut_astar.solve(estimate=_estimate, iter_limit=_iter_limit)

while it < n_iter:
estimate = (u_bound + l_bound) / 2
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter)
if it == 0:
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter)
else:
try:
schedule, max_cut_size, cuts = solver_wrapper(estimate=estimate, iter_limit=astar_n_iter)
except timeout_decorator.TimeoutError:
if last_result[0] is None:
Logger.critical(f"Max-cut solver stopped on timeout in iteration {it} before finding a solution.") # pragma: no cover
else:
Logger.warning(f"Max-cut solver stopped on timeout in iteration {it}.")
return last_result

if schedule is None:
l_bound = estimate
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __eq__(self, other) -> bool:
return False # pragma: no cover

def __hash__(self):
return hash((frozenset(self.op_order), frozenset(self.op_record), self.mem_elements))
return id(self)

def __repr__(self):
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
return f"<Cut: Nodes={[e.node_name for e in self.mem_elements.elements]}, size={self.memory_size()}>" # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
import copy
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Set

from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.constants import DUMMY_TENSOR, DUMMY_NODE
Expand Down Expand Up @@ -139,8 +139,8 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode],
"""

open_list = [self.src_cut]
closed_list = []
open_list = {self.src_cut}
closed_list = set()
costs = {self.src_cut: self.src_cut.memory_size()}
routes = {self.src_cut: [self.src_cut]}

Expand All @@ -159,22 +159,21 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode],

if self.is_pivot(next_cut):
# Can clear all search history
open_list = []
closed_list = []
open_list.clear()
closed_list.clear()
routes = {}
else:
# Can remove only next_cut and put it in closed_list
open_list.remove(next_cut)
del routes[next_cut]
closed_list.append(next_cut)
closed_list.add(next_cut)

# Expand the chosen cut
expanded_cuts = self.expand(next_cut)
expansion_count += 1

# Only consider nodes that where not already visited
expanded_cuts = [_c for _c in expanded_cuts if _c not in closed_list]
for c in expanded_cuts:
for c in filter(lambda _c: _c not in closed_list, expanded_cuts):
cost = self.accumulate(cut_cost, c.memory_size())
if c not in open_list:
self._update_expanded_node(c, cost, cut_route, open_list, costs, routes)
Expand All @@ -192,7 +191,7 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode],
return None, 0, None # pragma: no cover

@staticmethod
def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: List[Cut],
def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: Set[Cut],
costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]]):
"""
An auxiliary method for updating search data structures according to an expanded node.
Expand All @@ -201,16 +200,16 @@ def _update_expanded_node(cut: Cut, cost: float, route: List[Cut], open_list: Li
cut: A cut to expand the search to.
cost: The cost of the cut.
route: The rout to the cut.
open_list: The search open list.
open_list: The search open set.
costs: The search utility mapping between cuts and their cost.
routes: The search utility mapping between cuts and their routes.
"""
open_list.append(cut)
open_list.add(cut)
costs.update({cut: cost})
routes.update({cut: [cut] + route})

def _get_cut_to_expand(self, open_list: List[Cut], costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]],
def _get_cut_to_expand(self, open_list: Set[Cut], costs: Dict[Cut, float], routes: Dict[Cut, List[Cut]],
estimate: float) -> Cut:
"""
An auxiliary method for finding a cut for expanding the search out of a set of potential cuts for expansion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
from typing import List
from operator import getitem
from functools import cache

from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.graph.edge import EDGE_SOURCE_INDEX
Expand Down Expand Up @@ -82,7 +83,6 @@ def __init__(self, model_graph: Graph):
inputs_tensors_memory = [sum([t.total_size for t in self.operation_node_children(n)])
for n in nodes if n in model_graph.get_inputs()]

# TODO maxcut: why both inputs and outputs of each nodes, while the A* solves for node outputs only???
nodes_total_memory = [sum([t.total_size for t in self.operation_node_children(n)] +
[t.total_size for t in self.operation_node_parents(n)])
for n in nodes if n not in model_graph.get_inputs()]
Expand Down Expand Up @@ -117,6 +117,7 @@ def update_sinks_b(self):
"""
self.sinks_b = [n for n in self.b_nodes if len(list(self.successors(n))) == 0]

@cache
def activation_tensor_children(self, activation_tensor: ActivationMemoryTensor) -> List[BaseNode]:
"""
Returns the children nodes of a side B node (activation tensor) in the bipartite graph.
Expand All @@ -129,6 +130,7 @@ def activation_tensor_children(self, activation_tensor: ActivationMemoryTensor)
"""
return [oe[1] for oe in self.out_edges(activation_tensor)]

@cache
def activation_tensor_parents(self, activation_tensor: ActivationMemoryTensor) -> List[BaseNode]:
"""
Returns the parents nodes of a side B node (activation tensor) in the bipartite graph.
Expand All @@ -141,6 +143,7 @@ def activation_tensor_parents(self, activation_tensor: ActivationMemoryTensor) -
"""
return [ie[0] for ie in self.in_edges(activation_tensor)]

@cache
def operation_node_children(self, op_node: BaseNode) -> List[ActivationMemoryTensor]:
"""
Returns the children nodes of a side A node (operation) in the bipartite graph.
Expand All @@ -153,6 +156,7 @@ def operation_node_children(self, op_node: BaseNode) -> List[ActivationMemoryTen
"""
return [oe[1] for oe in self.out_edges(op_node)]

@cache
def operation_node_parents(self, op_node: BaseNode) -> List[ActivationMemoryTensor]:
"""
Returns the parents nodes of a side A node (operation) in the bipartite graph.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from enum import Enum, auto
from typing import Dict, NamedTuple, Optional, Tuple, List, Iterable, Union, Literal, Sequence, Set

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.core import FrameworkInfo
from model_compression_toolkit.core.common import Graph, BaseNode
Expand Down Expand Up @@ -169,6 +170,7 @@ def compute_resource_utilization(self,
w_total, *_ = self.compute_weights_utilization(target_criterion, bitwidth_mode, w_qcs)

if {RUTarget.ACTIVATION, RUTarget.TOTAL}.intersection(ru_targets):
Logger.warning("Using an experimental feature max-cut for activation memory utilization estimation.")
a_total = self.compute_activations_utilization(target_criterion, bitwidth_mode, act_qcs)

ru = ResourceUtilization()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _add_ru_constraints(search_manager: MixedPrecisionSearchManager,
target_resource_utilization: ResourceUtilization,
indicators_matrix: np.ndarray,
lp_problem: LpProblem,
non_conf_ru_dict: Optional[Dict[RUTarget, np.ndarray]]):
non_conf_ru_dict: Dict[RUTarget, np.ndarray]):
"""
Adding targets constraints for the Lp problem for the given target resource utilization.
The update to the Lp problem object is done inplace.
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ matplotlib<3.10.0
scipy
protobuf
mct-quantizers==1.5.2
pydantic<2.0
pydantic<2.0
timeout-decorator

0 comments on commit 0fd48ce

Please sign in to comment.