-
Notifications
You must be signed in to change notification settings - Fork 0
/
gpu.py
50 lines (43 loc) · 1.51 KB
/
gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
import os
import time
from transformers import LlamaForCausalLM, LlamaTokenizer
import setproctitle
setproctitle.setproctitle("~/anaconda3/envs/LLM/bin/python")
time_start = time.time()
torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
size = 1024
total_time = 30 #second
time_need = False
class randomdata(Dataset):
def __init__(self, size):
self.data = torch.randn(size, size, size).to('cuda')
self.len = 1
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class get_gpu(nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc = nn.Linear(input_size, input_size)
def forward(self, x):
for _ in range(1000):
y = self.fc(self.fc(x) @ self.fc(x))
dataset = randomdata(size)
data_load = DataLoader(dataset=dataset, batch_size=8, sampler=DistributedSampler(dataset))
_ = torch.randn((16, 1024, 1024, 1024), device=device)
model = get_gpu(size).to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
for data in data_load:
while(True):
time_end = time.time()
if time_need and (time_end - time_start) > total_time:
exit()
model(data)