Skip to content

Commit

Permalink
single pending ops queue
Browse files Browse the repository at this point in the history
process pending ops recursively
  • Loading branch information
ddavis-2015 committed Nov 10, 2024
1 parent cfd9890 commit f651c88
Showing 1 changed file with 46 additions and 73 deletions.
119 changes: 46 additions & 73 deletions tensorflow/lite/micro/compression/relocate_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@
{ ( subgraph_index, tensor_index ) : var_handle_id }
"""

ReadVarOps = List[Dict[int, model_facade.Operator]]

ConcatOps = List[Dict[int, model_facade.Operator]]
PendingOps = List[Dict[TensorIndex, model_facade.Operator]]
"""PendingOps
[ { output_tensor_index : operator }]
"""


class Context:
Expand All @@ -73,8 +74,7 @@ def __init__(self, model: model_facade.Model) -> None:
self._subgraph_processed: List[bool] = [False] * len(model.subgraphs)
self._subgraph_modified_vars: List[VarHandles] = [set()] * len(
model.subgraphs)
self._subgraph_read_var_ops: ReadVarOps = [{}] * len(model.subgraphs)
self._subgraph_concat_ops: ConcatOps = [{}] * len(model.subgraphs)
self._pending_ops: PendingOps = [{}] * len(model.subgraphs)
self._var_handles_by_name: VarHandleByName = {}
self._var_handles_by_tensor: VarHandleByTensor = {}
self._current_var_handle_id: VarHandleId = 0
Expand Down Expand Up @@ -116,47 +116,35 @@ def set_subgraph_var_handles(self, subgraph_index: SubgraphIndex,
handles: VarHandles) -> None:
self._subgraph_modified_vars[subgraph_index] = handles

def add_read_var_op(self, op: model_facade.Operator) -> None:
assert op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE
def add_pending_op(self, op: model_facade.Operator) -> None:
assert len(op.outputs_indices) == 1
key: TensorIndex = op.outputs_indices[0]
self._subgraph_read_var_ops[op.subgraph.index][key] = op
assert self._pending_ops[op.subgraph.index].get(key) is None
self._pending_ops[op.subgraph.index][key] = op

def remove_read_var_op(self, op: model_facade.Operator) -> None:
assert op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE
def remove_pending_op(self, op: model_facade.Operator) -> None:
key: TensorIndex = op.outputs_indices[0]
del self._subgraph_read_var_ops[op.subgraph.index][key]
assert self._pending_ops[op.subgraph.index][key].index == op.index
del self._pending_ops[op.subgraph.index][key]

def get_read_var_op_by_tensor(
def get_pending_op(
self, tensor_index: TensorIndex,
subgraph_index: SubgraphIndex) -> model_facade.Operator | None:
return self._subgraph_read_var_ops[subgraph_index].get(tensor_index, None)
return self._pending_ops[subgraph_index].get(tensor_index, None)

def get_read_var_op_by_handle(
self, resource_tensor_index: TensorIndex,
subgraph_index: SubgraphIndex) -> List[model_facade.Operator]:
result: List[model_facade.Operator] = []
var_handle_id = self.get_var_handle(subgraph_index, resource_tensor_index)
for op in self._subgraph_read_var_ops[subgraph_index].values():
for op in self._pending_ops[subgraph_index].values():
if op.builtin_opcode != tflite.BuiltinOperator.READ_VARIABLE:
continue
if self.get_var_handle(op.subgraph.index,
op.inputs_indices[0]) == var_handle_id:
result.append(op)
return result

def add_concat_op(self, op: model_facade.Operator) -> None:
assert op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION
key: TensorIndex = op.outputs_indices[0]
self._subgraph_concat_ops[op.subgraph.index][key] = op

def remove_concat_op(self, op: model_facade.Operator) -> None:
assert op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION
key: TensorIndex = op.outputs_indices[0]
del self._subgraph_concat_ops[op.subgraph.index][key]

def get_concat_op_by_tensor(
self, tensor_index: TensorIndex,
subgraph_index: SubgraphIndex) -> model_facade.Operator | None:
return self._subgraph_concat_ops[subgraph_index].get(tensor_index, None)

def create_var_handle(self, container_name: str | None, resource_name: str,
subgraph_index: SubgraphIndex,
resource_tensor_index: TensorIndex) -> VarHandleId:
Expand Down Expand Up @@ -201,27 +189,15 @@ def process_operator_var_handle(context: Context) -> VarHandles:

def process_operator_assign_variable(context: Context) -> VarHandles:
assign_op = context.current_op()
pending_concat_op = context.get_concat_op_by_tensor(
assign_op.inputs_indices[1], assign_op.subgraph.index)
assert pending_concat_op is None
read_var_op = context.get_read_var_op_by_tensor(assign_op.inputs_indices[1],
assign_op.subgraph.index)
if read_var_op is not None:
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)

for read_var_op in context.get_read_var_op_by_handle(
assign_op.inputs_indices[0], assign_op.subgraph.index):
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)
context.remove_pending_op(read_var_op)

context.append_to_reordered_operations(assign_op)
return set()


def process_operator_read_variable(context: Context) -> VarHandles:
context.add_read_var_op(context.current_op())
return set()
process_pending_ops(context)
var_handle_id = context.get_var_handle(assign_op.subgraph.index,
assign_op.inputs_indices[0])
return set([var_handle_id])


def process_operator_call_once(context: Context) -> VarHandles:
Expand All @@ -239,45 +215,42 @@ def process_operator_while(context: Context) -> VarHandles:
return set()


def process_operator_concatenation(context: Context) -> VarHandles:
context.add_concat_op(context.current_op())
def process_operator_as_pending(context: Context) -> VarHandles:
context.add_pending_op(context.current_op())
return set()


def process_pending_ops(context: Context) -> None:
op = context.current_op()
for tensor_input in op.inputs_indices:
pending_op = context.get_pending_op(tensor_input, op.subgraph.index)
if pending_op is not None:
context.remove_pending_op(pending_op)
context.push_current_op(pending_op)
process_pending_ops(context)
context.pop_current_op()

context.append_to_reordered_operations(op)


def process_operator(context: Context) -> VarHandles:
op = context.current_op()
if op.builtin_opcode == tflite.BuiltinOperator.VAR_HANDLE:
return process_operator_var_handle(context)
elif op.builtin_opcode == tflite.BuiltinOperator.ASSIGN_VARIABLE:
return process_operator_assign_variable(context)
elif op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE:
return process_operator_read_variable(context)
return process_operator_as_pending(context)
elif op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION:
return process_operator_concatenation(context)
return process_operator_as_pending(context)
elif op.builtin_opcode == tflite.BuiltinOperator.IF:
return process_operator_if(context)
elif op.builtin_opcode == tflite.BuiltinOperator.WHILE:
return process_operator_while(context)
elif op.builtin_opcode == tflite.BuiltinOperator.CALL_ONCE:
return process_operator_call_once(context)
else:
for tensor_input in op.inputs_indices:
concat_op = context.get_concat_op_by_tensor(tensor_input,
op.subgraph.index)
if concat_op is not None:
for concat_tensor_input in concat_op.inputs_indices:
pending_concat_op = context.get_concat_op_by_tensor(
concat_tensor_input, op.subgraph.index)
assert pending_concat_op is None
read_var_op = context.get_read_var_op_by_tensor(
concat_tensor_input, op.subgraph.index)
if read_var_op is not None:
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)

context.append_to_reordered_operations(concat_op)
context.remove_concat_op(concat_op)

read_var_op = context.get_read_var_op_by_tensor(tensor_input,
op.subgraph.index)
if read_var_op is not None:
context.append_to_reordered_operations(read_var_op)
context.remove_read_var_op(read_var_op)
context.append_to_reordered_operations(op)
process_pending_ops(context)

return set()

Expand Down

0 comments on commit f651c88

Please sign in to comment.