Skip to content

Commit

Permalink
Updates to operator reordering script.
Browse files Browse the repository at this point in the history
  • Loading branch information
ddavis-2015 committed Dec 21, 2024
1 parent 3840b33 commit c0604d8
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 58 deletions.
13 changes: 0 additions & 13 deletions tensorflow/lite/micro/compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,6 @@ py_test(
],
)

py_binary(
name = "relocate_ops",
srcs = [
"relocate_ops.py",
],
deps = [
"model_facade",
"//tensorflow/lite/python:schema_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)

py_library(
name = "spec",
srcs = ["spec.py"],
Expand Down
10 changes: 7 additions & 3 deletions tensorflow/lite/micro/compression/model_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import numpy as np
from numpy.typing import NDArray
from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite
from typing import ByteString, Generic, TypeVar
from typing import ByteString, Generic, TypeVar, List

_IteratorTo = TypeVar("_IteratorTo")

Expand Down Expand Up @@ -116,11 +116,11 @@ def outputs(self):
return _IndirectIterator(self.operator.outputs, self.subgraph.tensors)

@property
def inputs_indices(self):
def inputs_indices(self) -> List[int]:
return self.operator.inputs

@property
def outputs_indices(self):
def outputs_indices(self) -> List[int]:
return self.operator.outputs

@property
Expand Down Expand Up @@ -235,6 +235,10 @@ def operators(self) -> _Iterator[_Operator]:
def tensors(self) -> _Iterator[_Tensor]:
return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self)

@property
def outputs_indices(self) -> List[int]:
return self._subgraph_t.outputs


class _Model:
"""A facade for manipulating tflite.Model.
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/lite/micro/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,17 @@ flatbuffer_py_library(
name = "layer_by_layer_schema_py",
srcs = ["layer_by_layer_schema.fbs"],
)

py_binary(
name = "reorder_ops",
srcs = [
"reorder_ops.py",
],
deps = [
":model_transforms_utils",
"//tensorflow/lite/micro/compression:model_facade",
"//tensorflow/lite/python:schema_py",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
r"""
*** EXPERIMENTAL ***
This is an experimental tool and is subject to change at any time
This tool will allow reordering of the command line specified operators.
The reordered operators will be moved within their subgraph, such that they
are more closely colocated to another operator which consumes their output
tensor.
The output model will be properly aligned as per the .tflite flatbuffer schema.
Usage:
bazel run tensorflow/lite/micro/tools:relocate_read_variable -- \\
--input=<input.tflite> --output=<output.tflite>
bazel run tensorflow/lite/micro/tools:reorder_ops -- \
--input=<input.tflite> \
--output=<output.tflite> \
--ops=<operator_name[,operator_name]...>
"""

import model_facade

from tensorflow.lite.python import schema_py_generated as tflite
from tflite_micro.tensorflow.lite.micro.compression import model_facade
from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite
from tflite_micro.tensorflow.lite.micro.tools import model_transforms_utils

from absl import app
from absl import flags
Expand All @@ -30,14 +42,22 @@

flags.DEFINE_string(
name='input',
default='',
default=None,
help='path for the .tflite input file',
required=True,
)

flags.DEFINE_string(
name='output',
default='',
default=None,
help='path for the .tflite output file',
required=True,
)

flags.DEFINE_list(
name='ops',
default=[],
help='comma separated names of operators to reorder (case insensitive)',
)

VarHandleId = int
Expand All @@ -64,9 +84,13 @@
class Context:
"""
Context:
The Context holds the stack of operators currently being processed,
the list of pending operations which may be relocated,
and the list of reordered operations representing the new subgraph(s)
"""

def __init__(self, model: model_facade._Model) -> None:
def __init__(self, model: model_facade._Model, ops: List[int]) -> None:
self._model = model
self._current_op_stack: List[model_facade._Operator] = []
self._reordered_operators: List[List[model_facade._Operator]] = [[]] * len(
Expand All @@ -78,6 +102,7 @@ def __init__(self, model: model_facade._Model) -> None:
self._var_handles_by_name: VarHandleByName = {}
self._var_handles_by_tensor: VarHandleByTensor = {}
self._current_var_handle_id: VarHandleId = 0
self._ops_to_reorder: List[int] = ops

@property
def model(self):
Expand Down Expand Up @@ -130,13 +155,12 @@ def remove_pending_op(self, op: model_facade._Operator) -> None:
def get_pending_op(
self, tensor_index: TensorIndex,
subgraph_index: SubgraphIndex) -> model_facade._Operator | None:
return self._pending_ops[subgraph_index].get(tensor_index, None)
return self._pending_ops[subgraph_index].get(tensor_index)

def get_read_var_op_by_handle(
self, resource_tensor_index: TensorIndex,
def get_pending_read_var_ops_by_handle(
self, var_handle_id: VarHandleId,
subgraph_index: SubgraphIndex) -> List[model_facade._Operator]:
result: List[model_facade._Operator] = []
var_handle_id = self.get_var_handle(subgraph_index, resource_tensor_index)
for op in self._pending_ops[subgraph_index].values():
if op.builtin_opcode != tflite.BuiltinOperator.READ_VARIABLE:
continue
Expand All @@ -145,6 +169,11 @@ def get_read_var_op_by_handle(
result.append(op)
return result

def can_be_pending_op(self, op: model_facade._Operator) -> bool:
return (op.builtin_opcode in self._ops_to_reorder
and op.outputs_indices is not None
and len(op.outputs_indices) == 1)

def create_var_handle(self, container_name: str | None, resource_name: str,
subgraph_index: SubgraphIndex,
resource_tensor_index: TensorIndex) -> VarHandleId:
Expand All @@ -167,36 +196,38 @@ def add_var_handle(self, subgraph_index: SubgraphIndex,
resource_tensor_index: TensorIndex,
var_handle_id: VarHandleId) -> None:
key = (subgraph_index, resource_tensor_index)
assert self._var_handles_by_tensor.get(key, None) is None
assert self._var_handles_by_tensor.get(key) is None
self._var_handles_by_tensor[key] = var_handle_id


# Begin global methods


def process_operator_var_handle(context: Context) -> VarHandles:
def process_operator_var_handle(context: Context) -> None:
op = context.current_op()
assert op.builtin_options_type == tflite.BuiltinOptions.VarHandleOptions
assert op.builtin_options is not None
container_name: str = op.builtin_options.container
resource_name: str = op.builtin_options.sharedName
var_handle_id = context.create_var_handle(container_name, resource_name,
op.subgraph.index,
op.outputs_indices[0])
context.append_to_reordered_operations(op)
return set([var_handle_id])
_ = context.create_var_handle(container_name, resource_name,
op.subgraph.index, op.outputs_indices[0])
if context.can_be_pending_op(op):
context.add_pending_op(op)
else:
context.append_to_reordered_operations(op)


def process_operator_assign_variable(context: Context) -> VarHandles:
assign_op = context.current_op()
for read_var_op in context.get_read_var_op_by_handle(
assign_op.inputs_indices[0], assign_op.subgraph.index):
var_handle_id = context.get_var_handle(assign_op.subgraph.index,
assign_op.inputs_indices[0])
for read_var_op in context.get_pending_read_var_ops_by_handle(
var_handle_id, assign_op.subgraph.index):
context.append_to_reordered_operations(read_var_op)
context.remove_pending_op(read_var_op)

process_pending_ops(context)
var_handle_id = context.get_var_handle(assign_op.subgraph.index,
assign_op.inputs_indices[0])

return set([var_handle_id])


Expand All @@ -215,40 +246,52 @@ def process_operator_while(context: Context) -> VarHandles:
return set()


def process_operator_as_pending(context: Context) -> VarHandles:
context.add_pending_op(context.current_op())
return set()


def process_pending_ops(context: Context) -> None:
"""Process current operator against any pending operators.
Then add the current operator to the list of reordered operations.
"""
op = context.current_op()
for tensor_input in op.inputs_indices:
pending_op = context.get_pending_op(tensor_input, op.subgraph.index)
if pending_op is not None:
context.remove_pending_op(pending_op)
context.push_current_op(pending_op)
process_pending_ops(context)
context.pop_current_op()
if op.inputs_indices is not None:
for tensor_input in op.inputs_indices:
pending_op = context.get_pending_op(tensor_input, op.subgraph.index)
if pending_op is not None:
context.remove_pending_op(pending_op)
context.push_current_op(pending_op)
process_pending_ops(context)
context.pop_current_op()

context.append_to_reordered_operations(op)


def process_subgraph_pending_ops(context: Context,
subgraph_index: SubgraphIndex) -> None:
"""Process subgraph outputs against any pending operators.
"""
outputs_indices = context.model.subgraphs[subgraph_index].outputs_indices
if outputs_indices is not None:
for tensor_index in outputs_indices:
pending_op = context.get_pending_op(tensor_index, subgraph_index)
if pending_op is not None:
context.remove_pending_op(pending_op)
context.push_current_op(pending_op)
process_pending_ops(context)
context.pop_current_op()


def process_operator(context: Context) -> VarHandles:
op = context.current_op()
if op.builtin_opcode == tflite.BuiltinOperator.VAR_HANDLE:
return process_operator_var_handle(context)
process_operator_var_handle(context)
elif op.builtin_opcode == tflite.BuiltinOperator.ASSIGN_VARIABLE:
return process_operator_assign_variable(context)
elif op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE:
return process_operator_as_pending(context)
elif op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION:
return process_operator_as_pending(context)
elif op.builtin_opcode == tflite.BuiltinOperator.IF:
return process_operator_if(context)
elif op.builtin_opcode == tflite.BuiltinOperator.WHILE:
return process_operator_while(context)
elif op.builtin_opcode == tflite.BuiltinOperator.CALL_ONCE:
return process_operator_call_once(context)
elif context.can_be_pending_op(op):
context.add_pending_op(op)
else:
process_pending_ops(context)

Expand All @@ -269,6 +312,8 @@ def process_subgraph(context: Context, subgraph_index: int) -> VarHandles:
var_handles.update(var_handles_processed)
context.pop_current_op()

process_subgraph_pending_ops(context, subgraph_index)

operators: List[tflite.OperatorT] = []
for op in context.reordered_operations(subgraph_index):
operators.append(op.operator)
Expand All @@ -280,6 +325,18 @@ def process_subgraph(context: Context, subgraph_index: int) -> VarHandles:
return var_handles


def op_names_to_values(op_names: List[str]) -> List[int]:
op_values = []
builtin_operators = vars(tflite.BuiltinOperator)
for name in op_names:
value = builtin_operators.get(name.upper())
if value is None:
raise ValueError(f'unknowm operator: {name}')
else:
op_values.append(value)
return op_values


def main(_):
input_path = Path(FLAGS.input)
output_path = Path(FLAGS.output)
Expand All @@ -288,12 +345,15 @@ def main(_):
buffer = bytes(file.read())
input_model: model_facade._Model = model_facade.read(buffer)

context = Context(input_model)
context = Context(input_model, op_names_to_values(FLAGS.ops))
_ = process_subgraph(context, 0)

output_model: bytearray = input_model.compile()
with open(output_path, 'wb') as file:
file.write(output_model)
model_transforms_utils.tflite_flatbuffer_align(str(output_path),
str(output_path))
print("\nreordering and alignment completed.")


if __name__ == '__main__':
Expand Down

0 comments on commit c0604d8

Please sign in to comment.