Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zgc/ditorch refactor #76

Merged
merged 15 commits into from
Oct 31, 2024
2 changes: 1 addition & 1 deletion .github/workflows/runs_on_ascend.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${GITHUB_JOB} && cp -R Build ${GITHUB_JOB} && cd ${GITHUB_JOB}
export PYTHONPATH=${PYTHONPATH}:$PWD
echo "start to test"
bash ci/run_op_tools_test_cases.sh
bash ci/run_individual_test_cases.sh

Test_use_pytorch_test_case:
name: run pytorch test case on ascend
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/runs_on_camb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
echo "start to test"
source /mnt/cache/share/platform/env/ditorch_env
export PYTHONPATH=${PYTHONPATH}:$PWD
srun -p camb_mlu370_m8 -n 1 --gres=mlu:1 bash ci/run_op_tools_test_cases.sh
srun -p camb_mlu370_m8 -n 1 --gres=mlu:1 bash ci/run_individual_test_cases.sh

Test_use_pytorch_test_case:
name: run pytorch test case on camb
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/runs_on_nv.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
set -ex
cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER} && rm -rf ${GITHUB_JOB} && cp -R Build ${GITHUB_JOB} && cd ${GITHUB_JOB}
echo "start to test"
srun --job-name=${GITHUB_JOB} bash -c "source /mnt/cache/share/platform/env/ditorch_env && cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${GITHUB_JOB} && export PYTHONPATH=${PYTHONPATH}:$PWD && bash ci/run_op_tools_test_cases.sh"
srun --job-name=${GITHUB_JOB} bash -c "source /mnt/cache/share/platform/env/ditorch_env && cd ${DEEPLINK_PATH}/${GITHUB_RUN_NUMBER}/${GITHUB_JOB} && export PYTHONPATH=${PYTHONPATH}:$PWD && bash ci/run_individual_test_cases.sh"

Test_use_pytorch_test_case:
name: run pytorch test case on nv
Expand Down
11 changes: 11 additions & 0 deletions ci/run_individual_test_cases.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
date
find ditorch op_tools -name test*.py | xargs -I {} bash -c ' echo "start run {}";date;time python {} && echo "Test {} PASSED\n\n\n" || echo "Test {} FAILED\n\n\n"' 2>&1 | tee test_individual_cases.log

# Check if any tests failed
if grep -Eq "FAILED" test_individual_cases.log; then
echo "tests failed"
exit 1
else
echo "all tests passed"
exit 0
fi
11 changes: 0 additions & 11 deletions ci/run_op_tools_test_cases.sh

This file was deleted.

35 changes: 16 additions & 19 deletions ditorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,33 @@
# Copyright (c) 2024, DeepLink.
import os

framework = None
adapter = None
try:
from ditorch import torch_npu_adapter

framework = "torch_npu:" + torch_npu_adapter.torch_npu.__version__
except Exception as e: # noqa: F841
import ditorch.torch_npu_adapter as adapter # noqa: F811
except ImportError as e: # noqa: F841
pass
try:
from ditorch import torch_dipu_adapter # noqa: F401

framework = "torch_dipu" # torch_dipu has not __version__ attr

except Exception as e: # noqa: F841
pass

try:
from ditorch import torch_mlu_adapter
import ditorch.torch_dipu_adapter as adapter # noqa: F811

framework = "torch_mlu:" + torch_mlu_adapter.torch_mlu.__version__
except Exception as e: # noqa: F841
except ImportError as e: # noqa: F841
pass

try:
from ditorch import torch_biren_adapter
import ditorch.torch_mlu_adapter as adapter # noqa: F811
except ImportError as e: # noqa: F841
pass

framework = "torch_br:" + torch_biren_adapter.torch_br.__version__
except Exception as e: # noqa: F841
try:
import ditorch.torch_biren_adapter as adapter # noqa: F811
except ImportError as e: # noqa: F841
pass


from ditorch import common_adapter # noqa: F401,E402

print(f"ditorch.framework: {framework} pid: {os.getpid()}")
if adapter is not None and int(os.getenv("DITORCH_DISABLE_MOCK", "0")) <= 0:
adapter.mock()
common_adapter.mock_common()

print(f"ditorch: {adapter.arch} {adapter.framework.__name__}:{adapter.framework.__version__} pid: {os.getpid()}")
5 changes: 5 additions & 0 deletions ditorch/common_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .common_mock import mock_tensor_device


def mock_common():
mock_tensor_device()
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import torch
import os

if torch.__version__ >= "2.0.0":

def mock_tensor_device():
if torch.__version__ < "2.0.0":
return
from torch.overrides import TorchFunctionMode, resolve_name

class DeviceMock(TorchFunctionMode):
Expand All @@ -15,8 +18,11 @@ def __torch_function__(self, func, types, args, kwargs=None):
name = None
result = func(*args, **(kwargs or {}))
if name == "torch.Tensor.device.__get__":
if result.type != "cpu":
result = torch.device("cuda" + (":" + str(result.index)) if result.index is not None else "")
if result.type not in ["cpu", "mps", "xpu", "xla", "meta"]:
device_str = "cuda"
if result.index is not None:
device_str += f":{result.index}"
result = torch.device(device_str)
if name == "torch.Tensor.__repr__":
device = args[0].device
if device.type != "cpu":
Expand Down
149 changes: 149 additions & 0 deletions ditorch/test/individual/test_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
import torch
import ditorch # noqa: F401
import torch.distributed as dist
import torch.multiprocessing as mp
import unittest

world_size = torch.cuda.device_count()


# 分布式环境的初始化
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="hccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)


# 清理函数
def cleanup():
dist.destroy_process_group()


# 分布式进程启动器
def run_distributed_test(test_func, world_size):
mp.spawn(test_func, args=(world_size,), nprocs=world_size, join=True)


# 测试 all_reduce 的函数
def all_reduce_test(rank, world_size):
setup(rank, world_size)

tensor = torch.ones(10).float().cuda(rank) * rank
print(f"Rank {rank} before all_reduce: {tensor}")

dist.all_reduce(tensor, op=dist.ReduceOp.AVG)

print(f"Rank {rank} after all_reduce: {tensor}")
# 每个进程的张量应该都是相同的,即 0+1 = 1 的和
expected_tensor = torch.ones(10).float().cuda(rank) * ((sum(range(world_size))) / world_size)
print(rank, "expected_tensor:", expected_tensor)
print(rank, "tensor:", tensor)
assert torch.equal(tensor, expected_tensor), f"Rank {rank} all_reduce failed!"
cleanup()


# 测试 reduce_scatter 的函数
def reduce_scatter_test(rank, world_size):
setup(rank, world_size)

# if world_size is 2,
# rank 0: [tensor([0]) tensor([1])]
# rank 1: [tensor([2]) tensor([3])]
input_tensors = list(torch.arange(world_size * rank, world_size + world_size * rank).float().cuda(rank).chunk(world_size))
output_tensor = torch.zeros(1).float().cuda(rank)
print(f"Rank {rank} before reduce_scatter: {input_tensors}")

dist.reduce_scatter(output_tensor, input_tensors, op=dist.ReduceOp.AVG)
# expected_tensor:
# rank 0: tensor([1])
# rank 1: tensor([2])
# expected_tensor = torch.tensor((sum((range(world_size)) * world_size) + world_size * rank) / world_size).cuda(rank)
expected_tensor = torch.tensor([(sum([i * world_size for i in range(world_size)]) + world_size * rank) / world_size]).cuda(rank)
print(f"Rank {rank} after reduce_scatter: {output_tensor}")
assert torch.equal(output_tensor, expected_tensor), f"Rank {rank} reduce_scatter failed!"
cleanup()


def _reduce_scatter_tensor_test(rank, world_size, func):
setup(rank, world_size)

# if world_size is 2,
# rank 0: tensor([0, 1])
# rank 1: tensor([2, 3])
input_tensor = torch.arange(world_size * rank, world_size + world_size * rank).float().cuda(rank)
output_tensor = torch.zeros(int(input_tensor.numel() / world_size)).float().cuda(rank)
print(f"Rank {rank} before reduce_scatter: {input_tensor}")

func(output_tensor, input_tensor, op=dist.ReduceOp.AVG)
# expected_tensor:
# rank 0: tensor([1])
# rank 1: tensor([2])
# expected_tensor = torch.tensor((sum((range(world_size)) * world_size) + world_size * rank) / world_size).cuda(rank)
expected_tensor = torch.tensor([(sum([i * world_size for i in range(world_size)]) + world_size * rank) / world_size]).cuda(rank)
print(f"Rank {rank} after reduce_scatter: {output_tensor}")
assert torch.equal(output_tensor, expected_tensor), f"Rank {rank} reduce_scatter failed!"
cleanup()


def reduce_scatter_tensor_test(rank, world_size):
_reduce_scatter_tensor_test(rank, world_size, dist.reduce_scatter_tensor)


def reduce_scatter_base_test(rank, world_size):
_reduce_scatter_tensor_test(rank, world_size, dist._reduce_scatter_base)


# 测试 reduce 的函数
def reduce_test(rank, world_size):
setup(rank, world_size)

tensor = torch.ones(10).cuda(rank) * (rank + 1)
print(f"Rank {rank} before reduce: {tensor}")

dist.reduce(tensor, dst=0, op=dist.ReduceOp.AVG)

if rank == 0:
expected_tensor = torch.ones(10).cuda(rank) * sum(range(1, world_size + 1)) / world_size
print(f"Rank {rank} after reduce (on root): {tensor}")
assert torch.equal(tensor, expected_tensor), "Reduce failed on root!"
else:
print(f"Rank {rank} after reduce (non-root): {tensor}")
# 非 root 进程的结果保持不变
expected_tensor = torch.ones(10).cuda(rank) * (rank + 1)
print(f"Rank {rank} after reduce (no root) expected_tensor: {expected_tensor}")
assert torch.equal(tensor, expected_tensor), "Reduce failed on non-root!"
cleanup()


class TestDist(unittest.TestCase):

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_all_reduce(self, world_size=world_size):
"""pytest wrapper for all_reduce test"""
run_distributed_test(all_reduce_test, world_size)

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_reduce_scatter(self, world_size=world_size):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_test, world_size)

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_reduce_scatter_tensor(self, world_size=world_size):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_tensor_test, world_size)

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test__reduce_scatter_base(self, world_size=world_size):
"""pytest wrapper for reduce_scatter test"""
run_distributed_test(reduce_scatter_base_test, world_size)

@unittest.skipIf(world_size < 2, "Communication test requires at least two cards")
def test_reduce(self, world_size=world_size):
"""pytest wrapper for reduce test"""
run_distributed_test(reduce_test, world_size)


if __name__ == "__main__":
unittest.main()
30 changes: 15 additions & 15 deletions ditorch/test/summary_test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,23 @@ def write_test_info_to_json(test_infos, pytorch_test_result): # noqa: C901
skipped_test_case = {}
failed_test_case = {}
for info in test_infos:
if info['file'] not in passed_test_case:
passed_test_case[info['file']] = []
if info['file'] not in skipped_test_case:
skipped_test_case[info['file']] = []
if info['file'] not in failed_test_case:
failed_test_case[info['file']] = []
if info["file"] not in passed_test_case:
passed_test_case[info["file"]] = []
if info["file"] not in skipped_test_case:
skipped_test_case[info["file"]] = []
if info["file"] not in failed_test_case:
failed_test_case[info["file"]] = []

case_name = info["classname"] + "." + info["name"]

if info["status"] == "passed":
if case_name not in passed_test_case[info['file']]:
if case_name not in passed_test_case[info["file"]]:
passed_test_case[info["file"]].append(case_name)
elif info["status"] == "skipped":
if case_name not in skipped_test_case[info['file']]:
if case_name not in skipped_test_case[info["file"]]:
skipped_test_case[info["file"]].append(case_name)
elif info["status"] == "error":
if case_name not in failed_test_case[info['file']]:
if case_name not in failed_test_case[info["file"]]:
failed_test_case[info["file"]].append(case_name)

passed_case_file_name = pytorch_test_result + "/passed_test_case.json"
Expand All @@ -113,9 +113,9 @@ def write_test_info_to_json(test_infos, pytorch_test_result): # noqa: C901

for info in test_infos:
case_name = info["classname"] + "." + info["name"]
if info['file'] in all_test_case.keys():
if case_name in all_test_case[info['file']]:
all_test_case[info['file']].remove(case_name)
if info["file"] in all_test_case.keys():
if case_name in all_test_case[info["file"]]:
all_test_case[info["file"]].remove(case_name)
with open(never_device_tested_case_file_name, "w") as f:
f.write(json.dumps(all_test_case))

Expand All @@ -129,9 +129,9 @@ def write_test_info_to_json(test_infos, pytorch_test_result): # noqa: C901

for info in test_infos:
case_name = info["classname"] + "." + info["name"]
if info['file'] in all_test_case.keys():
if case_name in all_test_case[info['file']]:
all_test_case[info['file']].remove(case_name)
if info["file"] in all_test_case.keys():
if case_name in all_test_case[info["file"]]:
all_test_case[info["file"]].remove(case_name)
with open(never_cpu_tested_case_file_name, "w") as f:
f.write(json.dumps(all_test_case))

Expand Down
3 changes: 0 additions & 3 deletions ditorch/torch_biren_adapter.py

This file was deleted.

10 changes: 10 additions & 0 deletions ditorch/torch_biren_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024, DeepLink.
import torch_br # noqa: F401


def mock():
from torch_br.contrib import transfer_to_supa # noqa: F401


framework = torch_br
arch = "biren"
3 changes: 0 additions & 3 deletions ditorch/torch_dipu_adapter.py

This file was deleted.

12 changes: 12 additions & 0 deletions ditorch/torch_dipu_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2024, DeepLink.
import torch # noqa: F401
import torch_dipu # noqa: F401


def mock():
if not hasattr(torch_dipu, "__version__"):
torch_dipu.__version__ = torch.__version__


framework = torch_dipu
arch = torch_dipu.vendor_type
Loading