From cfd98902bac1026a18a4f40cf80b1ca86f3fe0c3 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Sun, 10 Nov 2024 00:26:14 -0800 Subject: [PATCH] expand model_facade redo var handle tracking --- .../lite/micro/compression/model_facade.py | 8 ++ .../lite/micro/compression/relocate_ops.py | 118 +++++++++++++----- 2 files changed, 95 insertions(+), 31 deletions(-) diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py index f2b920af760..6136fa43313 100644 --- a/tensorflow/lite/micro/compression/model_facade.py +++ b/tensorflow/lite/micro/compression/model_facade.py @@ -160,6 +160,14 @@ def inputs_indices(self): def outputs_indices(self): return self.operator.outputs + @property + def builtin_options_type(self) -> int: + return self.operator.builtinOptionsType + + @property + def builtin_options(self): + return self.operator.builtinOptions + class Tensor: diff --git a/tensorflow/lite/micro/compression/relocate_ops.py b/tensorflow/lite/micro/compression/relocate_ops.py index 6d32041c0e3..df4ec935fdf 100644 --- a/tensorflow/lite/micro/compression/relocate_ops.py +++ b/tensorflow/lite/micro/compression/relocate_ops.py @@ -24,7 +24,7 @@ from absl import app from absl import flags from pathlib import Path -from typing import List, Set, Dict +from typing import List, Set, Dict, Tuple FLAGS = flags.FLAGS @@ -40,10 +40,24 @@ help='path for the .tflite output file', ) -VarHandles = Set[int] +VarHandleId = int +VarHandles = Set[VarHandleId] TensorIndex = int -ReadVarOps = List[Dict[TensorIndex, model_facade.Operator]] -ConcatOps = List[Dict[TensorIndex, model_facade.Operator]] +SubgraphIndex = int + +VarHandleByName = Dict[Tuple[str | None, str], VarHandleId] +"""VarHandleByName +{ ( container_name | None, resource_name ) : var_handle_id } +""" + +VarHandleByTensor = Dict[Tuple[SubgraphIndex, TensorIndex], VarHandleId] +"""VarHandleByTensor +{ ( subgraph_index, tensor_index ) : var_handle_id } +""" + +ReadVarOps = List[Dict[int, model_facade.Operator]] + +ConcatOps = List[Dict[int, model_facade.Operator]] class Context: @@ -54,13 +68,16 @@ class Context: def __init__(self, model: model_facade.Model) -> None: self._model = model self._current_op_stack: List[model_facade.Operator] = [] - self._operators: List[List[model_facade.Operator]] = [[]] * len( + self._reordered_operators: List[List[model_facade.Operator]] = [[]] * len( model.subgraphs) self._subgraph_processed: List[bool] = [False] * len(model.subgraphs) self._subgraph_modified_vars: List[VarHandles] = [set()] * len( model.subgraphs) self._subgraph_read_var_ops: ReadVarOps = [{}] * len(model.subgraphs) self._subgraph_concat_ops: ConcatOps = [{}] * len(model.subgraphs) + self._var_handles_by_name: VarHandleByName = {} + self._var_handles_by_tensor: VarHandleByTensor = {} + self._current_var_handle_id: VarHandleId = 0 @property def model(self): @@ -75,26 +92,27 @@ def push_current_op(self, op: model_facade.Operator) -> None: def pop_current_op(self) -> None: _ = self._current_op_stack.pop() - def append_to_operations(self, op: model_facade.Operator) -> None: + def append_to_reordered_operations(self, op: model_facade.Operator) -> None: subgraph_index: int = op.subgraph.index - new_op = model_facade.Operator(op.operator, - len(self._operators[subgraph_index]), - op.subgraph) - self._operators[subgraph_index].append(new_op) + new_op = model_facade.Operator( + op.operator, len(self._reordered_operators[subgraph_index]), + op.subgraph) + self._reordered_operators[subgraph_index].append(new_op) - def operations(self, subgraph_index: int) -> List[model_facade.Operator]: - return self._operators[subgraph_index] + def reordered_operations( + self, subgraph_index: SubgraphIndex) -> List[model_facade.Operator]: + return self._reordered_operators[subgraph_index] - def is_subgraph_processed(self, subgraph_index: int) -> bool: + def is_subgraph_processed(self, subgraph_index: SubgraphIndex) -> bool: return self._subgraph_processed[subgraph_index] - def mark_subgraph_processed(self, subgraph_index: int) -> None: + def mark_subgraph_processed(self, subgraph_index: SubgraphIndex) -> None: self._subgraph_processed[subgraph_index] = True - def subgraph_var_handles(self, subgraph_index: int) -> VarHandles: + def subgraph_var_handles(self, subgraph_index: SubgraphIndex) -> VarHandles: return self._subgraph_modified_vars[subgraph_index] - def set_subgraph_var_handles(self, subgraph_index: int, + def set_subgraph_var_handles(self, subgraph_index: SubgraphIndex, handles: VarHandles) -> None: self._subgraph_modified_vars[subgraph_index] = handles @@ -110,15 +128,17 @@ def remove_read_var_op(self, op: model_facade.Operator) -> None: def get_read_var_op_by_tensor( self, tensor_index: TensorIndex, - subgraph_index: int) -> model_facade.Operator | None: + subgraph_index: SubgraphIndex) -> model_facade.Operator | None: return self._subgraph_read_var_ops[subgraph_index].get(tensor_index, None) def get_read_var_op_by_handle( - self, tensor_index: TensorIndex, - subgraph_index: int) -> List[model_facade.Operator]: + self, resource_tensor_index: TensorIndex, + subgraph_index: SubgraphIndex) -> List[model_facade.Operator]: result: List[model_facade.Operator] = [] + var_handle_id = self.get_var_handle(subgraph_index, resource_tensor_index) for op in self._subgraph_read_var_ops[subgraph_index].values(): - if op.inputs_indices[0] == tensor_index: + if self.get_var_handle(op.subgraph.index, + op.inputs_indices[0]) == var_handle_id: result.append(op) return result @@ -134,13 +154,49 @@ def remove_concat_op(self, op: model_facade.Operator) -> None: def get_concat_op_by_tensor( self, tensor_index: TensorIndex, - subgraph_index: int) -> model_facade.Operator | None: + subgraph_index: SubgraphIndex) -> model_facade.Operator | None: return self._subgraph_concat_ops[subgraph_index].get(tensor_index, None) + def create_var_handle(self, container_name: str | None, resource_name: str, + subgraph_index: SubgraphIndex, + resource_tensor_index: TensorIndex) -> VarHandleId: + key = (container_name, resource_name) + var_handle_id = self._var_handles_by_name.get(key) + if var_handle_id is None: + var_handle_id = self._current_var_handle_id + self._current_var_handle_id += 1 + self._var_handles_by_name[key] = var_handle_id + + self.add_var_handle(subgraph_index, resource_tensor_index, var_handle_id) + + return var_handle_id + + def get_var_handle(self, subgraph_index: SubgraphIndex, + resource_tensor_index: TensorIndex) -> VarHandleId: + return self._var_handles_by_tensor[(subgraph_index, resource_tensor_index)] + + def add_var_handle(self, subgraph_index: SubgraphIndex, + resource_tensor_index: TensorIndex, + var_handle_id: VarHandleId) -> None: + key = (subgraph_index, resource_tensor_index) + assert self._var_handles_by_tensor.get(key, None) is None + self._var_handles_by_tensor[key] = var_handle_id + + +# Begin global methods + def process_operator_var_handle(context: Context) -> VarHandles: - context.append_to_operations(context.current_op()) - return set() + op = context.current_op() + assert op.builtin_options_type == tflite.BuiltinOptions.VarHandleOptions + assert op.builtin_options is not None + container_name: str = op.builtin_options.container + resource_name: str = op.builtin_options.sharedName + var_handle_id = context.create_var_handle(container_name, resource_name, + op.subgraph.index, + op.outputs_indices[0]) + context.append_to_reordered_operations(op) + return set([var_handle_id]) def process_operator_assign_variable(context: Context) -> VarHandles: @@ -151,15 +207,15 @@ def process_operator_assign_variable(context: Context) -> VarHandles: read_var_op = context.get_read_var_op_by_tensor(assign_op.inputs_indices[1], assign_op.subgraph.index) if read_var_op is not None: - context.append_to_operations(read_var_op) + context.append_to_reordered_operations(read_var_op) context.remove_read_var_op(read_var_op) for read_var_op in context.get_read_var_op_by_handle( assign_op.inputs_indices[0], assign_op.subgraph.index): - context.append_to_operations(read_var_op) + context.append_to_reordered_operations(read_var_op) context.remove_read_var_op(read_var_op) - context.append_to_operations(assign_op) + context.append_to_reordered_operations(assign_op) return set() @@ -210,18 +266,18 @@ def process_operator(context: Context) -> VarHandles: read_var_op = context.get_read_var_op_by_tensor( concat_tensor_input, op.subgraph.index) if read_var_op is not None: - context.append_to_operations(read_var_op) + context.append_to_reordered_operations(read_var_op) context.remove_read_var_op(read_var_op) - context.append_to_operations(concat_op) + context.append_to_reordered_operations(concat_op) context.remove_concat_op(concat_op) read_var_op = context.get_read_var_op_by_tensor(tensor_input, op.subgraph.index) if read_var_op is not None: - context.append_to_operations(read_var_op) + context.append_to_reordered_operations(read_var_op) context.remove_read_var_op(read_var_op) - context.append_to_operations(op) + context.append_to_reordered_operations(op) return set() @@ -241,7 +297,7 @@ def process_subgraph(context: Context, subgraph_index: int) -> VarHandles: context.pop_current_op() operators: List[tflite.OperatorT] = [] - for op in context.operations(subgraph_index): + for op in context.reordered_operations(subgraph_index): operators.append(op.operator) context.model.root.subgraphs[subgraph_index].operators = operators