Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Greatly reduce the host memory usage and device memory usage during autocompare #50

Merged
merged 4 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions op_tools/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def wrapper(*args, **kwargs):
except Exception as e:
self.result = None
self.exception = e
self.after_call_op(self.result)
return self.result
result = self.after_call_op(self.result)
return result if result is not None else self.result

return wrapper

Expand Down
81 changes: 51 additions & 30 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import torch
import gc
import os
import time
import atexit
Expand All @@ -15,10 +14,11 @@
is_opname_match,
compare_result,
is_random_number_gen_op,
garbage_collect,
)
from .save_op_args import save_op_args, serialize_args_to_dict

from .pretty_print import pretty_print_op_args, dict_data_list_to_table, packect_data_to_dict_list
from .pretty_print import dict_data_list_to_table, packect_data_to_dict_list


SKIP_LIST_OPS = []
Expand Down Expand Up @@ -68,9 +68,9 @@ def register_grad_fn_hook(self, tensor):
hook_handle = None

def grad_fun(grad_inputs, grad_outputs):
hook_handle.remove()
self.compare_hook.run_backward_on_cpu(grad_inputs, grad_outputs)
self.compare_hook.compare_backward_relate()
hook_handle.remove()

hook_handle = tensor.grad_fn.register_hook(grad_fun)
return grad_fun
Expand Down Expand Up @@ -98,9 +98,9 @@ def copy_input_to_cpu(self):
)

def run_forward_on_cpu(self):
self.result_device = to_device("cpu", self.result, detach=True)
try:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.result_device = to_device("cpu", self.result, detach=True)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e: # noqa: F841
Expand All @@ -119,7 +119,7 @@ def run_forward_on_cpu(self):
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[0].requires_grad:
if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.args_cpu = tuple(args_cpu)
Expand All @@ -128,15 +128,7 @@ def run_forward_on_cpu(self):
args_cpu = self.args_cpu
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)

self.result_device = to_device(
"cpu",
self.result,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.op_backward_args_to_table(grad_inputs, grad_output)
self.grad_outputs_cpu = to_device("cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True)
for arg_cpu in traverse_container(self.args_cpu):
Expand All @@ -160,7 +152,12 @@ def post_hook(grad_inputs, grad_outputs):
result_cpu.backward(*self.grad_outputs_cpu)
handle.remove()

self.op_backward_args_to_table(grad_inputs, grad_output)

def register_backward_hook_for_grads(self):
if self.count_params_with_requires_grad() <= 0:
self.backward_hook_handle = None
return
self.backward_hook_handle = BackwardHookHandle(self)
for result in traverse_container(self.result):
if isinstance(result, torch.Tensor):
Expand All @@ -187,23 +184,34 @@ def compare_forward_relate(self):
output_compare_result = self.compare_forward_result()

result_list = input_compare_result["result_list"] + output_compare_result["result_list"]

dtype_cast_info = ""
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"
print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id}")
self.pretty_print_op_forward_args()
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(self.op_forward_args_to_table())
print(dict_data_list_to_table(result_list))
print("\n" * 2)

self.forward_allclose = self.forward_allclose and self.input_allclose
if not self.forward_allclose:
self.save_forward_args()

def pretty_print_op_forward_args(self):
pretty_print_op_args(self.name, serialize_args_to_dict(*self.args, **self.kwargs), serialize_args_to_dict(self.result))
def op_forward_args_to_table(self):
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs", serialize_args_to_dict(self.result))
cpu_inputs_list = packect_data_to_dict_list(self.name + " inputs(cpu)", serialize_args_to_dict(*self.args_cpu, **self.kwargs_cpu))
cpu_output_list = packect_data_to_dict_list(self.name + " outputs(cpu)", serialize_args_to_dict(self.result_cpu))
forward_args_table = dict_data_list_to_table(inputs_list + output_list + cpu_inputs_list + cpu_output_list)
return forward_args_table

def op_backward_args_to_table(self, grad_inputs, grad_output):
grad_inputs_list = packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_inputs), "grad_inputs")
grad_output_list = packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_output), "grad_output")
self.backward_args_table = dict_data_list_to_table(grad_output_list + grad_inputs_list)
grad_output_list = packect_data_to_dict_list(self.name + " grad_output", serialize_args_to_dict(*grad_output))
grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(*grad_inputs))
grad_output_list = packect_data_to_dict_list(self.name + " grad_output(cpu)", serialize_args_to_dict(*self.grad_outputs_cpu))
cpu_grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs(cpu)", serialize_args_to_dict(*tuple(self.args_cpu_grad))) # noqa: E501
self.backward_args_table = dict_data_list_to_table(grad_output_list + grad_inputs_list + cpu_grad_inputs_list)
return self.backward_args_table

def count_params_with_requires_grad(self):
Expand All @@ -215,9 +223,8 @@ def count_params_with_requires_grad(self):

def compare_input_grad(self):
self.args_grad = self.grad_inputs_cpu
compare_info = compare_result(self.name + " grad", self.args_cpu_grad, self.args_grad)
compare_info = compare_result(self.name + " grad", self.args_grad, self.args_cpu_grad)
compare_info["forward_id"] = self.forward_op_id
print(dict_data_list_to_table(compare_info["result_list"]))

compare_result_cache.append(self.forward_op_id, compare_info)

Expand All @@ -228,8 +235,12 @@ def compare_input_grad(self):
def compare_backward_relate(self):
backward_compare_result = self.compare_input_grad()

dtype_cast_info = ""
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"

print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id}")
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(self.backward_args_table)
print(dict_data_list_to_table(backward_compare_result["result_list"]))
print("\n" * 2)
Expand All @@ -239,8 +250,10 @@ def compare_backward_relate(self):
if self.forward_allclose:
self.save_forward_args()
self.save_backward_args()

id = self.forward_op_id
self = None
gc.collect()
garbage_collect(id)

def save_forward_args(self):
save_op_args(
Expand Down Expand Up @@ -294,18 +307,26 @@ def after_call_op(self, result): # noqa:C901
return
with DisableHookGuard():
self.run_forward_on_cpu()

self.register_backward_hook_for_grads()

self.compare_forward_relate()

self.args = to_device("cpu", self.args, detach=True)
self.kwargs = to_device("cpu", self.kwargs or {}, detach=True)

if self.result is None and self.result_cpu is None:
print(f"{self.name} output is None, no check for backward accuracy")
return

self.register_backward_hook_for_grads()
id = self.forward_op_id
result = self.result
# for reduce device memory usage
if self.backward_hook_handle is not None:
self.args = to_device("cpu", self.args, detach=True)
self.kwargs = to_device("cpu", self.kwargs or {}, detach=True)
self.result = to_device("cpu", self.result, detach=True)
else:
self = None

garbage_collect(id)
return result

def is_should_apply(self, *args, **kwargs):
if is_random_number_gen_op(self.name):
return False
Expand Down
11 changes: 5 additions & 6 deletions op_tools/op_dispatch_watch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .pretty_print import pretty_print_op_args
from .pretty_print import packect_data_to_dict_list, dict_data_list_to_table


class OpDispatchWatcherHook(BaseHook):
Expand All @@ -16,11 +16,10 @@ def before_call_op(self, *args, **kwargs):

def after_call_op(self, result):
with DisableHookGuard():
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs", serialize_args_to_dict(self.result))
forward_args_table = dict_data_list_to_table(inputs_list + output_list)
print(forward_args_table)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DISPATCH_WATCH_DISABLE_LIST", "")):
Expand Down
12 changes: 5 additions & 7 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,15 @@ def after_call_op(self, result):
def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
self.name + " input(device)",
serialize_args_to_dict(*self.args_device, **self.kwargs_device)
)
data_dict_list += packect_data_to_dict_list(
self.name,
self.name + " input(cpu)",
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result), prefix="device_output")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output ")
data_dict_list += packect_data_to_dict_list(self.name + " output(device)", serialize_args_to_dict(self.result))
data_dict_list += packect_data_to_dict_list(self.name + " output(cpu)", serialize_args_to_dict(self.result_cpu))

table = dict_data_list_to_table(data_dict_list)
print(table)
Expand Down
37 changes: 17 additions & 20 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match
from .utils import is_opname_match, traverse_container, garbage_collect
from .pretty_print import (
pretty_print_op_args,
dict_data_list_to_table,
packect_data_to_dict_list,
)
Expand Down Expand Up @@ -91,8 +90,8 @@ def grad_fun(grad_inputs, grad_outputs):
self.end_time = time.time()
self.backward_elasped = self.end_time - self.start_time
data_dict_list = []
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_outputs), prefix="grad_outputs ")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_inputs), prefix="grad_inputs ")
data_dict_list += packect_data_to_dict_list(self.name + " grad_outputs", serialize_args_to_dict(grad_outputs))
data_dict_list += packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(grad_inputs))
table = dict_data_list_to_table(data_dict_list)
print(table)
elasped_info_dict = {
Expand Down Expand Up @@ -123,35 +122,33 @@ def after_call_op(self, result):
self.foward_elasped = self.end_time - self.start_time

self.backward_hook_handle = BackwardHookHandle(self.name, self.id)
if isinstance(self.result, torch.Tensor):
if self.result.grad_fn is not None:
self.result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())
self.result.grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())
elif isinstance(self.result, (tuple, list)) or type(self.result).__module__.startswith("torch.return_types"):
# torch.return_types is a structseq, aka a "namedtuple"-like thing defined by the Python C-API.
for i in range(len(self.result)):
if isinstance(self.result[i], torch.Tensor) and self.result[i].grad_fn is not None:
self.result[i].grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())

self.result[i].grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())

for result in traverse_container(self.result):
if isinstance(result, torch.Tensor) and result.grad_fn is not None:
result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())
result.grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())

with DisableHookGuard():
pretty_print_op_args(
self.name,
serialize_args_to_dict(*self.args, **self.kwargs),
serialize_args_to_dict(self.result),
)
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs", serialize_args_to_dict(self.result))
forward_args_table = dict_data_list_to_table(inputs_list + output_list)

elasped_info_dict = {
"name": self.name,
"forward_id": self.id,
"forward_elasped": f"{(self.foward_elasped * 1000):>10.8f}",
"unit": "ms",
}
print("\n" * 2)
print(forward_args_table)
print(dict_data_list_to_table([elasped_info_dict]))
print("\n" * 2)
elasped_info_dict["input"] = serialize_args_to_dict(*self.args, **self.kwargs)
elasped_info_dict["output"] = serialize_args_to_dict(self.result)
time_measure_result_cache.append(self.id, elasped_info_dict)

garbage_collect(self.id)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
return False
Expand Down
21 changes: 6 additions & 15 deletions op_tools/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,32 @@ def dict_data_list_to_table(data_dict_list):
return table


def packect_data_to_dict_list(op_name, inputs_dict, prefix):
def packect_data_to_dict_list(op_name, inputs_dict):
data_dict_list = []
args = inputs_dict.get("args", [])
kwargs = inputs_dict.get("kwargs", {})
arg_index = -1
for arg in args:
arg_index += 1
if isinstance(arg, dict):
item_name = op_name + f" {prefix}" + (f"[{arg_index}]" if len(args) > 1 else "")
item_name = op_name + (f"[{arg_index}]" if len(args) > 1 else "")
data_dict = {"name": item_name}
data_dict.update(arg)
data_dict_list.append(data_dict)
elif isinstance(arg, (tuple, list)):
arg_sub_index = -1
for item in arg:
arg_sub_index += 1
item_name = op_name + f" {prefix}[{arg_index}]" + f"[{arg_sub_index}]"
item_name = op_name + f" [{arg_index}]" + f"[{arg_sub_index}]"
if isinstance(item, dict):
data_dict = {"name": item_name}
data_dict.update(item)
data_dict_list.append(data_dict)
else:
data_dict_list.append({"name": item_name, "value": item})
elif isinstance(arg, (str, int, float, bool)):
data_dict_list.append({"name": op_name + (f"[{arg_index}]" if len(args) > 1 else ""), "value": arg})
for key, value in kwargs.items():
data_dict_list.append({"name": op_name + f" [{key}]", "value": value})
data_dict_list.append({"name": op_name + f" {key}", "value": value})

return data_dict_list


def pretty_print_op_args(op_name, inputs_dict, outputs_dict=None):

input_data_dict_list = packect_data_to_dict_list(op_name, inputs_dict, "inputs")
output_data_dict_list = packect_data_to_dict_list(op_name, outputs_dict, "outputs")
data_dict_list = input_data_dict_list + output_data_dict_list
table = dict_data_list_to_table(data_dict_list)
if len(data_dict_list) > 0:
print(table)
return table
14 changes: 9 additions & 5 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,15 @@ def test_compare_different_bool(self):
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_invalid_input(self):
compare_result("empty_list", [], []) # 输入空列表
compare_result("empty_tesnsor", torch.empty(0).cuda(), torch.empty(0).cuda()) # 输入空张量
compare_result("invalid_type", (), []) # 输入元组
compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3]) # 输入a的元素类型不符合要求
compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3]) # 输入b的元素类型不符合要求
self.assertTrue(compare_result("empty_list", [], [])["allclose"]) # 输入空列表
self.assertTrue(compare_result("empty_tesnsor", torch.empty(0).cuda(), torch.empty(0).cuda())["allclose"]) # 输入空张量
self.assertTrue(compare_result("equal_tesnsor", torch.ones(1).cuda(), torch.ones(1).cuda())["allclose"]) # 输入相等张量empty
self.assertFalse(
compare_result("not_equal_tesnsor", torch.rand(1000).cuda(), -torch.rand(1000).cuda())["allclose"]
) # 输入相等张量empty
self.assertTrue(compare_result("invalid_type", (), [])["allclose"]) # 输入空元组和空列表
self.assertFalse(compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3])["allclose"]) # 输入a的元素类型不符合要求
self.assertFalse(compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3])["allclose"]) # 输入b的元素类型不符合要求


if __name__ == "__main__":
Expand Down
Loading