Skip to content

Commit

Permalink
support DeepSpeedChat to run on different device besides cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 committed Sep 20, 2023
1 parent 902a0f6 commit 89c20e9
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand Down Expand Up @@ -190,10 +191,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from utils.model.model_utils import create_hf_model
from utils.utils import load_hf_tokenizer
from deepspeed import get_accelerator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,7 +195,7 @@ def prompt_eval(args, model_baseline, model_fintuned, tokenizer, device,
def main():
args = parse_args()

device = torch.device("cuda:0")
device = torch.device(get_accelerator().device_name()+":0")

tokenizer = load_hf_tokenizer(args.model_name_or_path_baseline,
fast_tokenizer=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand Down Expand Up @@ -185,10 +186,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
# torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from utils.model.model_utils import create_critic_model
from utils.utils import to_device
from utils.utils import load_hf_tokenizer
from deepspeed import get_accelerator


def parse_args():
Expand Down Expand Up @@ -100,7 +101,7 @@ def prepare_singlesample(prompt,
def run_pair_comparison():
args = parse_args()

device = torch.device("cuda:0")
device = torch.device(get_accelerator().device_name()+":0")

rm_model, tokenizer = load_stuff(args.model_name_or_path,
args.num_padding_at_beginning)
Expand Down Expand Up @@ -144,7 +145,7 @@ def run_pair_comparison():

def run_single_sample():
args = parse_args()
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())

rm_model, tokenizer = load_stuff(args.model_name_or_path,
args.num_padding_at_beginning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, moving_average, save_zero_three_model, load_hf_tokenizer
from utils.module.lora import convert_lora_to_linear_layer
from utils.perf import print_throughput_step3
from deepspeed.accelerator import get_accelerator

writer = None

Expand Down Expand Up @@ -417,10 +418,10 @@ def main():
args = parse_args()

if args.local_rank == -1:
device = torch.device("cuda")
device = torch.device(get_accelerator().device_name())
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
get_accelerator().set_device(args.local_rank)
device = torch.device(get_accelerator().device_name(), args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
deepspeed.init_distributed()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
Expand All @@ -18,7 +19,7 @@

def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).cuda()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(get_accelerator().current_device_name())
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import hashlib
from itertools import chain
from . import raw_datasets
from deepspeed.accelerator import get_accelerator


def get_raw_dataset(dataset_name, output_path, seed, local_rank):
Expand Down Expand Up @@ -281,7 +282,7 @@ def create_prompt_dataset(local_rank,
eval_fname = f"{output_path}/evaldata_{fname}.pt"

cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
buf_create_cache = torch.ByteTensor([not cache_found]).cuda()
buf_create_cache = torch.ByteTensor([not cache_found]).to(get_accelerator().current_device_name())
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
Expand Down
5 changes: 3 additions & 2 deletions applications/DeepSpeed-Chat/training/utils/ds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import deepspeed.comm as dist
from deepspeed.accelerator import get_accelerator

GLOBAL_BATCH_SIZE = 32
MICRO_BATCH_SIZE = 4
Expand Down Expand Up @@ -39,8 +40,8 @@ def get_train_ds_config(offload,
}
if enable_mixed_precision_lora:
zero_opt_dict["zero_quantized_nontrainable_weights"] = True
if dist.get_world_size() != torch.cuda.device_count():
zero_opt_dict["zero_hpz_partition_size"] = torch.cuda.device_count(
if dist.get_world_size() != get_accelerator().device_count():
zero_opt_dict["zero_hpz_partition_size"] = get_accelerator().device_count(
)
return {
"train_batch_size": GLOBAL_BATCH_SIZE,
Expand Down
3 changes: 2 additions & 1 deletion applications/DeepSpeed-Chat/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator
import torch.nn as nn


Expand Down Expand Up @@ -102,7 +103,7 @@ def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
get_accelerator().manual_seed_all(seed)


def get_all_reduce_mean(tensor):
Expand Down

0 comments on commit 89c20e9

Please sign in to comment.