Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LoRA Fuse/Unfuse in Hybrid Engine #3563

Merged
merged 20 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepspeed/module_inject/containers/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)

def get_lora_matched_pair(self):
"""
Necessary to implement for `HybridEngineContainer`
"""
fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params()
ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)]
return ret

def set_lora_params(self):
"""
Necessary to implement for `HybridEngineContainer`
Expand Down
23 changes: 23 additions & 0 deletions deepspeed/module_inject/containers/features/hybrid_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,29 @@ def set_lora_params(self):
"""
raise NotImplementedError("A set_lora_params() function must be defined for the relevant parameters.")

@abstractmethod
def get_lora_matched_pair(self):
"""Get the pair of lora params and its matched model parameters."""
raise NotImplementedError("get_lora_matched_pair() must be defined for the relevant parameters.")

def fuse_lora(self):
"""Fuse the LoRA parameters for the inference mode."""
for maybe_lora_param, param in self.get_lora_matched_pair():
if len(maybe_lora_param) == 3:
lora_right_weight, \
lora_left_weight, \
lora_scaling = maybe_lora_param
param.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())

def unfuse_lora(self):
"""Unfuse the LoRA parameters for the training mode."""
for maybe_lora_param, param in self.get_lora_matched_pair():
if len(maybe_lora_param) == 3:
lora_right_weight, \
lora_left_weight, \
lora_scaling = maybe_lora_param
param.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())

def apply_tensor_parallelism(self, mp_replace, reversed_dim=False):
"""
Add support for reversed dim in tensor parallelism. If necessary, override
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/module_inject/containers/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def set_lora_params(self):
]
]

def get_lora_matched_pair(self):
fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw),
(k_lora, self.kw), (v_lora, self.vw)]
return ret

def set_q_k_v(self):
"""
Necessary to implement for `HybridSplitQKVContainer`
Expand Down
24 changes: 22 additions & 2 deletions deepspeed/module_inject/containers/gptneo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .base import *
from .features.meta_tensor import MetaTensorContainer
from .features.hybrid_engine import HybridEngineContainer
from .features.split_qkv import HybridSplitQKVContainer
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
import torch
from torch.nn.parameter import Parameter
Expand All @@ -17,7 +17,7 @@
from ..policy import maybe_get_lora


class DS_GPTNEOContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):
class DS_GPTNEOContainer(MetaTensorContainer, HybridSplitQKVContainer, BaseTransformerContainer):

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -42,6 +42,26 @@ def set_lora_params(self):
]
]

def set_q_k_v(self):
"""
Necessary to implement for `HybridSplitQKVContainer`
"""
self.qw = self.policy.client_module.attn.attention.q_proj.weight
self.qb = None
self.kw = self.policy.client_module.attn.attention.k_proj.weight
self.kb = None
self.vw = self.policy.client_module.attn.attention.v_proj.weight
self.vb = None

def get_lora_matched_pair(self):
"""
Necessary to implement for `HybridEngineContainer`
"""
fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw),
(k_lora, self.kw), (v_lora, self.vw)]
return ret

def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'attn.attention.q_proj.weight', \
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/module_inject/containers/gptneox.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def create_module(self, config=None):

return self.module

def get_lora_matched_pair(self):
"""
Necessary to implement for `HybridEngineContainer`
"""
fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params()
ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)]
return ret

def set_lora_params(self):
"""
Necessary to implement for `HybridEngineContainer`
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/module_inject/containers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def set_lora_params(self):
]
]

def get_lora_matched_pair(self):
up_proj_lora, gate_proj_lora, down_proj_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
ret = [(up_proj_lora, self.inter_up_w), (gate_proj_lora, self.inter_gate_w), (down_proj_lora, self._4hh_w),
(out_lora, self.dense_w), (q_lora, self.qw), (k_lora, self.kw), (v_lora, self.vw)]
return ret

def set_q_k_v(self):
"""
Necessary to implement for `HybridSplitQKVContainer`
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/module_inject/containers/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def set_q_k_v(self):
self.vw = self.policy.client_module.self_attn.v_proj.weight
self.vb = self.policy.client_module.self_attn.v_proj.bias

def get_lora_matched_pair(self):
fc1_lora, fc2_lora, q_lora, k_lora, v_lora, out_lora = self.get_lora_params()
ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (out_lora, self.dense_w), (q_lora, self.qw),
(k_lora, self.kw), (v_lora, self.vw)]
return ret

def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'self_attn.q_proj.weight', \
Expand Down
5 changes: 4 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,7 +2303,10 @@ def _get_gradients_for_reduction(self):
expert_grads[key] = []

for param_name, param in self.module.named_parameters():
if param.grad is None and param.requires_grad:
if not param.requires_grad:
continue

if param.grad is None:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
Expand Down
40 changes: 18 additions & 22 deletions deepspeed/runtime/hybrid_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,19 @@ def _replace_linear_layer(r_module, parent_type=None, prev_type=None):

def new_inference_container(self, orig_layer, policy_cls, layer_id):
policy = policy_cls(orig_layer, inference=True)

if self._config.fp16_enabled:
inference_dtype = torch.float16
elif self._config.bfloat16_enabled:
inference_dtype = torch.bfloat16
else:
inference_dtype = torch.float32

_container = policy_to_ds_container(
policy=policy,
config=DeepSpeedInferenceConfig(
set_empty_params=True,
dtype=torch.float16 if self._config.fp16_enabled else torch.float32,
dtype=inference_dtype,
max_out_tokens=self._config.hybrid_engine.max_out_tokens,
min_out_tokens=self._config.hybrid_engine.max_out_tokens,
transposed_mode=True,
Expand Down Expand Up @@ -127,31 +135,19 @@ def populate_all_inference_policies(self):
OPTLearnedPositionalEmbedding: (OPTEmbedding, )
})

def _fuse_lora(self, params, lora_params):
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
if len(lora_param) == 3:
lora_right_weight, \
lora_left_weight, \
lora_scaling = lora_param
weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
def _fuse_lora_layer(self, layer_id):
self._inference_containers[layer_id].fuse_lora()

def fuse_lora_weight(self):
for layer_id in range(len(self.layer_params)):
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
self._fuse_lora_layer(layer_id)

def _unfuse_lora(self, params, lora_params):
maybe_has_lora_params = [p for p in params if len(p.shape) > 1]
for lora_param, weight in zip(lora_params, maybe_has_lora_params):
if len(lora_param) == 3:
lora_right_weight, \
lora_left_weight, \
lora_scaling = lora_param
weight.data -= lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
def _unfuse_lora_layer(self, layer_id):
self._inference_containers[layer_id].unfuse_lora()

def unfuse_lora_weight(self):
for layer_id in range(len(self.layer_params)):
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
self._unfuse_lora_layer(layer_id)

def unfuse_lora_weight_non_pinned(self):
for layer_id in range(len(self.layer_params)):
Expand All @@ -160,7 +156,7 @@ def unfuse_lora_weight_non_pinned(self):
non_active_params.extend(non_active_lora_params)

with GatheredParameters(non_active_params):
self._unfuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
self._unfuse_lora_layer(layer_id)

def retake_inference_cache(self):
if self._config.hybrid_engine.release_inference_cache:
Expand Down Expand Up @@ -204,7 +200,7 @@ def generate(self, *inputs, **kwargs):
for layer_id in range(lg * partition_size,
min(len(self.layer_params), (lg + 1) * partition_size), 1):
if len(self.all_lora_params) > 0:
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
self._fuse_lora_layer(layer_id)

if self.mpu is not None:
self._inference_containers[layer_id].apply_tensor_parallelism(self.mp_replace,
Expand Down Expand Up @@ -375,7 +371,7 @@ def run_forward(*inputs, **kwargs):
if len(self.all_lora_params) > 0:
# Use the is_lora_fused flag to prevent multiple fusion in Z3 with non-pinned memory
if not self.is_lora_fused:
self._fuse_lora(self.layer_params[layer_id], self.lora_params[layer_id])
self._fuse_lora_layer(layer_id)
# Set the is_lora_fused to true when reaching the last layer
if layer_id == len(self.layer_params) - 1:
self.is_lora_fused = True
Expand Down
Loading