Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Sep 29, 2024
1 parent 3e5989f commit 41c130d
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 71 deletions.
53 changes: 28 additions & 25 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2024, DeepLink.
import torch
import gc
import os
import time
import atexit
Expand All @@ -15,6 +14,7 @@
is_opname_match,
compare_result,
is_random_number_gen_op,
garbage_collect,
)
from .save_op_args import save_op_args, serialize_args_to_dict

Expand Down Expand Up @@ -68,9 +68,9 @@ def register_grad_fn_hook(self, tensor):
hook_handle = None

def grad_fun(grad_inputs, grad_outputs):
hook_handle.remove()
self.compare_hook.run_backward_on_cpu(grad_inputs, grad_outputs)
self.compare_hook.compare_backward_relate()
hook_handle.remove()

hook_handle = tensor.grad_fn.register_hook(grad_fun)
return grad_fun
Expand Down Expand Up @@ -98,9 +98,9 @@ def copy_input_to_cpu(self):
)

def run_forward_on_cpu(self):
self.result_device = to_device("cpu", self.result, detach=True)
try:
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)
self.result_device = to_device("cpu", self.result, detach=True)
self.dtype_cast_dict = dict()
args_cpu = self.args_cpu
except Exception as e: # noqa: F841
Expand All @@ -119,7 +119,7 @@ def run_forward_on_cpu(self):
detach=True,
)
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
if (is_inplace_op(self.name) or is_view_op(self.name)) and self.args[0].requires_grad:
if (is_inplace_op(self.name) or self.kwargs.get("inplace", False) or is_view_op(self.name)) and self.args[0].requires_grad:
args_cpu = [item for item in self.args_cpu]
args_cpu[0] = args_cpu[0].clone()
self.args_cpu = tuple(args_cpu)
Expand All @@ -128,13 +128,6 @@ def run_forward_on_cpu(self):
args_cpu = self.args_cpu
self.result_cpu = self.func(*self.args_cpu, **self.kwargs_cpu)

self.result_device = to_device(
"cpu",
self.result,
dtype_cast_dict=self.dtype_cast_dict,
detach=True,
)

def run_backward_on_cpu(self, grad_inputs, grad_output):
self.grad_outputs_cpu = to_device("cpu", grad_output, dtype_cast_dict=self.dtype_cast_dict, detach=True)
self.grad_inputs_cpu = to_device("cpu", grad_inputs, dtype_cast_dict=self.dtype_cast_dict, detach=True)
Expand Down Expand Up @@ -191,8 +184,12 @@ def compare_forward_relate(self):
output_compare_result = self.compare_forward_result()

result_list = input_compare_result["result_list"] + output_compare_result["result_list"]

dtype_cast_info = ""
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"
print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id} {self.dtype_cast_dict if len(self.dtype_cast_dict) > 0 else ''}")
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(self.op_forward_args_to_table())
print(dict_data_list_to_table(result_list))
print("\n" * 2)
Expand All @@ -202,16 +199,18 @@ def compare_forward_relate(self):
self.save_forward_args()

def op_forward_args_to_table(self):
inputs_list = packect_data_to_dict_list(self.name + " inputs ", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs ", serialize_args_to_dict(self.result))
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
output_list = packect_data_to_dict_list(self.name + " outputs", serialize_args_to_dict(self.result))
cpu_inputs_list = packect_data_to_dict_list(self.name + " inputs(cpu)", serialize_args_to_dict(*self.args_cpu, **self.kwargs_cpu))
cpu_output_list = packect_data_to_dict_list(self.name + " outputs(cpu)", serialize_args_to_dict(self.result_cpu))
forward_args_table = dict_data_list_to_table(inputs_list + output_list + cpu_output_list)
forward_args_table = dict_data_list_to_table(inputs_list + output_list + cpu_inputs_list + cpu_output_list)
return forward_args_table

def op_backward_args_to_table(self, grad_inputs, grad_output):
grad_output_list = packect_data_to_dict_list(self.name + " grad_output ", serialize_args_to_dict(grad_output))
grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs ", serialize_args_to_dict(grad_inputs))
cpu_grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs(cpu)", serialize_args_to_dict(self.args_cpu_grad))
grad_output_list = packect_data_to_dict_list(self.name + " grad_output", serialize_args_to_dict(*grad_output))
grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(*grad_inputs))
grad_output_list = packect_data_to_dict_list(self.name + " grad_output(cpu)", serialize_args_to_dict(*self.grad_outputs_cpu))
cpu_grad_inputs_list = packect_data_to_dict_list(self.name + " grad_inputs(cpu)", serialize_args_to_dict(*tuple(self.args_cpu_grad))) # noqa: E501
self.backward_args_table = dict_data_list_to_table(grad_output_list + grad_inputs_list + cpu_grad_inputs_list)
return self.backward_args_table

Expand All @@ -224,7 +223,7 @@ def count_params_with_requires_grad(self):

def compare_input_grad(self):
self.args_grad = self.grad_inputs_cpu
compare_info = compare_result(self.name + " grad", self.args_cpu_grad, self.args_grad)
compare_info = compare_result(self.name + " grad", self.args_grad, self.args_cpu_grad)
compare_info["forward_id"] = self.forward_op_id

compare_result_cache.append(self.forward_op_id, compare_info)
Expand All @@ -236,8 +235,12 @@ def compare_input_grad(self):
def compare_backward_relate(self):
backward_compare_result = self.compare_input_grad()

dtype_cast_info = ""
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"

print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id}")
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(self.backward_args_table)
print(dict_data_list_to_table(backward_compare_result["result_list"]))
print("\n" * 2)
Expand All @@ -247,8 +250,10 @@ def compare_backward_relate(self):
if self.forward_allclose:
self.save_forward_args()
self.save_backward_args()

id = self.forward_op_id
self = None
gc.collect()
garbage_collect(id)

def save_forward_args(self):
save_op_args(
Expand Down Expand Up @@ -309,8 +314,8 @@ def after_call_op(self, result): # noqa:C901
return

self.register_backward_hook_for_grads()
id = self.forward_op_id
result = self.result
id = self.id
# for reduce device memory usage
if self.backward_hook_handle is not None:
self.args = to_device("cpu", self.args, detach=True)
Expand All @@ -319,9 +324,7 @@ def after_call_op(self, result): # noqa:C901
else:
self = None

gc_cycle = int(os.getenv("OP_AUTOCOMPARE_GARBAGE_COLLECTION_CYCLE", "100"))
if id % gc_cycle == 0:
gc.collect()
garbage_collect(id)
return result

def is_should_apply(self, *args, **kwargs):
Expand Down
12 changes: 5 additions & 7 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,15 @@ def after_call_op(self, result):
def dump_op_args(self):
data_dict_list = []
data_dict_list += packect_data_to_dict_list(
self.name,
serialize_args_to_dict(*self.args_device, **self.kwargs_device),
prefix="device_input ",
self.name + " input(device)",
serialize_args_to_dict(*self.args_device, **self.kwargs_device)
)
data_dict_list += packect_data_to_dict_list(
self.name,
self.name + " input(cpu)",
serialize_args_to_dict(*self.args, **self.kwargs),
prefix="cpu_input ",
)
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result), prefix="device_output")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(self.result_cpu), prefix="cpu_output ")
data_dict_list += packect_data_to_dict_list(self.name + " output(device)", serialize_args_to_dict(self.result))
data_dict_list += packect_data_to_dict_list(self.name + " output(cpu)", serialize_args_to_dict(self.result_cpu))

table = dict_data_list_to_table(data_dict_list)
print(table)
Expand Down
26 changes: 12 additions & 14 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .base_hook import BaseHook, DisableHookGuard

from .save_op_args import serialize_args_to_dict
from .utils import is_opname_match
from .utils import is_opname_match, traverse_container, garbage_collect
from .pretty_print import (
dict_data_list_to_table,
packect_data_to_dict_list,
Expand Down Expand Up @@ -90,8 +90,8 @@ def grad_fun(grad_inputs, grad_outputs):
self.end_time = time.time()
self.backward_elasped = self.end_time - self.start_time
data_dict_list = []
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_outputs), prefix="grad_outputs ")
data_dict_list += packect_data_to_dict_list(self.name, serialize_args_to_dict(grad_inputs), prefix="grad_inputs ")
data_dict_list += packect_data_to_dict_list(self.name + " grad_outputs", serialize_args_to_dict(grad_outputs))
data_dict_list += packect_data_to_dict_list(self.name + " grad_inputs", serialize_args_to_dict(grad_inputs))
table = dict_data_list_to_table(data_dict_list)
print(table)
elasped_info_dict = {
Expand Down Expand Up @@ -122,17 +122,11 @@ def after_call_op(self, result):
self.foward_elasped = self.end_time - self.start_time

self.backward_hook_handle = BackwardHookHandle(self.name, self.id)
if isinstance(self.result, torch.Tensor):
if self.result.grad_fn is not None:
self.result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())
self.result.grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())
elif isinstance(self.result, (tuple, list)) or type(self.result).__module__.startswith("torch.return_types"):
# torch.return_types is a structseq, aka a "namedtuple"-like thing defined by the Python C-API.
for i in range(len(self.result)):
if isinstance(self.result[i], torch.Tensor) and self.result[i].grad_fn is not None:
self.result[i].grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())

self.result[i].grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())

for result in traverse_container(self.result):
if isinstance(result, torch.Tensor) and result.grad_fn is not None:
result.grad_fn.register_hook(self.backward_hook_handle.grad_fun_posthook())
result.grad_fn.register_prehook(self.backward_hook_handle.grad_fun_prehook())

with DisableHookGuard():
inputs_list = packect_data_to_dict_list(self.name + " inputs", serialize_args_to_dict(*self.args, **self.kwargs))
Expand All @@ -145,12 +139,16 @@ def after_call_op(self, result):
"forward_elasped": f"{(self.foward_elasped * 1000):>10.8f}",
"unit": "ms",
}
print("\n" * 2)
print(forward_args_table)
print(dict_data_list_to_table([elasped_info_dict]))
print("\n" * 2)
elasped_info_dict["input"] = serialize_args_to_dict(*self.args, **self.kwargs)
elasped_info_dict["output"] = serialize_args_to_dict(self.result)
time_measure_result_cache.append(self.id, elasped_info_dict)

garbage_collect(self.id)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
return False
Expand Down
4 changes: 3 additions & 1 deletion op_tools/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def packect_data_to_dict_list(op_name, inputs_dict):
data_dict_list.append(data_dict)
else:
data_dict_list.append({"name": item_name, "value": item})
elif isinstance(arg, (str, int, float, bool)):
data_dict_list.append({"name": op_name + (f"[{arg_index}]" if len(args) > 1 else ""), "value": arg})
for key, value in kwargs.items():
data_dict_list.append({"name": op_name + f" [{key}]", "value": value})
data_dict_list.append({"name": op_name + f" {key}", "value": value})

return data_dict_list
14 changes: 9 additions & 5 deletions op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,15 @@ def test_compare_different_bool(self):
self.assertTrue(isinstance(compare_info["result_list"], list))

def test_compare_invalid_input(self):
compare_result("empty_list", [], []) # 输入空列表
compare_result("empty_tesnsor", torch.empty(0).cuda(), torch.empty(0).cuda()) # 输入空张量
compare_result("invalid_type", (), []) # 输入元组
compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3]) # 输入a的元素类型不符合要求
compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3]) # 输入b的元素类型不符合要求
self.assertTrue(compare_result("empty_list", [], [])["allclose"]) # 输入空列表
self.assertTrue(compare_result("empty_tesnsor", torch.empty(0).cuda(), torch.empty(0).cuda())["allclose"]) # 输入空张量
self.assertTrue(compare_result("equal_tesnsor", torch.ones(1).cuda(), torch.ones(1).cuda())["allclose"]) # 输入相等张量empty
self.assertFalse(
compare_result("not_equal_tesnsor", torch.rand(1000).cuda(), -torch.rand(1000).cuda())["allclose"]
) # 输入相等张量empty
self.assertTrue(compare_result("invalid_type", (), [])["allclose"]) # 输入空元组和空列表
self.assertFalse(compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3])["allclose"]) # 输入a的元素类型不符合要求
self.assertFalse(compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3])["allclose"]) # 输入b的元素类型不符合要求


if __name__ == "__main__":
Expand Down
38 changes: 23 additions & 15 deletions op_tools/test/test_pretty_print.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from op_tools.save_op_args import serialize_args_to_dict
from op_tools.pretty_print import pretty_print_op_args
from op_tools.utils import compare_result
from op_tools.pretty_print import dict_data_list_to_table, packect_data_to_dict_list
import torch
import ditorch

x = torch.randn(3, 4, device="cuda")
y = torch.randn(3, 4, 7, 8, device="cpu")

pretty_print_op_args(
op_name="torch.add",
inputs_dict=serialize_args_to_dict(x, x, x),
outputs_dict=serialize_args_to_dict(x),
)
pretty_print_op_args(
op_name="torch.stack",
inputs_dict=serialize_args_to_dict([x, x, x], dim=1),
outputs_dict=serialize_args_to_dict(x),
)
import unittest


class TestPrettyPrint(unittest.TestCase):
def test_pretty_print(self):
x = torch.randn(3, 4, device="cuda")
y = torch.randn(3, 4, 7, 8, device="cpu")

data_list1 = packect_data_to_dict_list("torch.add", serialize_args_to_dict(x, x))
data_list2 = packect_data_to_dict_list("torch.stack", serialize_args_to_dict([y, y, y], dim=1))

data_list = data_list1 + data_list2

self.assertTrue(len(data_list) == 6)

table = dict_data_list_to_table(data_list)

print(table)


if __name__ == "__main__":
unittest.main()
28 changes: 26 additions & 2 deletions op_tools/test/test_tool_with_special_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def test_untyped_storage(self):
x = torch.randn(4, 5, dtype=torch.float32, device="cuda")
y = x.untyped_storage() # type(y) is class 'torch.storage.UntypedStorage'

with op_tools.OpTimeMeasure():
x = torch.randn(4, 5, dtype=torch.float32, device="cuda")
y = x.untyped_storage() # type(y) is class 'torch.storage.UntypedStorage'

def test_sort(self):
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda")
y = x.sort()
Expand All @@ -41,6 +45,10 @@ def test_sort(self):
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda")
y = x.sort() # type(y) is class 'torch.return_types.sort'

with op_tools.OpTimeMeasure():
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda")
y = x.sort() # type(y) is class 'torch.return_types.sort'

def test_traverse_container_with_dtype(self):
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda")
y = x.dtype
Expand All @@ -67,8 +75,12 @@ def test_setitem(self):
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda")
x[0, 0, 0] = 1.0 # __setitem__ return None

with op_tools.OpTimeMeasure():
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda")
x[0, 0, 0] = 1.0 # __setitem__ return None

def test_inplace_op(self):
with op_tools.OpAutoCompare():
def f():
m = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda", requires_grad=True)
x = m + 1
x.add_(1.0)
Expand All @@ -80,8 +92,14 @@ def test_inplace_op(self):
y = x.abs()
y.backward(torch.ones_like(x))

def test_contiguous(self):
with op_tools.OpAutoCompare():
f()

with op_tools.OpTimeMeasure():
f()

def test_contiguous(self):
def f():
x = torch.randn(3, 4, 5, dtype=torch.float32, device="cuda", requires_grad=True)
y = x.contiguous()
z = y + y
Expand All @@ -98,6 +116,12 @@ def test_contiguous(self):
z = y + y
z.backward(torch.ones_like(z))

with op_tools.OpAutoCompare():
f()

with op_tools.OpTimeMeasure():
f()


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 41c130d

Please sign in to comment.