Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch support result cache management #38

Merged
merged 7 commits into from
Sep 19, 2024
Merged
70 changes: 40 additions & 30 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,42 @@
SKIP_LIST_OPS = []


class AutoCompareResultCache:
global_autocompare_result = []

def __init__(self) -> None:
self.file_name = f"op_tools_results/op_autocompare_result/op_autocompare_info_pid{os.getpid()}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
self.dir = self.file_name[0 : self.file_name.rfind("/")]

def append(self, forward_id, compare_info):
for result in compare_info["result_list"]:
self.global_autocompare_result.append({"forward_id": forward_id, **result})

if len(self.global_autocompare_result) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "1000")):
self.write_to_file()

def write_to_file(self):
if len(self.global_autocompare_result) == 0:
return
table = dict_data_list_to_table(self.global_autocompare_result)
print(table)
self.global_autocompare_result.clear()
data_string = table.get_csv_string()

os.makedirs(self.dir, exist_ok=True)
with open(self.file_name, "a+") as f:
f.write(data_string)
f.close
print(f"op autocompare result saved to {self.file_name}")


compare_result_cache = AutoCompareResultCache()


def dump_all_autocompare_info():
compare_result_cache.write_to_file()


class BackwardHookHandle:
def __init__(self, compare_hook) -> None:
self.compare_hook = compare_hook
Expand All @@ -40,9 +76,6 @@ def grad_fun(grad_inputs, grad_outputs):
return grad_fun


global_autocompare_result = []


class OpAutoCompareHook(BaseHook):
AUTO_COMPARE_DTYPE_CAST_DICT = {
torch.half: torch.float32,
Expand Down Expand Up @@ -128,8 +161,8 @@ def register_backward_hook_for_grads(self):

def compare_forward_result(self):
compare_info = compare_result(self.name, self.result_device, self.result_cpu)
compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)
compare_result_cache.append(self.forward_op_id, compare_info)

allclose = compare_info["allclose"]
self.forward_allclose = allclose
if not allclose:
Expand All @@ -151,8 +184,8 @@ def compare_all_grad(self):
self.args_grad = self.grad_inputs_cpu
compare_info = compare_result(self.name + " grad", self.args_cpu_grad, self.args_grad)

compare_info.update({"forward_id": self.forward_op_id})
global_autocompare_result.append(compare_info)
compare_result_cache.append(self.forward_op_id, compare_info)

if not compare_info["allclose"]:
# Parameters are not saved when forward accuracy is normal
if self.forward_allclose:
Expand Down Expand Up @@ -241,27 +274,4 @@ def is_should_apply(self, *args, **kwargs):
return is_opname_match(self.name, os.getenv("OP_AUTOCOMPARE_LIST", ".*"))


def dump_all_autocompare_info():
if len(global_autocompare_result) == 0:
return
all_compare_info_list = []
while len(global_autocompare_result) > 0:
compare_info = global_autocompare_result.pop(0)
while len(compare_info["result_list"]) > 0:
compare_result = compare_info["result_list"].pop(0)
all_compare_info_list.append({"forward_id": compare_info["forward_id"], **compare_result})

table = dict_data_list_to_table(all_compare_info_list)
print(table)
data_string = table.get_csv_string()
file_name = f"op_tools_results/op_autocompare_result/op_autocompare_info_pid{os.getpid()}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
dir = file_name[0 : file_name.rfind("/")]
os.makedirs(dir, exist_ok=True)

with open(file_name, "w") as f:
f.write(data_string)
f.close
print(f"op autocompare info saved to {file_name}")


atexit.register(dump_all_autocompare_info)
66 changes: 18 additions & 48 deletions op_tools/op_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from abc import ABC
import os
import torch
import time
from .utils import to_device, get_function_from_string, traverse_container, is_inplace_op
from .op_autocompare_hook import OpAutoCompareHook
from .op_time_measure_hook import OpTimeMeasureHook


class OpRunnerHook(ABC):
Expand All @@ -21,69 +21,39 @@ def after_backward(self):
pass


class AsyncEventTimer(OpRunnerHook):
def __init__(self) -> None:
super().__init__()
self.forward_start_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
self.forward_end_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)

self.backward_start_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)
self.backward_end_event = torch.cuda.Event(enable_timing=True, blocking=False, interprocess=False)

def before_forward(self):
self.forward_start_event.record(torch.cuda.current_stream)

def after_forward(self):
self.forward_end_event.record(torch.cuda.current_stream)

def before_backward(self):
self.backward_start_event.record(torch.cuda.current_stream)

def after_backward(self):
self.backward_end_event.record(torch.cuda.current_stream)


class SyncExecuteTimer(OpRunnerHook):
def __init__(self) -> None:
super().__init__()
self.time_measure_hook = None
self.raw_func = None

def before_forward(self):
torch.cuda.current_stream().synchronize()
self.forward_start_time = time.time()
if self.time_measure_hook is None:
self.raw_func = self.runner.func
self.time_measure_hook = OpTimeMeasureHook(self.runner.name, self.runner.func)
self.runner.func = self.time_measure_hook
else:
self.runner.func = self.time_measure_hook

def after_forward(self):
torch.cuda.current_stream().synchronize()
self.forward_end_time = time.time()
self.forward_elasped_time = self.forward_end_time - self.forward_start_time
print(f"SyncExecuteTimer: {self.runner.name} forward elasped {self.forward_elasped_time * 1000:>.8f} ms ")

def before_backward(self):
torch.cuda.current_stream().synchronize()
self.backward_start_time = time.time()

def after_backward(self):
torch.cuda.current_stream().synchronize()
self.backward_end_time = time.time()
self.backward_elasped_time = self.backward_end_time - self.forward_start_time
print(f"SyncExecuteTimer: {self.runner.name} backward elasped {self.backward_elasped_time * 1000:>.8f} ms")
self.runner.func = self.raw_func


class OpAccyChecker(OpRunnerHook):
def __init__(self) -> None:
super().__init__()
self.raw_func = None
self.aucompare_hook = None

def before_forward(self):
self.aucompare_hook = OpAutoCompareHook(self.runner.name, self.runner.func)
self.runner.func = self.aucompare_hook
if self.aucompare_hook is None:
self.aucompare_hook = OpAutoCompareHook(self.runner.name, self.runner.func)
self.runner.func = self.aucompare_hook
else:
self.runner.func = self.aucompare_hook

def after_forward(self):
pass

def before_backward(self):
pass

def after_backward(self):
self.runner.func = self.aucompare_hook.func
self.runner.func = self.raw_func


class OpRunner:
Expand Down
86 changes: 59 additions & 27 deletions op_tools/op_time_measure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,63 @@
global_elasped_info_dict = dict()


class TimeMeasureResultCache:
global_elasped_info_dict = dict()

def __init__(self) -> None:
self.file_name = f"op_tools_results/op_time_measure_result/op_elasped_info_pid{os.getpid()}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
self.dir = self.file_name[0 : self.file_name.rfind("/")]
self.ordered_keys = [
"name",
"forward_id",
"forward_elasped",
"backward_elasped",
"unit",
"input",
"output",
"grad_inputs",
"grad_outputs",
]

def append(self, forward_id, elasped_info):
if forward_id not in self.global_elasped_info_dict:
self.global_elasped_info_dict[forward_id] = elasped_info
else:
self.global_elasped_info_dict[forward_id].update(elasped_info)

if len(self.global_elasped_info_dict) > int(os.getenv("OP_TOOLS_MAX_CACHE_SIZE", "5000")):
self.write_to_file()

def write_to_file(self):
if len(self.global_elasped_info_dict) == 0:
return
simple_data_list = []
for key, value in self.global_elasped_info_dict.items():
new_value = {k: value.get(k, "-") for k in self.ordered_keys}
simple_value = {k: value.get(k, "-") for k in self.ordered_keys[0:5]}
simple_data_list.append(simple_value)
self.global_elasped_info_dict[key] = new_value

print(dict_data_list_to_table(simple_data_list))

table = dict_data_list_to_table(list(self.global_elasped_info_dict.values()))
self.global_elasped_info_dict.clear()
data_string = table.get_csv_string()
os.makedirs(self.dir, exist_ok=True)

with open(self.file_name, "a+") as f:
f.write(data_string)
f.close
print(f"op elasped info saved to {self.file_name}")


time_measure_result_cache = TimeMeasureResultCache()


def dump_all_op_elasped_info():
time_measure_result_cache.write_to_file()


class BackwardHookHandle:
def __init__(self, name, id) -> None:
self.name = name
Expand Down Expand Up @@ -47,7 +104,7 @@ def grad_fun(grad_inputs, grad_outputs):
print(dict_data_list_to_table([elasped_info_dict]))
elasped_info_dict["grad_inputs"] = serialize_args_to_dict(grad_inputs)
elasped_info_dict["grad_outputs"] = serialize_args_to_dict(grad_outputs)
global_elasped_info_dict[self.id].update(elasped_info_dict)
time_measure_result_cache.append(self.id, elasped_info_dict)

return grad_fun

Expand Down Expand Up @@ -93,7 +150,7 @@ def after_call_op(self, result):
print(dict_data_list_to_table([elasped_info_dict]))
elasped_info_dict["input"] = serialize_args_to_dict(*self.args, **self.kwargs)
elasped_info_dict["output"] = serialize_args_to_dict(self.result)
global_elasped_info_dict[self.id] = elasped_info_dict
time_measure_result_cache.append(self.id, elasped_info_dict)

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_DISABLE_LIST", "")):
Expand All @@ -102,29 +159,4 @@ def is_should_apply(self, *args, **kwargs):
return is_opname_match(self.name, os.getenv("OP_TIME_MEASURE_LIST", ".*"))


def dump_all_op_elasped_info():
if len(global_elasped_info_dict) == 0:
return
ordered_keys = ["name", "forward_id", "forward_elasped", "backward_elasped", "unit", "input", "output", "grad_inputs", "grad_outputs"]
simple_data_list = []
for key, value in global_elasped_info_dict.items():
new_value = {k: value[k] for k in ordered_keys}
simple_value = {k: value[k] for k in ordered_keys[0:5]}
simple_data_list.append(simple_value)
global_elasped_info_dict[key] = new_value

print(dict_data_list_to_table(simple_data_list))

table = dict_data_list_to_table(list(global_elasped_info_dict.values()))
data_string = table.get_csv_string()
file_name = f"op_tools_results/op_time_measure_result/op_elasped_info_pid{os.getpid()}_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.csv" # noqa: E501
dir = file_name[0 : file_name.rfind("/")]
os.makedirs(dir, exist_ok=True)

with open(file_name, "w") as f:
f.write(data_string)
f.close
print(f"op elasped info saved to {file_name}")


atexit.register(dump_all_op_elasped_info)