diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 3dce9d383da1..836b8253ea73 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1196,6 +1196,7 @@ def _init_weights(self, module): ) class AlignTextModel(AlignPreTrainedModel): config_class = AlignTextConfig + _no_split_modules = ["AlignTextEmbeddings"] def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True): super().__init__(config) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 0d27d87de7f4..4cf9614ae728 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1027,6 +1027,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel): config_class = AltCLIPConfig base_model_prefix = "altclip" supports_gradient_checkpointing = True + _no_split_module = [] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 262fc79f0d40..0c29423c87d6 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -850,6 +850,8 @@ class BertModel(BertPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ + _no_split_modules = ["BertEmbeddings"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 808c33f8104f..3eb6ad457910 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -549,6 +549,7 @@ class BlipTextPreTrainedModel(PreTrainedModel): config_class = BlipTextConfig base_model_prefix = "bert" + _no_split_modules = [] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index d8e97c20b24c..c29f3720027a 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1106,6 +1106,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): """ config_class = ChineseCLIPTextConfig + _no_split_modules = ["ChineseCLIPTextEmbeddings"] def __init__(self, config, add_pooling_layer=True): super().__init__(config) @@ -1277,6 +1278,7 @@ def forward( class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel): config_class = ChineseCLIPVisionConfig main_input_name = "pixel_values" + _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"] def __init__(self, config: ChineseCLIPVisionConfig): super().__init__(config) diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 788b0d911396..3a4901ae5ee5 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -364,6 +364,8 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width) DEPTH_ANYTHING_START_DOCSTRING, ) class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): + _no_split_modules = ["DPTViTEmbeddings"] + def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index c25022f6ec22..c90221f145d4 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -481,6 +481,7 @@ class Dinov2PreTrainedModel(PreTrainedModel): base_model_prefix = "dinov2" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2SwiGLUFFN"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/efficientformer/modeling_efficientformer.py b/src/transformers/models/efficientformer/modeling_efficientformer.py index 70075cff55d7..cc62e9cbd21e 100644 --- a/src/transformers/models/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/efficientformer/modeling_efficientformer.py @@ -555,6 +555,7 @@ class EfficientFormerModel(EfficientFormerPreTrainedModel): def __init__(self, config: EfficientFormerConfig): super().__init__(config) self.config = config + _no_split_modules = ["EfficientFormerMeta4D"] self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0]) self.encoder = EfficientFormerEncoder(config) diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index b169af0cbd56..4574ca378760 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -462,6 +462,7 @@ class PvtPreTrainedModel(PreTrainedModel): config_class = PvtConfig base_model_prefix = "pvt" main_input_name = "pixel_values" + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 9c2269a3ae54..424d657dc878 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -421,6 +421,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["ViTMSNAttention"] # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts.