Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-device for more models #30409

Merged
merged 41 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4f367a3
feat: support for dinov2
jla524 Apr 22, 2024
7e95e93
feat: support for depth_anything
jla524 Apr 22, 2024
cc0ff99
feat: support for efficientformer
jla524 Apr 22, 2024
14f4536
feat: support for bert (is this right?)
jla524 Apr 23, 2024
90015ae
update: embedding split
jla524 Apr 23, 2024
11275c0
remove: empty string
jla524 Apr 23, 2024
70da3af
feat: support for align
jla524 Apr 23, 2024
4acb68e
fix: copies
jla524 Apr 23, 2024
9ed35c4
fix: QAQBertEmbeddings
jla524 Apr 23, 2024
54aa1b2
fix: more consistency issues
jla524 Apr 23, 2024
863d392
revert: support for effientformer
jla524 Apr 23, 2024
f7f8f48
feat: support for altclip
jla524 Apr 23, 2024
987d418
feat: support for blip_text
jla524 Apr 23, 2024
04eac19
support for ChineseCLIP
jla524 Apr 23, 2024
3f88e46
feat: support for depth anything
jla524 Apr 24, 2024
4cf422f
feat: support for dpt
jla524 Apr 24, 2024
74f6263
feat: support for dpt
jla524 Apr 24, 2024
a8b384b
feat: support for git
jla524 Apr 24, 2024
5206ff6
feat: support for groupvit
jla524 Apr 24, 2024
cfc5a43
update: format
jla524 Apr 24, 2024
e3002e4
fix: support for clip
jla524 Apr 24, 2024
3d62a4a
fix: consistency
jla524 Apr 24, 2024
e121478
feat: support for pvt
jla524 Apr 24, 2024
5d9f452
feat: support for vit_msn
jla524 Apr 24, 2024
84b9fc6
fix: consistency
jla524 Apr 24, 2024
1197b8d
fix: other copies
jla524 Apr 25, 2024
efd4271
remove: device transfer
jla524 Apr 29, 2024
dded165
revert: in-place add
jla524 Apr 29, 2024
cb0c291
update: support for align
jla524 Apr 30, 2024
df99d79
update: support for bert
jla524 Apr 30, 2024
433c39b
update: support for Chinese CLIP
jla524 Apr 30, 2024
18c2bb0
revert: changes to efficientformer
jla524 Apr 30, 2024
091fcad
update: support for dpt
jla524 Apr 30, 2024
f590c99
update: support for efficientformer
jla524 Apr 30, 2024
16a4258
revert: changes to git
jla524 Apr 30, 2024
ac40045
revert: changes to groupvit
jla524 Apr 30, 2024
14f2b26
revert: changes to roc_bert
jla524 Apr 30, 2024
8dfbd4e
update: support for vit_msn
jla524 Apr 30, 2024
675f95d
revert: changes to dpt
jla524 Apr 30, 2024
491198b
remove: extra space
jla524 Apr 30, 2024
603ef24
style: extra space
jla524 Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/align/modeling_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ def _init_weights(self, module):
)
class AlignTextModel(AlignPreTrainedModel):
config_class = AlignTextConfig
_no_split_modules = ["AlignTextEmbeddings"]

def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
config_class = AltCLIPConfig
base_model_prefix = "altclip"
supports_gradient_checkpointing = True
_no_split_module = []

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ class BertModel(BertPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""

_no_split_modules = ["BertEmbeddings"]

def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ class BlipTextPreTrainedModel(PreTrainedModel):

config_class = BlipTextConfig
base_model_prefix = "bert"
_no_split_modules = []

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/chinese_clip/modeling_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
"""

config_class = ChineseCLIPTextConfig
_no_split_modules = ["ChineseCLIPTextEmbeddings"]

def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
Expand Down Expand Up @@ -1277,6 +1278,7 @@ def forward(
class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel):
config_class = ChineseCLIPVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"]

def __init__(self, config: ChineseCLIPVisionConfig):
super().__init__(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ def forward(self, hidden_states: List[torch.Tensor], patch_height, patch_width)
DEPTH_ANYTHING_START_DOCSTRING,
)
class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
_no_split_modules = ["DPTViTEmbeddings"]

def __init__(self, config):
super().__init__(config)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
base_model_prefix = "dinov2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Dinov2SwiGLUFFN"]

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ class EfficientFormerModel(EfficientFormerPreTrainedModel):
def __init__(self, config: EfficientFormerConfig):
super().__init__(config)
self.config = config
_no_split_modules = ["EfficientFormerMeta4D"]

self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0])
self.encoder = EfficientFormerEncoder(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/pvt/modeling_pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ class PvtPreTrainedModel(PreTrainedModel):
config_class = PvtConfig
base_model_prefix = "pvt"
main_input_name = "pixel_values"
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/vit_msn/modeling_vit_msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTMSNAttention"]

# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
# when creating pre-training scripts.
Expand Down
Loading