diff --git a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py index 5b6778cc2..0e67efcf9 100755 --- a/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py +++ b/applications/DeepSpeed-Chat/dschat/rlhf/rlhf_engine.py @@ -268,23 +268,14 @@ def _init_reward(self, critic_model_name_or_path): # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory zero_stage = 0 - ds_config = get_eval_ds_config(offload=self.args.offload, + ds_config = get_eval_ds_config(offload=self.args.offload_reward_model, dtype=self.args.dtype, stage=zero_stage) - ds_config[ - 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_config[ - 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( - ) * self.args.gradient_accumulation_steps - - ds_eval_config = get_eval_ds_config(offload=False, - dtype=self.args.dtype, - stage=zero_stage) # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. - ds_eval_config[ + ds_config[ 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size - ds_eval_config[ + ds_config[ 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( ) * self.args.gradient_accumulation_steps @@ -292,7 +283,7 @@ def _init_reward(self, critic_model_name_or_path): reward_model = create_critic_model( model_name_or_path=critic_model_name_or_path, tokenizer=self.tokenizer, - ds_config=ds_eval_config, + ds_config=ds_config, num_padding_at_beginning=self.args.num_padding_at_beginning, rlhf_training=True, dropout=self.args.critic_dropout, diff --git a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py index 97d3bff15..050819a22 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py @@ -10,7 +10,7 @@ AutoModel, ) from huggingface_hub import snapshot_download -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig from dschat.utils.model.reward_model import RewardModel from dschat.utils.utils import load_state_dict_into_model, print_rank_0 diff --git a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py index a5be5671b..1378dc4e6 100644 --- a/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py +++ b/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py @@ -246,6 +246,9 @@ def parse_args(): '--offload_reference_model', action='store_true', help='Enable ZeRO Offload techniques for reference model') + parser.add_argument('--offload_reward_model', + action='store_true', + help='Enable ZeRO Offload techniques for reward model') parser.add_argument( '--actor_zero_stage', type=int, diff --git a/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py b/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py index eb9db9428..1407c1dfc 100755 --- a/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py +++ b/applications/DeepSpeed-VisualChat/utils/model/modeling_dsvl.py @@ -15,7 +15,7 @@ os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) import data.DST as DST # default special tokens from torch.utils.data import DataLoader -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig import numpy as np from .vis_proj import VisProjection_vit, VisProjection_perceiver diff --git a/inference/huggingface/zero_inference/run_model.py b/inference/huggingface/zero_inference/run_model.py index 230d601cb..d0e16eca3 100644 --- a/inference/huggingface/zero_inference/run_model.py +++ b/inference/huggingface/zero_inference/run_model.py @@ -19,7 +19,7 @@ from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM, BloomForCausalLM, OPTForCausalLM, LlamaForCausalLM, ) -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations.deepspeed import HfDeepSpeedConfig from utils import (GB, add_model_hooks, cache_bytes, get_filename, get_quant_config, hidden_bytes, meta_to_cpu, model_bytes, write_benchmark_log)