Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch enhance garbage collect #57

Merged
merged 4 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def compare_forward_relate(self):
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"
print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"autocompare {self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"{self.current_location}")
print(self.op_forward_args_to_table())
print(dict_data_list_to_table(result_list))
Expand Down Expand Up @@ -272,7 +272,7 @@ def compare_backward_relate(self):
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"

print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"autocompare {self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"{self.current_location}")
print(self.backward_args_table)
print(dict_data_list_to_table(backward_compare_result["result_list"]))
Expand All @@ -284,9 +284,8 @@ def compare_backward_relate(self):
self.save_forward_args()
self.save_backward_args()

id = self.forward_op_id
self = None
garbage_collect(id, 2)
garbage_collect()

def save_forward_args(self):
save_op_args(
Expand Down Expand Up @@ -347,7 +346,6 @@ def after_call_op(self, result): # noqa:C901
return

self.register_backward_hook_for_grads()
id = self.forward_op_id
result = self.result
# for reduce device memory usage
if self.backward_hook_handle is not None:
Expand All @@ -357,7 +355,7 @@ def after_call_op(self, result): # noqa:C901
else:
self = None

garbage_collect(id, 10)
garbage_collect()
return result

def is_should_apply(self, *args, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions op_tools/op_capture_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ def after_call_op(self, result):
if result.grad_fn is not None:
self.backward_hook_handle.register_grad_fun_hook(result)

id = self.id
self = None
garbage_collect(id)
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_CAPTURE_DISABLE_LIST", "")):
Expand Down
89 changes: 62 additions & 27 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
traverse_container,
get_dtype_cast_dict_form_str,
is_opname_match,
is_view_op,
is_dtype_cast_op,
garbage_collect
)
from .pretty_print import dict_data_list_to_table
Expand All @@ -26,10 +28,21 @@ def before_call_op(self, *args, **kwargs):
)
self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)
with DisableHookGuard():
self.args_raw = self.args
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return

self.raw_ins_dtype_list = []
for arg in traverse_container(self.args):
if isinstance(arg, torch.Tensor):
self.raw_ins_dtype_list.append(arg.dtype)
else:
self.raw_ins_dtype_list.append(None)

for dtype in set(self.dtype_cast_dict.keys()):
if dtype not in self.raw_ins_dtype_list:
self.dtype_cast_dict.pop(dtype)

self.args = to_device(
self.device,
self.args,
Expand All @@ -43,57 +56,79 @@ def before_call_op(self, *args, **kwargs):
detach=False,
)
self.dtype_cast_back_dict = {}
self.ins_list = []
self.ins_dtype_list = []
for arg in traverse_container(self.args):
self.ins_list.append(arg)

self.raw_ins_list = []
for arg in traverse_container(self.args_raw):
self.raw_ins_list.append(arg)
if isinstance(arg, torch.Tensor):
self.ins_dtype_list.append(arg.dtype)
else:
self.ins_dtype_list.append(None)

self.data_dict_list = []
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
for i in range(len(self.ins_dtype_list)):
if self.ins_dtype_list[i] != self.raw_ins_dtype_list[i]:
self.dtype_cast_back_dict[self.ins_dtype_list[i]] = self.raw_ins_dtype_list[i]
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_dtype_list[i]} -> {self.ins_dtype_list[i]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)

def after_call_op(self, result):
if self.is_cpu_op:
return
with DisableHookGuard():
self.result_raw = result
self.raw_result_dtype_list = []
for arg in traverse_container(self.result):
if isinstance(arg, torch.Tensor):
self.raw_result_dtype_list.append(arg.dtype)
else:
self.raw_result_dtype_list.append(None)

self.result = to_device(
self.device,
self.result,
dtype_cast_dict=self.dtype_cast_back_dict,
detach=False,
)

self.result_dtype_list = []
for arg in traverse_container(self.result):
if isinstance(arg, torch.Tensor):
self.result_dtype_list.append(arg.dtype)
else:
self.result_dtype_list.append(None)

i = -1
for out in traverse_container(self.result_raw):
for out in traverse_container(self.raw_result_dtype_list):
i += 1
if isinstance(out, torch.Tensor) and out.dtype in self.dtype_cast_back_dict.keys():
if out in self.dtype_cast_back_dict.keys():
data_dict = {
"name": self.name,
"target": f"output[{i}]",
"action": f"{out.dtype} -> {self.dtype_cast_back_dict[out.dtype]}",
"action": f"{out} -> {self.dtype_cast_back_dict[out]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))
id = self.id
self = None
garbage_collect(id)
if len(self.data_dict_list) > 0:
print("\n" * 2, f"cast_dtype {self.name} forward_id: {self.id}")
print(f"{self.current_location}")
print(dict_data_list_to_table(self.data_dict_list))
print("\n" * 2)
result = self.result
self = None
garbage_collect()
return result

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
return False

if is_view_op(self.name):
return False

if is_dtype_cast_op(self.name, *args, **kwargs):
return False

return is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_LIST", ".*"))
5 changes: 2 additions & 3 deletions op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ def after_call_op(self, result):

self.result = to_device(self.device, self.result_cpu, self.dtype_convert_back_dict)
self.dump_op_args()
id = self.id
self = None
garbage_collect(id)
garbage_collect()

def dump_op_args(self):
data_dict_list = []
Expand All @@ -99,7 +98,7 @@ def dump_op_args(self):
dtype_cast_info = "cpu_dtype_cast_info: " + str(self.dtype_cast_dict)

print("\n" * 2)
print(f"{self.name} forward_id: {self.id} {dtype_cast_info}")
print(f"fallback {self.name} forward_id: {self.id} {dtype_cast_info}")
print(f"{self.current_location}")
print(table)
print("\n" * 2)
Expand Down
2 changes: 1 addition & 1 deletion op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def after_call_op(self, result):
elasped_info_dict["output"] = serialize_args_to_dict(self.result)
time_measure_result_cache.append(self.id, elasped_info_dict)

garbage_collect(self.id)
garbage_collect()

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
Expand Down
10 changes: 9 additions & 1 deletion op_tools/test/test_compare_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2024, DeepLink.
from op_tools.utils import compare_result
from op_tools.utils import compare_result, tensor_cos_similarity
import torch
import ditorch
import unittest
Expand Down Expand Up @@ -156,6 +156,14 @@ def test_compare_invalid_input(self):
self.assertFalse(compare_result("invalid_value_a", ["1", 2, 3], [1, 2, 3])["allclose"]) # 输入a的元素类型不符合要求
self.assertFalse(compare_result("invalid_value_b", [1, 2, 3], ["1", 2, 3])["allclose"]) # 输入b的元素类型不符合要求

def test_cosine_similarity(self):
x = torch.randn(3, 4, 4, device="cuda").float()
y = torch.randn(3, 4, 4, device="cuda")
self.assertTrue(abs(tensor_cos_similarity(x, x) - 1) < 1e-6)
self.assertTrue(abs(tensor_cos_similarity(x, -x) + 1) < 1e-6)
xy_cos_similarity = tensor_cos_similarity(x, y)
self.assertTrue(xy_cos_similarity >= -1 and xy_cos_similarity <= 1)


if __name__ == "__main__":
unittest.main()
11 changes: 11 additions & 0 deletions op_tools/test/test_opname_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_get_dtype_cast_dict_from_config(self):
},
)

def test_get_dtype_cast_dict_from_config2(self):
dtype_cast_dict = get_dtype_cast_dict_form_str(" torch.float32 ->torch.float16, torch.float64 -> torch.float16, torch.int64-> torch.int32 ")
self.assertEqual(
dtype_cast_dict,
{
torch.float32: torch.float16,
torch.float64: torch.float16,
torch.int64: torch.int32,
},
)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions op_tools/test/test_tool_with_special_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def test_exp(self):
y = x.exp()
y.backward(torch.ones_like(y))

def test_dtype_cast(self):
with op_tools.OpDtypeCast():
x = torch.randn(3, 4, 5, dtype=torch.float16, device="cuda", requires_grad=True)
y = x.to(torch.float32)
y.backward(torch.ones_like(y))


if __name__ == "__main__":
unittest.main()
Loading