Skip to content

Commit

Permalink
expand model_facade
Browse files Browse the repository at this point in the history
redo var handle tracking
  • Loading branch information
ddavis-2015 committed Nov 10, 2024
1 parent 7776cda commit cfd9890
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 31 deletions.
8 changes: 8 additions & 0 deletions tensorflow/lite/micro/compression/model_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ def inputs_indices(self):
def outputs_indices(self):
return self.operator.outputs

@property
def builtin_options_type(self) -> int:
return self.operator.builtinOptionsType

@property
def builtin_options(self):
return self.operator.builtinOptions


class Tensor:

Expand Down
118 changes: 87 additions & 31 deletions tensorflow/lite/micro/compression/relocate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from absl import app
from absl import flags
from pathlib import Path
from typing import List, Set, Dict
from typing import List, Set, Dict, Tuple

FLAGS = flags.FLAGS

Expand All @@ -40,10 +40,24 @@
help='path for the .tflite output file',
)

VarHandles = Set[int]
VarHandleId = int
VarHandles = Set[VarHandleId]
TensorIndex = int
ReadVarOps = List[Dict[TensorIndex, model_facade.Operator]]
ConcatOps = List[Dict[TensorIndex, model_facade.Operator]]
SubgraphIndex = int

VarHandleByName = Dict[Tuple[str | None, str], VarHandleId]
"""VarHandleByName
{ ( container_name | None, resource_name ) : var_handle_id }
"""

VarHandleByTensor = Dict[Tuple[SubgraphIndex, TensorIndex], VarHandleId]
"""VarHandleByTensor
{ ( subgraph_index, tensor_index ) : var_handle_id }
"""

ReadVarOps = List[Dict[int, model_facade.Operator]]

ConcatOps = List[Dict[int, model_facade.Operator]]


class Context:
Expand All @@ -54,13 +68,16 @@ class Context:
def __init__(self, model: model_facade.Model) -> None:
self._model = model
self._current_op_stack: List[model_facade.Operator] = []
self._operators: List[List[model_facade.Operator]] = [[]] * len(
self._reordered_operators: List[List[model_facade.Operator]] = [[]] * len(
model.subgraphs)
self._subgraph_processed: List[bool] = [False] * len(model.subgraphs)
self._subgraph_modified_vars: List[VarHandles] = [set()] * len(
model.subgraphs)
self._subgraph_read_var_ops: ReadVarOps = [{}] * len(model.subgraphs)
self._subgraph_concat_ops: ConcatOps = [{}] * len(model.subgraphs)
self._var_handles_by_name: VarHandleByName = {}
self._var_handles_by_tensor: VarHandleByTensor = {}
self._current_var_handle_id: VarHandleId = 0

@property
def model(self):
Expand All @@ -75,26 +92,27 @@ def push_current_op(self, op: model_facade.Operator) -> None:
def pop_current_op(self) -> None:
_ = self._current_op_stack.pop()

def append_to_operations(self, op: model_facade.Operator) -> None:
def append_to_reordered_operations(self, op: model_facade.Operator) -> None:
subgraph_index: int = op.subgraph.index
new_op = model_facade.Operator(op.operator,
len(self._operators[subgraph_index]),
op.subgraph)
self._operators[subgraph_index].append(new_op)
new_op = model_facade.Operator(
op.operator, len(self._reordered_operators[subgraph_index]),
op.subgraph)
self._reordered_operators[subgraph_index].append(new_op)

def operations(self, subgraph_index: int) -> List[model_facade.Operator]:
return self._operators[subgraph_index]
def reordered_operations(
self, subgraph_index: SubgraphIndex) -> List[model_facade.Operator]:
return self._reordered_operators[subgraph_index]

def is_subgraph_processed(self, subgraph_index: int) -> bool:
def is_subgraph_processed(self, subgraph_index: SubgraphIndex) -> bool:
return self._subgraph_processed[subgraph_index]

def mark_subgraph_processed(self, subgraph_index: int) -> None:
def mark_subgraph_processed(self, subgraph_index: SubgraphIndex) -> None:
self._subgraph_processed[subgraph_index] = True

def subgraph_var_handles(self, subgraph_index: int) -> VarHandles:
def subgraph_var_handles(self, subgraph_index: SubgraphIndex) -> VarHandles:
return self._subgraph_modified_vars[subgraph_index]

def set_subgraph_var_handles(self, subgraph_index: int,
def set_subgraph_var_handles(self, subgraph_index: SubgraphIndex,
handles: VarHandles) -> None:
self._subgraph_modified_vars[subgraph_index] = handles

Expand All @@ -110,15 +128,17 @@ def remove_read_var_op(self, op: model_facade.Operator) -> None:

def get_read_var_op_by_tensor(
self, tensor_index: TensorIndex,
subgraph_index: int) -> model_facade.Operator | None:
subgraph_index: SubgraphIndex) -> model_facade.Operator | None:
return self._subgraph_read_var_ops[subgraph_index].get(tensor_index, None)

def get_read_var_op_by_handle(
self, tensor_index: TensorIndex,
subgraph_index: int) -> List[model_facade.Operator]:
self, resource_tensor_index: TensorIndex,
subgraph_index: SubgraphIndex) -> List[model_facade.Operator]:
result: List[model_facade.Operator] = []
var_handle_id = self.get_var_handle(subgraph_index, resource_tensor_index)
for op in self._subgraph_read_var_ops[subgraph_index].values():
if op.inputs_indices[0] == tensor_index:
if self.get_var_handle(op.subgraph.index,
op.inputs_indices[0]) == var_handle_id:
result.append(op)
return result

Expand All @@ -134,13 +154,49 @@ def remove_concat_op(self, op: model_facade.Operator) -> None:

def get_concat_op_by_tensor(
self, tensor_index: TensorIndex,
subgraph_index: int) -> model_facade.Operator | None:
subgraph_index: SubgraphIndex) -> model_facade.Operator | None:
return self._subgraph_concat_ops[subgraph_index].get(tensor_index, None)

def create_var_handle(self, container_name: str | None, resource_name: str,
subgraph_index: SubgraphIndex,
resource_tensor_index: TensorIndex) -> VarHandleId:
key = (container_name, resource_name)
var_handle_id = self._var_handles_by_name.get(key)
if var_handle_id is None:
var_handle_id = self._current_var_handle_id
self._current_var_handle_id += 1
self._var_handles_by_name[key] = var_handle_id

self.add_var_handle(subgraph_index, resource_tensor_index, var_handle_id)

return var_handle_id

def get_var_handle(self, subgraph_index: SubgraphIndex,
resource_tensor_index: TensorIndex) -> VarHandleId:
return self._var_handles_by_tensor[(subgraph_index, resource_tensor_index)]

def add_var_handle(self, subgraph_index: SubgraphIndex,
resource_tensor_index: TensorIndex,
var_handle_id: VarHandleId) -> None:
key = (subgraph_index, resource_tensor_index)
assert self._var_handles_by_tensor.get(key, None) is None
self._var_handles_by_tensor[key] = var_handle_id


# Begin global methods


def process_operator_var_handle(context: Context) -> VarHandles:
context.append_to_operations(context.current_op())
return set()
op = context.current_op()
assert op.builtin_options_type == tflite.BuiltinOptions.VarHandleOptions
assert op.builtin_options is not None
container_name: str = op.builtin_options.container
resource_name: str = op.builtin_options.sharedName
var_handle_id = context.create_var_handle(container_name, resource_name,
op.subgraph.index,
op.outputs_indices[0])
context.append_to_reordered_operations(op)
return set([var_handle_id])


def process_operator_assign_variable(context: Context) -> VarHandles:
Expand All @@ -151,15 +207,15 @@ def process_operator_assign_variable(context: Context) -> VarHandles:
read_var_op = context.get_read_var_op_by_tensor(assign_op.inputs_indices[1],
assign_op.subgraph.index)
if read_var_op is not None:
context.append_to_operations(read_var_op)
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)

for read_var_op in context.get_read_var_op_by_handle(
assign_op.inputs_indices[0], assign_op.subgraph.index):
context.append_to_operations(read_var_op)
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)

context.append_to_operations(assign_op)
context.append_to_reordered_operations(assign_op)
return set()


Expand Down Expand Up @@ -210,18 +266,18 @@ def process_operator(context: Context) -> VarHandles:
read_var_op = context.get_read_var_op_by_tensor(
concat_tensor_input, op.subgraph.index)
if read_var_op is not None:
context.append_to_operations(read_var_op)
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)

context.append_to_operations(concat_op)
context.append_to_reordered_operations(concat_op)
context.remove_concat_op(concat_op)

read_var_op = context.get_read_var_op_by_tensor(tensor_input,
op.subgraph.index)
if read_var_op is not None:
context.append_to_operations(read_var_op)
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)
context.append_to_operations(op)
context.append_to_reordered_operations(op)

return set()

Expand All @@ -241,7 +297,7 @@ def process_subgraph(context: Context, subgraph_index: int) -> VarHandles:
context.pop_current_op()

operators: List[tflite.OperatorT] = []
for op in context.operations(subgraph_index):
for op in context.reordered_operations(subgraph_index):
operators.append(op.operator)
context.model.root.subgraphs[subgraph_index].operators = operators

Expand Down

0 comments on commit cfd9890

Please sign in to comment.