Skip to content

Commit

Permalink
tied weights for adapters (#6928)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: arendu <[email protected]>

* wip

Signed-off-by: arendu <[email protected]>

* tied lora weights

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lora and adapter tying

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* layer selection wip

Signed-off-by: arendu <[email protected]>

* added layer selection

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make ln optional

Signed-off-by: arendu <[email protected]>

* layer section

Signed-off-by: arendu <[email protected]>

* small dim pos emb

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adapter w/o layer norm and weight tying

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* eval works with all pos embeddings strategy

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: arendu <[email protected]>

* mlp transform of pos embeddings

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* zero init position bias

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* merge

Signed-off-by: arendu <[email protected]>

* minor fix

Signed-off-by: arendu <[email protected]>

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] authored Aug 18, 2023
1 parent b479c90 commit f26b2f1
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,32 @@ model:
type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter'
adapter_dim: 32
adapter_dropout: 0.0
norm_position: 'pre' # This can be set to 'pre' or 'post', 'pre' is normally what is used.
norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used.
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm']
layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
adapter_dim: 32
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True

# Used for p-tuning peft training
p_tuning:
virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence
bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck
embedding_dim: 1024 # the size of the prompt encoder embeddings
init_std: 0.023

ia3_tuning:
layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers

data:
train_ds:
Expand Down
12 changes: 10 additions & 2 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@

from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import (
MegatronGPTAdapterModel,
MegatronGPTAdapterModelWeightTying,
MegatronGPTAdapterPTuningModel,
MegatronGPTIA3Model,
MegatronGPTLoRAModel,
MegatronGPTLoRAModelWeightTying,
MegatronGPTPTuningModel,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTModel
Expand Down Expand Up @@ -114,15 +116,21 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):

def _get_peft_scheme(cfg):
if cfg.peft.peft_scheme == "adapter":
peft_cls = MegatronGPTAdapterModel
if cfg.peft.adapter_tuning.weight_tying:
peft_cls = MegatronGPTAdapterModelWeightTying
else:
peft_cls = MegatronGPTAdapterModel
elif cfg.peft.peft_scheme == "ia3":
peft_cls = MegatronGPTIA3Model
elif cfg.peft.peft_scheme == "ptuning":
peft_cls = MegatronGPTPTuningModel
elif cfg.peft.peft_scheme == "adapter_and_ptuning":
peft_cls = MegatronGPTAdapterPTuningModel
elif cfg.peft.peft_scheme == "lora":
peft_cls = MegatronGPTLoRAModel
if cfg.peft.lora_tuning.weight_tying:
peft_cls = MegatronGPTLoRAModelWeightTying
else:
peft_cls = MegatronGPTLoRAModel
else:
raise RuntimeError("Invalid Peft scheme")
return peft_cls
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
AdapterName,
InfusedAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
ParallelLinearAdapterWeightTyingConfig,
PromptEncoderAdapterConfig,
)
from nemo.core.classes.mixins import adapter_mixins
Expand Down Expand Up @@ -131,7 +133,37 @@ def setup_optimizer_param_groups(self):
logging.info(f"Optimizer groups set:\n{self.summarize()}")


class MegatronGPTAdapterModel(MegatronGPTPEFTModel):
class MegatronGPTLayerwisePEFTModel(MegatronGPTPEFTModel):
def __init__(
self, cfg: DictConfig, trainer: Trainer,
):
super().__init__(cfg, trainer)

def init_peft_modules(self):
"""
Randomly initialize the peft params and add them to the appropriate modules.
"""
assert len(self.peft_name_keys) > 0, "peft_name_keys have not been set no PEFT modules will be added"
assert len(self.name_key_to_cfg) > 0, "name_key_to_cfg has not been set no PEFT modules will be added"
logging.info(f"Before adding PEFT params:\n{self.summarize()}")
for layer in self.model.language_model.encoder.layers:
if layer.layer_number in self.layer_selection:
for _, module in layer.named_modules():
if isinstance(module, adapter_mixins.AdapterModuleMixin):
for peft_key in self.peft_name_keys:
peft_cfg = self.name_key_to_cfg[peft_key]
if (
model_utils.import_class_by_path(peft_cfg._target_)
in module.get_accepted_adapter_types()
):
module.add_adapter(
name=peft_key, cfg=peft_cfg,
)
logging.info(f"After adding PEFT params:\n{self.summarize()}")
return True


class MegatronGPTAdapterModel(MegatronGPTLayerwisePEFTModel):
"""
MegatronGPTAdapterLearningModel is a model that combines a base model (GPTSFTModel) with a adapters.
This class only supports the canonical Adapter training described in Houlsby et al. (https://arxiv.org/pdf/1902.00751.pdf)
Expand All @@ -151,7 +183,6 @@ def __init__(
AdapterName.POST_ATTN_ADAPTER,
]
adapter_tuning_cfg = cfg.peft.adapter_tuning

adapter_cfg = ParallelLinearAdapterConfig(
in_features=cfg.hidden_size,
out_features=cfg.hidden_size,
Expand All @@ -167,10 +198,73 @@ def __init__(
for k in self.peft_name_keys:
self.name_key_to_cfg[k] = adapter_cfg

self.layer_selection = adapter_tuning_cfg.get("layer_selection", None)
if self.layer_selection is None:
self.layer_selection = list(range(1, cfg.num_layers + 1))
super().__init__(cfg, trainer)


class MegatronGPTIA3Model(MegatronGPTPEFTModel):
class MegatronGPTAdapterModelWeightTying(MegatronGPTLayerwisePEFTModel):
"""
TODO
"""

def __init__(
self, cfg: DictConfig, trainer: Trainer,
):
self.peft_name_keys = [
AdapterName.PRE_ATTN_ADAPTER,
AdapterName.POST_ATTN_ADAPTER,
]
adapter_tuning_cfg = cfg.peft.adapter_tuning

adapter_cfg = ParallelLinearAdapterWeightTyingConfig(
in_features=cfg.hidden_size,
out_features=cfg.hidden_size,
dim=adapter_tuning_cfg.adapter_dim,
norm_position=adapter_tuning_cfg.get("norm_position", "pre"),
norm_type=adapter_tuning_cfg.get("norm_type", "mixedfusedlayernorm"),
column_init_method=adapter_tuning_cfg.get("column_init_method", "xavier"),
row_init_method=adapter_tuning_cfg.get("row_init_method", "zero"),
dropout=adapter_tuning_cfg.adapter_dropout,
num_position_embeddings=cfg.num_layers * 2,
dim_position_embeddings=cfg.hidden_size,
position_embedding_strategy=adapter_tuning_cfg.get("position_embedding_strategy", None),
)

self.name_key_to_cfg = {}
for k in self.peft_name_keys:
self.name_key_to_cfg[k] = adapter_cfg

self.layer_selection = adapter_tuning_cfg.get("layer_selection", None)
if self.layer_selection is None:
self.layer_selection = list(range(1, cfg.num_layers + 1))
super().__init__(cfg, trainer)
self.tie_weights()

def tie_weights(self,):
pos_idx = 0
layer0 = self.model.language_model.encoder.layers[0]
for adapter_name in layer0.adapter_layer:
adapter = layer0.get_adapter_module(adapter_name)
print(adapter_name, pos_idx)
adapter.set_position(pos_idx)
pos_idx += 1

for layer in self.model.language_model.encoder.layers[1:]:
for adapter_name in layer.adapter_layer:
print(adapter_name, pos_idx)
adapter_l = layer.get_adapter_module(adapter_name)
adapter_0 = layer0.get_adapter_module(adapter_name)
if hasattr(adapter_0, "layer_norm"):
lnorm = adapter_0.layer_norm
else:
lnorm = None
adapter_l.tie_weights(pos_idx, adapter_0)
pos_idx += 1


class MegatronGPTIA3Model(MegatronGPTLayerwisePEFTModel):
"""
MegatronGPTInfusedAdapterModel is a model that combines a base model (GPTSFTModel) with a "Infused Adapter that can Inhibiting and Amplify Inner Activations", known as IA3.
This class supports the addition of IA3 into a transformer based LM as described in Liu et al. (https://arxiv.org/pdf/2205.05638.pdf)
Expand Down Expand Up @@ -330,7 +424,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens


class MegatronGPTLoRAModel(MegatronGPTPEFTModel):
class MegatronGPTLoRAModel(MegatronGPTLayerwisePEFTModel):
"""
MegatronGPTLoRAModel is a model that combines a base model (GPTSFTModel) with a low-rank adapters.
The lora adapters will be added in `nemo/collections/nlp/modules/common/megatron/attention.py`
Expand Down Expand Up @@ -360,8 +454,8 @@ def __init__(
in_features=cfg.hidden_size,
out_features=3 * projection_size,
dim=lora_cfg.adapter_dim,
norm_position="none",
norm_type="none",
norm_position=None,
norm_type=None,
activation="identity",
column_init_method=lora_cfg.get("column_init_method", "normal"),
row_init_method=lora_cfg.get("row_init_method", "zero"),
Expand All @@ -372,5 +466,87 @@ def __init__(
self.name_key_to_cfg = {}
for k in self.peft_name_keys:
self.name_key_to_cfg[k] = adapter_cfg
self.layer_selection = lora_cfg.get("layer_selection", None)
if self.layer_selection is None:
self.layer_selection = list(range(1, cfg.num_layers + 1))
super().__init__(cfg, trainer)


class MegatronGPTLoRAModelWeightTying(MegatronGPTLayerwisePEFTModel):
"""
TODO
"""

def __init__(
self, cfg: DictConfig, trainer: Trainer,
):
self.peft_name_keys = [
AdapterName.LORA_KQV_ADAPTER,
]
lora_cfg = cfg.peft.lora_tuning
if cfg.get("kv_channels", None) is None:
assert (
cfg.hidden_size % cfg.num_attention_heads == 0
), 'hidden_size must be divisible by num_attention_heads if kv_channels is None'
kv_channels = cfg.hidden_size // cfg.num_attention_heads
else:
kv_channels = cfg.kv_channels
projection_size = kv_channels * cfg.num_attention_heads
position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)
if position_embedding_strategy is None:
dim_position_embeddings = 0
elif position_embedding_strategy == "add":
dim_position_embeddings = cfg.hidden_size
elif position_embedding_strategy == "biasadd":
dim_position_embeddings = 3 * projection_size
elif position_embedding_strategy == "concat":
dim_position_embeddings = lora_cfg.adapter_dim
elif position_embedding_strategy == "mlpconcat":
dim_position_embeddings = lora_cfg.adapter_dim
else:
raise RuntimeError(f"Unknown position embedding strategy {position_embedding_strategy} for tied weights")

adapter_cfg = LoraKQVAdapterWeightTyingConfig(
in_features=cfg.hidden_size,
out_features=3 * projection_size,
dim=lora_cfg.adapter_dim,
norm_position=None,
norm_type=None,
activation="identity",
column_init_method=lora_cfg.get("column_init_method", "normal"),
row_init_method=lora_cfg.get("row_init_method", "zero"),
gather_output=False,
dropout=lora_cfg.adapter_dropout,
num_position_embeddings=cfg.num_layers,
dim_position_embeddings=dim_position_embeddings,
position_embedding_strategy=position_embedding_strategy,
)

self.name_key_to_cfg = {}
for k in self.peft_name_keys:
self.name_key_to_cfg[k] = adapter_cfg
self.layer_selection = lora_cfg.get("layer_selection", None)
if self.layer_selection is None:
self.layer_selection = list(range(1, cfg.num_layers + 1))
super().__init__(cfg, trainer)
self.tie_weights()

def tie_weights(self,):
pos_idx = 0
layer0 = self.model.language_model.encoder.layers[0]
for adapter_name in layer0.self_attention.adapter_layer:
adapter = layer0.self_attention.get_adapter_module(adapter_name)
print(adapter_name, pos_idx)
adapter.set_position(pos_idx)
pos_idx += 1

for layer in self.model.language_model.encoder.layers[1:]:
for adapter_name in layer.self_attention.adapter_layer:
print(adapter_name, pos_idx)
adapter_l = layer.self_attention.get_adapter_module(adapter_name)
adapter_0 = layer0.self_attention.get_adapter_module(adapter_name)
position_embeddings_0 = None
if adapter_0.position_embedding_strategy:
position_embeddings_0 = adapter_0.position_embeddings
adapter_l.tie_weights(pos_idx, adapter_0)
pos_idx += 1
Loading

0 comments on commit f26b2f1

Please sign in to comment.