diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b3102a37d37f31..4871110f5b6ffb 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3552,6 +3552,7 @@ def from_pretrained( "device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index, + "force_hooks": True } if "skip_keys" in inspect.signature(dispatch_model).parameters: device_map_kwargs["skip_keys"] = model._skip_keys_device_placement