From f26b2f1af220901b110c2abdc98c1d612d27eb3a Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Fri, 18 Aug 2023 14:13:42 -0700 Subject: [PATCH] tied weights for adapters (#6928) * wip Signed-off-by: arendu * wip Signed-off-by: arendu * tied lora weights Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lora and adapter tying Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * layer selection wip Signed-off-by: arendu * added layer selection Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make ln optional Signed-off-by: arendu * layer section Signed-off-by: arendu * small dim pos emb Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * adapter w/o layer norm and weight tying Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * eval works with all pos embeddings strategy Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: arendu * mlp transform of pos embeddings Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * zero init position bias Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * merge Signed-off-by: arendu * minor fix Signed-off-by: arendu --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../conf/megatron_gpt_peft_tuning_config.yaml | 11 +- .../tuning/megatron_gpt_peft_tuning.py | 12 +- .../megatron_gpt_peft_models.py | 188 +++++++++++++++++- .../megatron/adapters/parallel_adapters.py | 162 ++++++++++++++- .../nlp/modules/common/megatron/attention.py | 2 + .../modules/common/megatron/transformer.py | 9 +- 6 files changed, 369 insertions(+), 15 deletions(-) diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml index 890029f911ae..7098fe73abff 100755 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml @@ -85,16 +85,22 @@ model: type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' adapter_dim: 32 adapter_dropout: 0.0 - norm_position: 'pre' # This can be set to 'pre' or 'post', 'pre' is normally what is used. + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True lora_tuning: adapter_dim: 32 adapter_dropout: 0.0 column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True # Used for p-tuning peft training p_tuning: @@ -102,6 +108,9 @@ model: bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck embedding_dim: 1024 # the size of the prompt encoder embeddings init_std: 0.023 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers data: train_ds: diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py index f9f8e1ee952f..cc7cb8060be1 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py @@ -25,9 +25,11 @@ from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import ( MegatronGPTAdapterModel, + MegatronGPTAdapterModelWeightTying, MegatronGPTAdapterPTuningModel, MegatronGPTIA3Model, MegatronGPTLoRAModel, + MegatronGPTLoRAModelWeightTying, MegatronGPTPTuningModel, ) from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTModel @@ -114,7 +116,10 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): def _get_peft_scheme(cfg): if cfg.peft.peft_scheme == "adapter": - peft_cls = MegatronGPTAdapterModel + if cfg.peft.adapter_tuning.weight_tying: + peft_cls = MegatronGPTAdapterModelWeightTying + else: + peft_cls = MegatronGPTAdapterModel elif cfg.peft.peft_scheme == "ia3": peft_cls = MegatronGPTIA3Model elif cfg.peft.peft_scheme == "ptuning": @@ -122,7 +127,10 @@ def _get_peft_scheme(cfg): elif cfg.peft.peft_scheme == "adapter_and_ptuning": peft_cls = MegatronGPTAdapterPTuningModel elif cfg.peft.peft_scheme == "lora": - peft_cls = MegatronGPTLoRAModel + if cfg.peft.lora_tuning.weight_tying: + peft_cls = MegatronGPTLoRAModelWeightTying + else: + peft_cls = MegatronGPTLoRAModel else: raise RuntimeError("Invalid Peft scheme") return peft_cls diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py index 776c0558d5ab..c32c9a8c5d23 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_peft_models.py @@ -21,8 +21,10 @@ AdapterName, InfusedAdapterConfig, LoraKQVAdapterConfig, + LoraKQVAdapterWeightTyingConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, + ParallelLinearAdapterWeightTyingConfig, PromptEncoderAdapterConfig, ) from nemo.core.classes.mixins import adapter_mixins @@ -131,7 +133,37 @@ def setup_optimizer_param_groups(self): logging.info(f"Optimizer groups set:\n{self.summarize()}") -class MegatronGPTAdapterModel(MegatronGPTPEFTModel): +class MegatronGPTLayerwisePEFTModel(MegatronGPTPEFTModel): + def __init__( + self, cfg: DictConfig, trainer: Trainer, + ): + super().__init__(cfg, trainer) + + def init_peft_modules(self): + """ + Randomly initialize the peft params and add them to the appropriate modules. + """ + assert len(self.peft_name_keys) > 0, "peft_name_keys have not been set no PEFT modules will be added" + assert len(self.name_key_to_cfg) > 0, "name_key_to_cfg has not been set no PEFT modules will be added" + logging.info(f"Before adding PEFT params:\n{self.summarize()}") + for layer in self.model.language_model.encoder.layers: + if layer.layer_number in self.layer_selection: + for _, module in layer.named_modules(): + if isinstance(module, adapter_mixins.AdapterModuleMixin): + for peft_key in self.peft_name_keys: + peft_cfg = self.name_key_to_cfg[peft_key] + if ( + model_utils.import_class_by_path(peft_cfg._target_) + in module.get_accepted_adapter_types() + ): + module.add_adapter( + name=peft_key, cfg=peft_cfg, + ) + logging.info(f"After adding PEFT params:\n{self.summarize()}") + return True + + +class MegatronGPTAdapterModel(MegatronGPTLayerwisePEFTModel): """ MegatronGPTAdapterLearningModel is a model that combines a base model (GPTSFTModel) with a adapters. This class only supports the canonical Adapter training described in Houlsby et al. (https://arxiv.org/pdf/1902.00751.pdf) @@ -151,7 +183,6 @@ def __init__( AdapterName.POST_ATTN_ADAPTER, ] adapter_tuning_cfg = cfg.peft.adapter_tuning - adapter_cfg = ParallelLinearAdapterConfig( in_features=cfg.hidden_size, out_features=cfg.hidden_size, @@ -167,10 +198,73 @@ def __init__( for k in self.peft_name_keys: self.name_key_to_cfg[k] = adapter_cfg + self.layer_selection = adapter_tuning_cfg.get("layer_selection", None) + if self.layer_selection is None: + self.layer_selection = list(range(1, cfg.num_layers + 1)) super().__init__(cfg, trainer) -class MegatronGPTIA3Model(MegatronGPTPEFTModel): +class MegatronGPTAdapterModelWeightTying(MegatronGPTLayerwisePEFTModel): + """ + TODO + """ + + def __init__( + self, cfg: DictConfig, trainer: Trainer, + ): + self.peft_name_keys = [ + AdapterName.PRE_ATTN_ADAPTER, + AdapterName.POST_ATTN_ADAPTER, + ] + adapter_tuning_cfg = cfg.peft.adapter_tuning + + adapter_cfg = ParallelLinearAdapterWeightTyingConfig( + in_features=cfg.hidden_size, + out_features=cfg.hidden_size, + dim=adapter_tuning_cfg.adapter_dim, + norm_position=adapter_tuning_cfg.get("norm_position", "pre"), + norm_type=adapter_tuning_cfg.get("norm_type", "mixedfusedlayernorm"), + column_init_method=adapter_tuning_cfg.get("column_init_method", "xavier"), + row_init_method=adapter_tuning_cfg.get("row_init_method", "zero"), + dropout=adapter_tuning_cfg.adapter_dropout, + num_position_embeddings=cfg.num_layers * 2, + dim_position_embeddings=cfg.hidden_size, + position_embedding_strategy=adapter_tuning_cfg.get("position_embedding_strategy", None), + ) + + self.name_key_to_cfg = {} + for k in self.peft_name_keys: + self.name_key_to_cfg[k] = adapter_cfg + + self.layer_selection = adapter_tuning_cfg.get("layer_selection", None) + if self.layer_selection is None: + self.layer_selection = list(range(1, cfg.num_layers + 1)) + super().__init__(cfg, trainer) + self.tie_weights() + + def tie_weights(self,): + pos_idx = 0 + layer0 = self.model.language_model.encoder.layers[0] + for adapter_name in layer0.adapter_layer: + adapter = layer0.get_adapter_module(adapter_name) + print(adapter_name, pos_idx) + adapter.set_position(pos_idx) + pos_idx += 1 + + for layer in self.model.language_model.encoder.layers[1:]: + for adapter_name in layer.adapter_layer: + print(adapter_name, pos_idx) + adapter_l = layer.get_adapter_module(adapter_name) + adapter_0 = layer0.get_adapter_module(adapter_name) + if hasattr(adapter_0, "layer_norm"): + lnorm = adapter_0.layer_norm + else: + lnorm = None + adapter_l.tie_weights(pos_idx, adapter_0) + pos_idx += 1 + + +class MegatronGPTIA3Model(MegatronGPTLayerwisePEFTModel): """ MegatronGPTInfusedAdapterModel is a model that combines a base model (GPTSFTModel) with a "Infused Adapter that can Inhibiting and Amplify Inner Activations", known as IA3. This class supports the addition of IA3 into a transformer based LM as described in Liu et al. (https://arxiv.org/pdf/2205.05638.pdf) @@ -330,7 +424,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.virtual_tokens = cfg.peft.p_tuning.virtual_tokens -class MegatronGPTLoRAModel(MegatronGPTPEFTModel): +class MegatronGPTLoRAModel(MegatronGPTLayerwisePEFTModel): """ MegatronGPTLoRAModel is a model that combines a base model (GPTSFTModel) with a low-rank adapters. The lora adapters will be added in `nemo/collections/nlp/modules/common/megatron/attention.py` @@ -360,8 +454,8 @@ def __init__( in_features=cfg.hidden_size, out_features=3 * projection_size, dim=lora_cfg.adapter_dim, - norm_position="none", - norm_type="none", + norm_position=None, + norm_type=None, activation="identity", column_init_method=lora_cfg.get("column_init_method", "normal"), row_init_method=lora_cfg.get("row_init_method", "zero"), @@ -372,5 +466,87 @@ def __init__( self.name_key_to_cfg = {} for k in self.peft_name_keys: self.name_key_to_cfg[k] = adapter_cfg + self.layer_selection = lora_cfg.get("layer_selection", None) + if self.layer_selection is None: + self.layer_selection = list(range(1, cfg.num_layers + 1)) + super().__init__(cfg, trainer) + +class MegatronGPTLoRAModelWeightTying(MegatronGPTLayerwisePEFTModel): + """ + TODO + """ + + def __init__( + self, cfg: DictConfig, trainer: Trainer, + ): + self.peft_name_keys = [ + AdapterName.LORA_KQV_ADAPTER, + ] + lora_cfg = cfg.peft.lora_tuning + if cfg.get("kv_channels", None) is None: + assert ( + cfg.hidden_size % cfg.num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = cfg.hidden_size // cfg.num_attention_heads + else: + kv_channels = cfg.kv_channels + projection_size = kv_channels * cfg.num_attention_heads + position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None) + if position_embedding_strategy is None: + dim_position_embeddings = 0 + elif position_embedding_strategy == "add": + dim_position_embeddings = cfg.hidden_size + elif position_embedding_strategy == "biasadd": + dim_position_embeddings = 3 * projection_size + elif position_embedding_strategy == "concat": + dim_position_embeddings = lora_cfg.adapter_dim + elif position_embedding_strategy == "mlpconcat": + dim_position_embeddings = lora_cfg.adapter_dim + else: + raise RuntimeError(f"Unknown position embedding strategy {position_embedding_strategy} for tied weights") + + adapter_cfg = LoraKQVAdapterWeightTyingConfig( + in_features=cfg.hidden_size, + out_features=3 * projection_size, + dim=lora_cfg.adapter_dim, + norm_position=None, + norm_type=None, + activation="identity", + column_init_method=lora_cfg.get("column_init_method", "normal"), + row_init_method=lora_cfg.get("row_init_method", "zero"), + gather_output=False, + dropout=lora_cfg.adapter_dropout, + num_position_embeddings=cfg.num_layers, + dim_position_embeddings=dim_position_embeddings, + position_embedding_strategy=position_embedding_strategy, + ) + + self.name_key_to_cfg = {} + for k in self.peft_name_keys: + self.name_key_to_cfg[k] = adapter_cfg + self.layer_selection = lora_cfg.get("layer_selection", None) + if self.layer_selection is None: + self.layer_selection = list(range(1, cfg.num_layers + 1)) super().__init__(cfg, trainer) + self.tie_weights() + + def tie_weights(self,): + pos_idx = 0 + layer0 = self.model.language_model.encoder.layers[0] + for adapter_name in layer0.self_attention.adapter_layer: + adapter = layer0.self_attention.get_adapter_module(adapter_name) + print(adapter_name, pos_idx) + adapter.set_position(pos_idx) + pos_idx += 1 + + for layer in self.model.language_model.encoder.layers[1:]: + for adapter_name in layer.self_attention.adapter_layer: + print(adapter_name, pos_idx) + adapter_l = layer.self_attention.get_adapter_module(adapter_name) + adapter_0 = layer0.self_attention.get_adapter_module(adapter_name) + position_embeddings_0 = None + if adapter_0.position_embedding_strategy: + position_embeddings_0 = adapter_0.position_embeddings + adapter_l.tie_weights(pos_idx, adapter_0) + pos_idx += 1 diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index d4a75aa18fb1..576366b90ddd 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -17,7 +17,7 @@ import enum import logging from dataclasses import dataclass - +from typing import Optional import torch import torch.nn as nn import torch.nn.init as init @@ -106,8 +106,8 @@ def __init__( out_features: int, dim: int, activation: str = 'swish', - norm_position: str = 'post', - norm_type: str = 'mixedfusedlayernorm', + norm_position: Optional[str] = 'post', + norm_type: Optional[str] = 'mixedfusedlayernorm', column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. gather_output: bool = True, @@ -161,6 +161,8 @@ def __init__( self.layer_norm = nn.LayerNorm(ln_features) else: raise NotImplementedError("norm_type should be either mixedfusedlayernorm or layernorm") + else: + self.layer_norm = None if dropout > 0.0: self.dropout = nn.Dropout(dropout) @@ -215,8 +217,8 @@ class ParallelLinearAdapterConfig: out_features: int dim: int activation: str = 'swish' - norm_position: str = 'post' - norm_type: str = 'mixedfusedlayernorm' + norm_position: Optional[str] = 'post' + norm_type: Optional[str] = 'mixedfusedlayernorm' column_init_method: str = 'xavier' row_init_method: str = 'zero' gather_output: bool = True @@ -375,3 +377,153 @@ class PromptEncoderAdapterConfig: init_std: float output_dim: int _target_: str = "{0}.{1}".format(PromptEncoderAdapter.__module__, PromptEncoderAdapter.__name__) + + +class ParallelLinearAdapterWeightTying(ParallelLinearAdapter): + """ + Extends parallel linear adapter for weight tying by providing a position embedding and convenience methods for tying weights + """ + + def __init__( + self, + in_features: int, + out_features: int, + dim: int, + activation: str = 'swish', + norm_position: Optional[str] = 'post', + norm_type: Optional[str] = 'mixedfusedlayernorm', + column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. + row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. + gather_output: bool = True, + dropout: float = 0.0, + num_position_embeddings: int = 1, + dim_position_embeddings: int = 1024, + position_embedding_strategy: Optional[str] = "add", + ): + self.position_embeddings = None + self.mlp = None + self.position_embedding_strategy = position_embedding_strategy + assert self.position_embedding_strategy in ["add", "concat", "mlpconcat", "biasadd", None] + if self.position_embedding_strategy == "concat": + in_features += dim_position_embeddings + elif self.position_embedding_strategy == "mlpconcat": + in_features += dim_position_embeddings + elif self.position_embedding_strategy == "biasadd": + assert ( + out_features == dim_position_embeddings + ), "adapter output feature size should match position emb size to bias add" + elif self.position_embedding_strategy == "add": + assert ( + in_features == dim_position_embeddings + ), "adapter input feature size should match position emb size to add" + super().__init__( + in_features, + out_features, + dim, + activation, + norm_position, + norm_type, + column_init_method, + row_init_method, + gather_output, + dropout, + ) + if self.position_embedding_strategy: + self.position_embeddings = torch.nn.Embedding(num_position_embeddings, dim_position_embeddings) + self.position_embeddings.weight.data.fill_(0.0) + if self.position_embedding_strategy == "mlpconcat": + self.mlp = torch.nn.Sequential( + torch.nn.Linear(dim_position_embeddings, dim_position_embeddings, bias=False), + torch.nn.GELU(), + torch.nn.Linear(dim_position_embeddings, dim_position_embeddings, bias=False), + ) + self.register_buffer("position_id", torch.LongTensor([1]), persistent=False) + + def set_position(self, position_id): + self.position_id *= position_id + + def tie_weights(self, position_id, adapter): + + self.set_position(position_id) + if self.linear_in: + self.linear_in.weight = adapter.linear_in.weight + if self.linear_out: + self.linear_out.weight = adapter.linear_out.weight + if self.layer_norm: + self.layer_norm.weight = adapter.layer_norm.weight + self.layer_norm.bias = adapter.layer_norm.bias + if self.mlp: + self.mlp[0].weight = adapter.mlp[0].weight + self.mlp[2].weight = adapter.mlp[2].weight + if self.position_embeddings: + self.position_embeddings.weight = adapter.position_embeddings.weight + + return True + + def forward(self, x): + + if self.position_embedding_strategy: + pos = self.position_embeddings(self.position_id).unsqueeze(0) + if self.position_embedding_strategy == "add": + pos = pos.expand_as(x) + x = x + pos + + elif self.position_embedding_strategy == "concat": + pos = pos.expand(x.shape[0], x.shape[1], pos.shape[2]) + x = torch.cat((x, pos), dim=2) + elif self.position_embedding_strategy == "mlpconcat": + pos = pos.expand(x.shape[0], x.shape[1], pos.shape[2]) + pos = self.mlp(pos) + x = torch.cat((x, pos), dim=2) + + if self.norm_position == 'pre': + x = self.layer_norm(x) + + x, _ = self.linear_in(x) # (@adithyare) ColumnLinear returns output and bias, we are ignoring the bias term. + x = self.activation(x) + x, _ = self.linear_out(x) + if self.norm_position == 'post': + x = self.layer_norm(x) + + if self.position_embedding_strategy == "biasadd": + pos = pos.expand_as(x) + x = x + pos + + # Add dropout if available + if self.dropout is not None: + x = self.dropout(x) + + return x + + +@dataclass +class ParallelLinearAdapterWeightTyingConfig: + in_features: int + out_features: int + dim: int + activation: str = 'swish' + norm_position: Optional[str] = 'post' + norm_type: Optional[str] = 'mixedfusedlayernorm' + column_init_method: str = 'xavier' + row_init_method: str = 'zero' + gather_output: bool = True + dropout: float = 0.0 + num_position_embeddings: int = 1 + dim_position_embeddings: int = 1024 + position_embedding_strategy: Optional[str] = "concat" + _target_: str = "{0}.{1}".format( + ParallelLinearAdapterWeightTying.__module__, ParallelLinearAdapterWeightTying.__name__ + ) + + +class LoraKQVAdapterWeightTying(ParallelLinearAdapterWeightTying): + """ + TODO + """ + + pass + + +@dataclass +class LoraKQVAdapterWeightTyingConfig(ParallelLinearAdapterWeightTyingConfig): + _target_: str = "{0}.{1}".format(LoraKQVAdapterWeightTying.__module__, LoraKQVAdapterWeightTying.__name__) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index a5a8b86b85bf..6b8189194333 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -22,6 +22,7 @@ AdapterName, InfusedAdapterConfig, LoraKQVAdapterConfig, + LoraKQVAdapterWeightTyingConfig, LoraKVAdapterConfig, LoraQAdapterConfig, ) @@ -143,6 +144,7 @@ def __init__( LoraKQVAdapterConfig._target_, LoraQAdapterConfig._target_, LoraKVAdapterConfig._target_, + LoraKQVAdapterWeightTyingConfig._target_, ] ) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 9cdcccf6e685..045daaf1151a 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -25,6 +25,7 @@ from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, ParallelLinearAdapterConfig, + ParallelLinearAdapterWeightTyingConfig, ) from nemo.collections.nlp.modules.common.megatron.attention import ParallelAttention, ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.fused_bias_dropout_add import ( @@ -188,7 +189,13 @@ def __init__( self.position_embedding_type = position_embedding_type self.param_dtype = utils_funcs.dtype_from_precision(precision, megatron_amp_O2) - self.set_accepted_adapter_types([LinearAdapterConfig._target_, ParallelLinearAdapterConfig._target_]) + self.set_accepted_adapter_types( + [ + LinearAdapterConfig._target_, + ParallelLinearAdapterConfig._target_, + ParallelLinearAdapterWeightTyingConfig._target_, + ] + ) if not bias and bias_dropout_add_fusion: raise ValueError(