Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added (optional) debugging info to the partitioner (AOT_PARTITIONER_DEBUG=1) and add a bunch of ops that inductor supports #947

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .named_members_polyfill import _named_parameters, _named_buffers
from typing import Callable, List, Dict, Any, Tuple, Optional
from functools import wraps
import os

AOT_PARTITIONER_DEBUG = 'AOT_PARTITIONER_DEBUG' in os.environ

try:
from torchdynamo import disable as disable_torchdynamo
Expand Down Expand Up @@ -202,6 +205,12 @@ def fake_fn(primals, tangents):
compiled_fw = fw_compiler(fw_module, flat_tensor_args)
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))

if AOT_PARTITIONER_DEBUG:
activation_sizes = 0
for out in fw_outs[num_outs:]:
if isinstance(out, torch.Tensor):
activation_sizes += out.storage().nbytes()
print(f"Real Activations Stored(GB): {activation_sizes/1e9}")
bw_args = fw_outs[num_outs:] + fw_outs[0:num_outs]
compiled_bw = bw_compiler(bw_module, bw_args)
else:
Expand Down
42 changes: 34 additions & 8 deletions functorch/_src/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import torch.utils._pytree as pytree
import copy
import os
from collections import defaultdict
from torch.fx.passes import graph_drawer
from typing import Tuple
from .compile_utils import fx_graph_cse, get_aten_target

AOT_PARTITIONER_DEBUG = bool(os.environ.get('AOT_PARTITIONER_DEBUG', False))


class InvalidNodeBase(object):
def __repr__(self):
Expand Down Expand Up @@ -268,11 +271,14 @@ def classify_nodes(joint_module):
node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)

aten = torch.ops.aten
prims = torch.ops.prims

pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501
pointwise_ops += [prims.div, prims.convert_element_type, aten.sign, aten.clone]
misc_ops = [aten.to, aten.type_as, operator.getitem]

reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax] # noqa: E501
reduction_ops += [prims.var, prims.sum, aten.var]

# not recomputed by default since these are kinda expensive/hard to fuse into
# norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward] # noqa: E501
Expand All @@ -282,8 +288,10 @@ def classify_nodes(joint_module):

# These are the view ops that NVFuser can fuse
view_ops = [aten.squeeze, aten.unsqueeze]
view_ops += [prims.broadcast_in_dim, aten.select, aten.permute, aten._unsafe_view, aten.view, aten.expand, aten.slice, aten.reshape, aten.broadcast_tensors]
random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like]
compute_intensive_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.upsample_bilinear2d] # noqa: E501

unrecomputable_ops = random_ops + compute_intensive_ops

recomputable_ops = set(
Expand All @@ -293,6 +301,10 @@ def classify_nodes(joint_module):
+ view_ops
)
fusible_ops = recomputable_ops | set(random_ops)
if AOT_PARTITIONER_DEBUG:
ops_ignored = set([str(node.target._overloadpacket) for node in joint_module.graph.nodes if node.op == 'call_function' and hasattr(node.target, '_overloadpacket')]) - set([str(i) for i in recomputable_ops])
print("Ops banned from rematerialization: ", ops_ignored)
print()

AGGRESSIVE_RECOMPUTATION = False

Expand All @@ -304,13 +316,15 @@ def ban_recomputation(node):
return False
if get_aten_target(node) not in recomputable_ops:
return True
# If the output of the reduction is 4x smaller (arbitrary choice),
if node.target == operator.getitem:
return False
# If the output of an op is 4x smaller (arbitrary choice),
# then we don't allow recomputation.
if get_aten_target(node) in reduction_ops:
input_tensors_size = sum(_size_of(i.meta['tensor_meta']) for i in node.args if isinstance(i, fx.Node))
output_size = _size_of(node.meta['tensor_meta'])
return (output_size * 4 < input_tensors_size)
return False
if 'tensor_meta' not in node.meta:
return False
input_tensors_size = sum(_size_of(i.meta['tensor_meta']) for i in node.args if isinstance(i, fx.Node) and 'tensor_meta' in i.meta)
output_size = _size_of(node.meta['tensor_meta'])
return (output_size * 4 < input_tensors_size)

def is_fusible(a, b):
return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops
Expand All @@ -325,7 +339,9 @@ def get_node_weight(node):
mem_sz = _size_of(node.meta['tensor_meta'])

# Heuristic to bias towards nodes closer to the backwards pass
mem_sz = int(mem_sz + node.dist_from_bw)
# Complete guess about current value
mem_sz = int(mem_sz * (1.5 ** max(min(node.dist_from_bw, 100), 1)))
# mem_sz = int(mem_sz + node.dist_from_bw)

if is_materialized(node):
return mem_sz
Expand Down Expand Up @@ -375,7 +391,17 @@ def get_node_weight(node):
# To make this stuff deterministic
node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)}
saved_values = sorted((name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x])
return _extract_fwd_bwd_modules(joint_module, saved_values)
fw_module, bw_module = _extract_fwd_bwd_modules(joint_module, saved_values)
if AOT_PARTITIONER_DEBUG:
print("Theoretical Activations Stored: ", sum([_size_of(i.meta['tensor_meta']) for i in saved_values])/1e9)
remat_nodes = set([node.name for node in fw_module.graph.nodes if node.op == 'call_function']) & set([node.name for node in bw_module.graph.nodes if node.op == 'call_function'])
counts = defaultdict(int)
for node in fw_module.graph.nodes:
if node.name in remat_nodes and hasattr(node.target, '_overloadpacket'):
counts[str(node.target._overloadpacket)] += 1
print("# nodes rematerialized: ", len(remat_nodes))
print("Count of Ops Rematerialized: ", sorted(counts.items(), key=lambda x: x[1], reverse=True))
return fw_module, bw_module


def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_graph", clear_meta=True):
Expand Down