-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mthreads] Support base/benchmarks: add bandwidth test with torch_musa #771
Conversation
@@ -51,7 +51,7 @@ def main(config, case_config, rank, world_size, local_rank): | |||
|
|||
Melements = case_config.Melements | |||
torchsize = (Melements, 1024, 1024) | |||
tensor = torch.rand(torchsize, dtype=torch.float32).cuda() | |||
tensor = torch.rand(torchsize, dtype=torch.float32).to(local_rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为避免前后代码不一致性,此处请处理为:if mthreads in config.vendor xxx.to(local_rank), else xxx.cuda()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复b97e82b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
基础信息修改一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修复76e61a1
新增基于torch_musa支持各类带宽测试
关于torch_musa详情可参考:https://github.com/MooreThreads/torch_musa