Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add operator data type conversion capability #19

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import ditorch
| 2 | [精度分析工具](#tool2) | 进行离线和实时的精度分析 |
| 3 | [速度分析工具](#tool3) | 可进行离线和实时的耗时分析,协助性能优化 |
| 4 | [算子 Fallback](#tool4) | 可将指定、全部算子在设备上运行的操作 fallback 到 CPU 计算 |
| 5 | [算子数据类型转换工具](#tool5) | 可将指定、全部算子的特定数据类型转到给定数据类型去计算 |


### **算子参数抓取工具** <a id="tool1"></a>
Expand Down Expand Up @@ -344,3 +345,66 @@ OpFallbackHook: torch.Tensor.mean input: {'args
OpFallbackHook: torch.Tensor.mean output: ({'shape': torch.Size([1, 16384, 1]), 'stride': (16384, 1, 1), 'numel': 16384, 'dtype': 'torch.float32', 'device': 'npu:0', 'requires_grad': False, 'layout': 'torch.strided', 'data': 20067180141056},) cpu output: ({'shape': torch.Size([1, 16384, 1]), 'stride': (16384, 1, 1), 'numel': 16384, 'dtype': 'torch.float32', 'device': 'cpu', 'requires_grad': False, 'layout': 'torch.strided', 'data': 33561021952},) dtype_convert_back_dict:{}
...
```


### **算子数据类型转换工具** <a id="tool5"></a>

```
# usage1
export OP_DTYPE_CAST_DICT="torch.float16->torch.float32,torch.bfloat16->torch.float32"
with op_tools.OpDtypeCast():
f()

# usage2
dtype_caster = op_tools.OpDtypeCast()
dtype_caster.start()
for i in range(3):
f()
dtype_caster.stop()
```

```
# usage3
os.environ["OP_DTYPE_CAST_DISABLE_LIST"] = "torch.Tensor.add,torch.Tensor.sub"
dtype_caster.start()
f()
dtype_caster.stop()
```
```
# usage4
os.environ["OP_DTYPE_CAST_DISABLE_LIST"] = ""
os.environ["OP_DTYPE_CAST_LIST"] = "torch.Tensor.sort,torch.Tensor.add" # only cast these op
os.environ["OP_DTYPE_CAST_DICT"] = "torch.half->torch.bfloat16"
dtype_caster.start()
f()
dtype_caster.stop()
```

```
apply OpDtypeCastHook on torch.nn.functional.linear
OpDtypeCastHook: torch.nn.functional.linear 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.nn.functional.linear 1th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.nn.functional.linear 2th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.nn.functional.linear 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
apply OpDtypeCastHook on torch.Tensor.add
OpDtypeCastHook: torch.Tensor.add 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.add 1th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.add 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
apply OpDtypeCastHook on torch.Tensor.sub
OpDtypeCastHook: torch.Tensor.sub 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.sub 1th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.sub 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
apply OpDtypeCastHook on torch.Tensor.div
OpDtypeCastHook: torch.Tensor.div 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.div 1th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.div 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
apply OpDtypeCastHook on torch.Tensor.sort
OpDtypeCastHook: torch.Tensor.sort 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.sort 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
apply OpDtypeCastHook on torch.Tensor.__getitem__
OpDtypeCastHook: torch.Tensor.__getitem__ 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.__getitem__ 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
apply OpDtypeCastHook on torch.Tensor.sum
OpDtypeCastHook: torch.Tensor.sum 0th arg torch.float16 -> torch.float32 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
OpDtypeCastHook: torch.Tensor.sum 0th out torch.float32 -> torch.float16 config:torch.float16->torch.float32,torch.bfloat16->torch.float32
```
1 change: 1 addition & 0 deletions op_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
OpAutoCompare,
OpDispatchWatcher,
OpTimeMeasure,
OpDtypeCast,
)
48 changes: 48 additions & 0 deletions op_tools/apply_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .op_autocompare_hook import OpAutoCompareHook
from .op_dispatch_watch_hook import OpDispatchWatcherHook
from .op_time_measure_hook import OpTimeMeasureHook
from .op_dtype_cast_hook import OpDtypeCastHook
from .utils import is_cpu_op, is_opname_match
import inspect

Expand Down Expand Up @@ -325,3 +326,50 @@ def start(self):

def stop(self):
super().__exit__(None, None, None)


class OpDtypeCast(OpToolBase):
"""
Set the OP_DTYPE_CAST_DISABLE_LIST environment variable to ignore specific operators
Set the OP_DTYPE_CAST_LIST environment variable to only take effect on these operators
Usage1:
with OpDtypeCast():
f()
Usage2:
dtypecaster = OpDtypeCast()
dtypecaster.start()
f()
dtypecaster.end()
"""

def __init__(self):
super().__init__()

def is_should_cast(self, name, func, args, kwargs=None):
if not is_should_apply_hook(name, func, args, kwargs=None):
return False

if is_opname_match(name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
return False

return is_opname_match(name, os.getenv("OP_DTYPE_CAST_LIST", ".*"))

def __torch_function__(self, func, types, args, kwargs=None):
name = resolve_name(func)
if self.is_should_cast(name, func, args, kwargs):
print(f"apply OpDtypeCastHook on {name}")
new_func = OpDtypeCastHook(name)(func)
return new_func(*args, **(kwargs or {}))
else:
if name not in self.skiped_op:
print(f"skip OpDtypeCastHook on {name}")
self.skiped_op.add(name)

return func(*args, **(kwargs or {}))

def start(self):
super().__enter__()
self.skiped_op.clear()

def stop(self):
super().__exit__(None, None, None)
66 changes: 66 additions & 0 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024, DeepLink.
import torch
import math
import gc
import os

from .base_hook import BaseHook, DisableHookGuard
from .utils import (
to_device,
is_cpu_op,
traverse_container,
is_inplace_op,
get_dtype_cast_dict_form_str,
)
from .op_fallback_hook import OpFallbackHook
from .save_op_args import save_op_args, serialize_args_to_dict


class OpDtypeCastHook(BaseHook):

def __init__(self, name) -> None:
super().__init__(name)
self.dtype_cast_config_str = os.environ.get(
"OP_DTYPE_CAST_DICT",
"torch.float16->torch.float32,torch.bfloat16->torch.float32",
)
self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)

def before_call_op(self, *args, **kwargs):
with DisableHookGuard():
self.args_raw = self.args
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return
self.args = to_device(self.device, self.args, self.dtype_cast_dict)
self.kwargs = to_device(
self.device, self.kwargs or {}, self.dtype_cast_dict
)
self.dtype_cast_back_dict = {}
for i in range(len(self.args_raw)):
if isinstance(self.args_raw[i], torch.Tensor) and (
self.args_raw[i].dtype in self.dtype_cast_dict
):
print(
f"OpDtypeCastHook: {self.name:<50} {i}th arg {self.args_raw[i].dtype} -> {self.args[i].dtype} config:{self.dtype_cast_config_str}"
)
self.dtype_cast_back_dict[self.args[i].dtype] = self.args_raw[
i
].dtype

def after_call_op(self, result):
if self.is_cpu_op:
return
with DisableHookGuard():
self.result_raw = result
self.result = to_device(self.device, self.result, self.dtype_cast_back_dict)
i = -1
for out in traverse_container(self.result_raw):
i += 1
if (
isinstance(out, torch.Tensor)
and out.dtype in self.dtype_cast_back_dict.keys()
):
print(
f"OpDtypeCastHook: {self.name:<50} {i}th out {out.dtype} -> {self.dtype_cast_back_dict[out.dtype]} config:{self.dtype_cast_config_str}"
)
54 changes: 54 additions & 0 deletions op_tools/test/test_op_dtype_cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, DeepLink.
import torch
import ditorch

import op_tools
import os


def f():
a = torch.rand(10, 20, requires_grad=True).cuda().half()
b = a * 2
c = b + a
d = c - a
e = d / c
sorted, indices = e.sort() # return torch.return_type.sort
y = sorted[2:8:2, ::3]
y.sum().backward()

m = torch.nn.Linear(4, 4, device="cuda").half()
x = torch.randn(3, 4, 4, device="cuda", requires_grad=True, dtype=torch.half)
y = m(x)
y.backward(torch.ones_like(y))


f()

# usage1
with op_tools.OpDtypeCast():
f()

# usage2
dtype_caster = op_tools.OpDtypeCast()
dtype_caster.start()
for i in range(3):
f()
dtype_caster.stop()


# usage3
os.environ["OP_DTYPE_CAST_DISABLE_LIST"] = "torch.Tensor.add,torch.Tensor.sub"
dtype_caster.start()
f()
dtype_caster.stop()


# usage4
os.environ["OP_DTYPE_CAST_DISABLE_LIST"] = ""
os.environ["OP_DTYPE_CAST_LIST"] = "torch.Tensor.sort" # only cast this op
os.environ["OP_DTYPE_CAST_DICT"] = (
"torch.half->torch.float32" # camb 370 not support bfloat16
)
dtype_caster.start()
f()
dtype_caster.stop()
16 changes: 15 additions & 1 deletion op_tools/test/test_opname_match.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, DeepLink.
from op_tools.utils import is_opname_match, is_inplace_op
from op_tools.utils import is_opname_match, is_inplace_op, get_dtype_cast_dict_form_str
import torch

import unittest

Expand Down Expand Up @@ -32,6 +33,19 @@ def test_inplace_op(self):
self.assertEqual(is_inplace_op("torch.Tensor.add_"), True)
self.assertEqual(is_inplace_op("torch.Tensadd"), False)

def test_get_dtype_cast_dict_from_config(self):
dtype_cast_dict = get_dtype_cast_dict_form_str(
"torch.float32->torch.float16,torch.float64->torch.float16,torch.int64->torch.int32"
)
self.assertEqual(
dtype_cast_dict,
{
torch.float32: torch.float16,
torch.float64: torch.float16,
torch.int64: torch.int32,
},
)


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,16 @@ def get_function_from_string(func_str):
attrs.append(attr)

return attrs[len(parts) - 1]


def get_dtype_cast_dict_form_str(config):
"""
'torch.float16->torch.float32,torch.bfloat16->torch.float32' -> {torch.float16:torch.float32, torch.bfloat16:torch.float32}
"""
dtype_cast_dict = dict()
if config is not None:
for item in config.split(","):
dtype_cast_dict[get_function_from_string(item.split("->")[0])] = (
get_function_from_string(item.split("->")[1])
)
return dtype_cast_dict