From c0604d8803cd753ab2c69a6fa79b971b4aa9bdb2 Mon Sep 17 00:00:00 2001 From: ddavis-2015 Date: Fri, 20 Dec 2024 17:48:22 -0800 Subject: [PATCH] Updates to operator reordering script. --- tensorflow/lite/micro/compression/BUILD | 13 -- .../lite/micro/compression/model_facade.py | 10 +- tensorflow/lite/micro/tools/BUILD | 14 ++ .../relocate_ops.py => tools/reorder_ops.py} | 144 +++++++++++++----- 4 files changed, 123 insertions(+), 58 deletions(-) rename tensorflow/lite/micro/{compression/relocate_ops.py => tools/reorder_ops.py} (68%) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index a3147c90ff6..8e037260215 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -152,19 +152,6 @@ py_test( ], ) -py_binary( - name = "relocate_ops", - srcs = [ - "relocate_ops.py", - ], - deps = [ - "model_facade", - "//tensorflow/lite/python:schema_py", - "@absl_py//absl:app", - "@absl_py//absl/flags", - ], -) - py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py index fafa0d92c5e..06c14d072c1 100644 --- a/tensorflow/lite/micro/compression/model_facade.py +++ b/tensorflow/lite/micro/compression/model_facade.py @@ -30,7 +30,7 @@ import numpy as np from numpy.typing import NDArray from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from typing import ByteString, Generic, TypeVar +from typing import ByteString, Generic, TypeVar, List _IteratorTo = TypeVar("_IteratorTo") @@ -116,11 +116,11 @@ def outputs(self): return _IndirectIterator(self.operator.outputs, self.subgraph.tensors) @property - def inputs_indices(self): + def inputs_indices(self) -> List[int]: return self.operator.inputs @property - def outputs_indices(self): + def outputs_indices(self) -> List[int]: return self.operator.outputs @property @@ -235,6 +235,10 @@ def operators(self) -> _Iterator[_Operator]: def tensors(self) -> _Iterator[_Tensor]: return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) + @property + def outputs_indices(self) -> List[int]: + return self._subgraph_t.outputs + class _Model: """A facade for manipulating tflite.Model. diff --git a/tensorflow/lite/micro/tools/BUILD b/tensorflow/lite/micro/tools/BUILD index 2d1e1874280..a6c2d45924c 100644 --- a/tensorflow/lite/micro/tools/BUILD +++ b/tensorflow/lite/micro/tools/BUILD @@ -223,3 +223,17 @@ flatbuffer_py_library( name = "layer_by_layer_schema_py", srcs = ["layer_by_layer_schema.fbs"], ) + +py_binary( + name = "reorder_ops", + srcs = [ + "reorder_ops.py", + ], + deps = [ + ":model_transforms_utils", + "//tensorflow/lite/micro/compression:model_facade", + "//tensorflow/lite/python:schema_py", + "@absl_py//absl:app", + "@absl_py//absl/flags", + ], +) diff --git a/tensorflow/lite/micro/compression/relocate_ops.py b/tensorflow/lite/micro/tools/reorder_ops.py similarity index 68% rename from tensorflow/lite/micro/compression/relocate_ops.py rename to tensorflow/lite/micro/tools/reorder_ops.py index d2f082ac630..8f8945cc114 100644 --- a/tensorflow/lite/micro/compression/relocate_ops.py +++ b/tensorflow/lite/micro/tools/reorder_ops.py @@ -11,15 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" +r""" +*** EXPERIMENTAL *** +This is an experimental tool and is subject to change at any time + +This tool will allow reordering of the command line specified operators. +The reordered operators will be moved within their subgraph, such that they +are more closely colocated to another operator which consumes their output +tensor. + +The output model will be properly aligned as per the .tflite flatbuffer schema. + Usage: - bazel run tensorflow/lite/micro/tools:relocate_read_variable -- \\ - --input= --output= + bazel run tensorflow/lite/micro/tools:reorder_ops -- \ + --input= \ + --output= \ + --ops= """ -import model_facade - -from tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.tools import model_transforms_utils from absl import app from absl import flags @@ -30,14 +42,22 @@ flags.DEFINE_string( name='input', - default='', + default=None, help='path for the .tflite input file', + required=True, ) flags.DEFINE_string( name='output', - default='', + default=None, help='path for the .tflite output file', + required=True, +) + +flags.DEFINE_list( + name='ops', + default=[], + help='comma separated names of operators to reorder (case insensitive)', ) VarHandleId = int @@ -64,9 +84,13 @@ class Context: """ Context: + + The Context holds the stack of operators currently being processed, + the list of pending operations which may be relocated, + and the list of reordered operations representing the new subgraph(s) """ - def __init__(self, model: model_facade._Model) -> None: + def __init__(self, model: model_facade._Model, ops: List[int]) -> None: self._model = model self._current_op_stack: List[model_facade._Operator] = [] self._reordered_operators: List[List[model_facade._Operator]] = [[]] * len( @@ -78,6 +102,7 @@ def __init__(self, model: model_facade._Model) -> None: self._var_handles_by_name: VarHandleByName = {} self._var_handles_by_tensor: VarHandleByTensor = {} self._current_var_handle_id: VarHandleId = 0 + self._ops_to_reorder: List[int] = ops @property def model(self): @@ -130,13 +155,12 @@ def remove_pending_op(self, op: model_facade._Operator) -> None: def get_pending_op( self, tensor_index: TensorIndex, subgraph_index: SubgraphIndex) -> model_facade._Operator | None: - return self._pending_ops[subgraph_index].get(tensor_index, None) + return self._pending_ops[subgraph_index].get(tensor_index) - def get_read_var_op_by_handle( - self, resource_tensor_index: TensorIndex, + def get_pending_read_var_ops_by_handle( + self, var_handle_id: VarHandleId, subgraph_index: SubgraphIndex) -> List[model_facade._Operator]: result: List[model_facade._Operator] = [] - var_handle_id = self.get_var_handle(subgraph_index, resource_tensor_index) for op in self._pending_ops[subgraph_index].values(): if op.builtin_opcode != tflite.BuiltinOperator.READ_VARIABLE: continue @@ -145,6 +169,11 @@ def get_read_var_op_by_handle( result.append(op) return result + def can_be_pending_op(self, op: model_facade._Operator) -> bool: + return (op.builtin_opcode in self._ops_to_reorder + and op.outputs_indices is not None + and len(op.outputs_indices) == 1) + def create_var_handle(self, container_name: str | None, resource_name: str, subgraph_index: SubgraphIndex, resource_tensor_index: TensorIndex) -> VarHandleId: @@ -167,36 +196,38 @@ def add_var_handle(self, subgraph_index: SubgraphIndex, resource_tensor_index: TensorIndex, var_handle_id: VarHandleId) -> None: key = (subgraph_index, resource_tensor_index) - assert self._var_handles_by_tensor.get(key, None) is None + assert self._var_handles_by_tensor.get(key) is None self._var_handles_by_tensor[key] = var_handle_id # Begin global methods -def process_operator_var_handle(context: Context) -> VarHandles: +def process_operator_var_handle(context: Context) -> None: op = context.current_op() assert op.builtin_options_type == tflite.BuiltinOptions.VarHandleOptions assert op.builtin_options is not None container_name: str = op.builtin_options.container resource_name: str = op.builtin_options.sharedName - var_handle_id = context.create_var_handle(container_name, resource_name, - op.subgraph.index, - op.outputs_indices[0]) - context.append_to_reordered_operations(op) - return set([var_handle_id]) + _ = context.create_var_handle(container_name, resource_name, + op.subgraph.index, op.outputs_indices[0]) + if context.can_be_pending_op(op): + context.add_pending_op(op) + else: + context.append_to_reordered_operations(op) def process_operator_assign_variable(context: Context) -> VarHandles: assign_op = context.current_op() - for read_var_op in context.get_read_var_op_by_handle( - assign_op.inputs_indices[0], assign_op.subgraph.index): + var_handle_id = context.get_var_handle(assign_op.subgraph.index, + assign_op.inputs_indices[0]) + for read_var_op in context.get_pending_read_var_ops_by_handle( + var_handle_id, assign_op.subgraph.index): context.append_to_reordered_operations(read_var_op) context.remove_pending_op(read_var_op) process_pending_ops(context) - var_handle_id = context.get_var_handle(assign_op.subgraph.index, - assign_op.inputs_indices[0]) + return set([var_handle_id]) @@ -215,40 +246,52 @@ def process_operator_while(context: Context) -> VarHandles: return set() -def process_operator_as_pending(context: Context) -> VarHandles: - context.add_pending_op(context.current_op()) - return set() - - def process_pending_ops(context: Context) -> None: + """Process current operator against any pending operators. + Then add the current operator to the list of reordered operations. + """ op = context.current_op() - for tensor_input in op.inputs_indices: - pending_op = context.get_pending_op(tensor_input, op.subgraph.index) - if pending_op is not None: - context.remove_pending_op(pending_op) - context.push_current_op(pending_op) - process_pending_ops(context) - context.pop_current_op() + if op.inputs_indices is not None: + for tensor_input in op.inputs_indices: + pending_op = context.get_pending_op(tensor_input, op.subgraph.index) + if pending_op is not None: + context.remove_pending_op(pending_op) + context.push_current_op(pending_op) + process_pending_ops(context) + context.pop_current_op() context.append_to_reordered_operations(op) +def process_subgraph_pending_ops(context: Context, + subgraph_index: SubgraphIndex) -> None: + """Process subgraph outputs against any pending operators. + """ + outputs_indices = context.model.subgraphs[subgraph_index].outputs_indices + if outputs_indices is not None: + for tensor_index in outputs_indices: + pending_op = context.get_pending_op(tensor_index, subgraph_index) + if pending_op is not None: + context.remove_pending_op(pending_op) + context.push_current_op(pending_op) + process_pending_ops(context) + context.pop_current_op() + + def process_operator(context: Context) -> VarHandles: op = context.current_op() if op.builtin_opcode == tflite.BuiltinOperator.VAR_HANDLE: - return process_operator_var_handle(context) + process_operator_var_handle(context) elif op.builtin_opcode == tflite.BuiltinOperator.ASSIGN_VARIABLE: return process_operator_assign_variable(context) - elif op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE: - return process_operator_as_pending(context) - elif op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION: - return process_operator_as_pending(context) elif op.builtin_opcode == tflite.BuiltinOperator.IF: return process_operator_if(context) elif op.builtin_opcode == tflite.BuiltinOperator.WHILE: return process_operator_while(context) elif op.builtin_opcode == tflite.BuiltinOperator.CALL_ONCE: return process_operator_call_once(context) + elif context.can_be_pending_op(op): + context.add_pending_op(op) else: process_pending_ops(context) @@ -269,6 +312,8 @@ def process_subgraph(context: Context, subgraph_index: int) -> VarHandles: var_handles.update(var_handles_processed) context.pop_current_op() + process_subgraph_pending_ops(context, subgraph_index) + operators: List[tflite.OperatorT] = [] for op in context.reordered_operations(subgraph_index): operators.append(op.operator) @@ -280,6 +325,18 @@ def process_subgraph(context: Context, subgraph_index: int) -> VarHandles: return var_handles +def op_names_to_values(op_names: List[str]) -> List[int]: + op_values = [] + builtin_operators = vars(tflite.BuiltinOperator) + for name in op_names: + value = builtin_operators.get(name.upper()) + if value is None: + raise ValueError(f'unknowm operator: {name}') + else: + op_values.append(value) + return op_values + + def main(_): input_path = Path(FLAGS.input) output_path = Path(FLAGS.output) @@ -288,12 +345,15 @@ def main(_): buffer = bytes(file.read()) input_model: model_facade._Model = model_facade.read(buffer) - context = Context(input_model) + context = Context(input_model, op_names_to_values(FLAGS.ops)) _ = process_subgraph(context, 0) output_model: bytearray = input_model.compile() with open(output_path, 'wb') as file: file.write(output_model) + model_transforms_utils.tflite_flatbuffer_align(str(output_path), + str(output_path)) + print("\nreordering and alignment completed.") if __name__ == '__main__':