Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
swolchok committed Jan 22, 2025
2 parents ad860de + 1655fc5 commit 9cc8669
Show file tree
Hide file tree
Showing 56 changed files with 2,122 additions and 1,439 deletions.
26 changes: 19 additions & 7 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
Raises a ValueError if the node doesn't have any parameters set.
"""
if "input_qparams" not in node.meta.keys():
raise ValueError(f"No input quantization parameter found in node {node}")
raise ValueError(
f"No input quantization parameter found in node {node}\n"
f"original_aten={node.meta.get('original_aten', 'None')}"
)
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
if len(input_qparams) == 0:
raise ValueError(f"No input quantization parameter found in node {node}")
raise ValueError(
f"No input quantization parameter found in node {node}\n"
f"original_aten={node.meta.get('original_aten', 'None')}"
)
return input_qparams


Expand All @@ -45,11 +51,17 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
Raises a ValueError if the node doesn't have any parameters set.
"""
if "output_qparams" not in node.meta.keys():
raise ValueError(f"No output quantization parameter found in node {node}")
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
if len(input_qparams) == 0:
raise ValueError(f"No output quantization parameter found in node {node}")
return input_qparams
raise ValueError(
f"No output quantization parameter found in node {node}\n"
f"original_aten={node.meta.get('original_aten', 'None')}"
)
output_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
if len(output_qparams) == 0:
raise ValueError(
f"No output quantization parameter found in node {node}\n"
f"original_aten={node.meta.get('original_aten', 'None')}"
)
return output_qparams


class FoldAndAnnotateQParamsPass(ExportPass):
Expand Down
6 changes: 4 additions & 2 deletions backends/cadence/aot/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ def print_memory_planning_info(

# Print the memory usage per memory space as a table
logging.info(
tabulate(
"\n"
+ tabulate(
memory_usage_table,
headers=[
"Memory Space",
Expand Down Expand Up @@ -398,7 +399,8 @@ def print_memory_planning_info(

# Print the total memory usage as a table
logging.info(
tabulate(
"\n"
+ tabulate(
total_memory_usage_table,
tablefmt="outline",
)
Expand Down
162 changes: 161 additions & 1 deletion backends/cadence/aot/tests/test_memory_passes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import logging
import math
import unittest
from typing import cast

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
Expand Down Expand Up @@ -110,7 +112,121 @@ def forward(self, x):


class TestMemTransform(unittest.TestCase):
def test_optimize_cat(self):
def _verify_cat_nop_memory_alloc(self, node: torch.fx.Node) -> None:
spec = node.meta.get("spec", None)
self.assertIsNotNone(spec)
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
outer_size = math.prod(spec.shape[:dim])
self.assertEqual(
outer_size,
1,
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
)
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
dim_offset = 0
for arg in cast(list[torch.fx.Node], node.args[0]):
arg_spec = arg.meta.get("spec", None)
self.assertEqual(arg_spec.mem_id, spec.mem_id)
self.assertEqual(
arg_spec.mem_offset,
spec.mem_offset + dim_offset * inner_dim_elements,
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}",
)
dim_offset += arg_spec.shape[dim]

def _verify_slice_nop_memory_alloc(self, node: torch.fx.Node) -> None:
spec = node.meta.get("spec", None)
self.assertIsNotNone(spec)
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
outer_size = math.prod(spec.shape[:dim])
self.assertEqual(
outer_size,
1,
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
)
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
start: int = (
cast(int, node.args[2])
if (len(node.args) > 2 and node.args[2] is not None)
else 0
)
arg = cast(torch.fx.Node, node.args[0])
arg_spec = arg.meta.get("spec", None)
self.assertEqual(arg_spec.mem_id, spec.mem_id)
self.assertEqual(
spec.mem_offset,
arg_spec.mem_offset + start * inner_dim_elements,
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for slice on {dim=}, but output has {spec.mem_offset=}",
)

def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
spec = node.meta.get("spec", None)
self.assertIsNotNone(spec)
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
outer_size = math.prod(spec.shape[:dim])
self.assertEqual(
outer_size,
1,
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
)
inner_dim_elements = math.prod(spec.shape[dim:]) * spec.dtype.itemsize
index: int = (
cast(int, node.args[2])
if (len(node.args) > 2 and node.args[2] is not None)
else 0
)
arg = cast(torch.fx.Node, node.args[0])
arg_spec = arg.meta.get("spec", None)
self.assertEqual(arg_spec.mem_id, spec.mem_id)
self.assertEqual(
spec.mem_offset,
arg_spec.mem_offset + index * inner_dim_elements,
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} for select on {dim=} {index=}, "
f"but output has {spec.mem_offset=}"
f"{spec=} {arg_spec=}",
)

def verify_nop_memory_alloc(self, graph_module):
for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._cat_nop.out
):
self._verify_cat_nop_memory_alloc(node)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._slice_copy_nop.Tensor_out
):
self._verify_slice_nop_memory_alloc(node)

for node in graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten._select_copy_nop.int_out
):
self._verify_select_nop_memory_alloc(node)

def test_optimize_cat_on_placeholders(self):
class Cat(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat((x, y))

x = torch.ones(3, 6)
y = torch.ones(2, 6)
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
# pass to run:
graph_module = (
compiler.export_to_executorch_gen_etrecord(
Cat(), (x, y), opt_level=2, mem_algo=1
)
.exported_program()
.graph_module
)
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
graph_module.graph.eliminate_dead_code()
# Assert that cat op is optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
# Assert that cat op is replaced by its nop version post optimization
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_outermost(self):
class OptimizeCatFeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -135,7 +251,9 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
# Assert that cat op is replaced by its nop version post optimization
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_non_outermost(self):
class OptimizeCatFeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -160,7 +278,9 @@ def forward(self, x, y):
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
# Assert that cat op is replaced by its nop version post optimization
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost(self):
class OptimizeCatInfeasible1(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -184,7 +304,9 @@ def forward(self, x, y):
# Assert that cat op is not optimized away, since the concat is not
# along the outermost dim
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_non_outermost1(self):
class OptimizeCatInfeasible2(torch.nn.Module):
def forward(self, x, y):
x1 = torch.add(x, 2.4, 3.1)
Expand All @@ -209,6 +331,7 @@ def forward(self, x, y):
# offsets are not multiple of 8 bytes, and the cat is not the output
# of the graph.
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice(self):
class OptimizeCatSliceFeasible(torch.nn.Module):
Expand Down Expand Up @@ -237,6 +360,7 @@ def forward(self, x):
graph_module.graph.eliminate_dead_code()
# Assert that cat op is optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_slice_infeasible(self):
class OptimizeCatSliceInfeasible(torch.nn.Module):
Expand All @@ -262,6 +386,7 @@ def forward(self, x, y):
graph_module.graph.eliminate_dead_code()
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_Tensor(self):
class SliceTensor(torch.nn.Module):
Expand Down Expand Up @@ -323,6 +448,7 @@ def forward(self, x, y, z):
self.assertEqual(
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_select_Tensor(self):
class SelectTensor(torch.nn.Module):
Expand Down Expand Up @@ -387,6 +513,7 @@ def forward(self, x, y, z):
self.assertEqual(
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
)
self.verify_nop_memory_alloc(graph_module)

# TODO: Test fails due to memory planning
@unittest.expectedFailure
Expand Down Expand Up @@ -416,6 +543,32 @@ def forward(self, x, y):
graph_module.graph.eliminate_dead_code()
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_then_slice_on_mutable_buffer(self):
class CatWithPadding(torch.nn.Module):
def __init__(self, padding_shape):
super().__init__()
zeros = torch.zeros(padding_shape)
self.register_buffer("padding", zeros)

def forward(self, x, y):
x = x.view(3, 5)
cat = torch.ops.aten.cat((x, self.padding.clone()))
slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0])
self.padding.copy_(slice_copy)
return cat.view(-1) + y

x = torch.ones(15)
y = torch.ones(1)
et_prog_manager = compiler.export_to_executorch_gen_etrecord(
CatWithPadding((1, 5)), (x, y), opt_level=3
)
graph_module = et_prog_manager.exported_program().graph_module
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_cat_with_view(self):
class CatViewFeasible(torch.nn.Module):
Expand All @@ -442,6 +595,7 @@ def forward(self, x, y):
# Assert that cat op is optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_repeated_args(self):
class CatViewInfeasible(torch.nn.Module):
Expand All @@ -465,6 +619,7 @@ def forward(self, x):
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat_with_placeholder(self):
class CatViewInfeasible(torch.nn.Module):
Expand Down Expand Up @@ -492,6 +647,7 @@ def forward(self, x, y):
# Assert that cat op is not optimized away
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_no_optimize_cat(self) -> None:
class Model(torch.nn.Module):
Expand Down Expand Up @@ -522,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
)
self.assertEqual(count_node(graph_module, memory.view), 2)
self.verify_nop_memory_alloc(graph_module)

def test_optimize_slice_copy(self) -> None:
class Model(torch.nn.Module):
Expand Down Expand Up @@ -553,6 +710,7 @@ def forward(self, x) -> torch.Tensor:
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0
)
self.assertEqual(count_node(graph_module, memory.view), 2)
self.verify_nop_memory_alloc(graph_module)

def test_cat_then_cat(self) -> None:
class Model(torch.nn.Module):
Expand All @@ -579,6 +737,7 @@ def forward(self, x) -> torch.Tensor:
graph_module.print_readable()
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2)
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
self.verify_nop_memory_alloc(graph_module)

def test_view_for_unallocated_output(self):
class Model(torch.nn.Module):
Expand All @@ -602,3 +761,4 @@ def forward(self, x, y):
.graph_module
)
self.assertEqual(count_node(graph_module, memory.view), 1)
self.verify_nop_memory_alloc(graph_module)
2 changes: 0 additions & 2 deletions backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,8 @@ def export(self, export_stage: Optional[Export] = None):
)

def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
# TODO(T182187531): Skip dim order for now. Support dim order and its op after alpha release.
if not to_edge_stage:
to_edge_stage = ToEdge()
to_edge_stage.edge_compile_conf._skip_dim_order = True
res = self._run_stage(to_edge_stage)
return res

Expand Down
6 changes: 3 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Welcome to the ExecuTorch Documentation
=======================================

.. important::
v0.4.0 is a beta release of ExecuTorch. As of this release, the API will
follow the `API Lifecycle and Deprecation Policy <api-life-cycle.html>`__,
and the ``.pte`` binary format will comply with the `Runtime Compatibility
v0.4.0 was the beta release of ExecuTorch. Starting from v0.4.0, the API
follows the `API Lifecycle and Deprecation Policy <api-life-cycle.html>`__,
and the ``.pte`` binary format complies with the `Runtime Compatibility
Policy
<https://github.com/pytorch/executorch/tree/main/runtime/COMPATIBILITY.md>`__.
This helps ensure that application developers can update to the latest
Expand Down
4 changes: 4 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,7 @@ class ExecutorchBackendConfig:
# If set to true, all constant tensors will be stored in a separate file,
# external to the PTE file.
external_constants: bool = False

# If set to true, all trainable weights will be stored in a separate file,
# external to the PTE file.
external_mutable_weights: bool = False
Loading

0 comments on commit 9cc8669

Please sign in to comment.