Skip to content

Commit

Permalink
Fix model copy after dispatch_model (#1971)
Browse files Browse the repository at this point in the history
* Fix model copy after dispatch_model

* Minor hook update to fix failing test

* address reviewer comments
  • Loading branch information
austinapatel authored Sep 19, 2023
1 parent 629d02c commit 03deec2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,17 @@ def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False)
module = hook.init_hook(module)
module._hf_hook = hook

@functools.wraps(old_forward)
def new_forward(*args, **kwargs):
def new_forward(module, *args, **kwargs):
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
if module._hf_hook.no_grad:
with torch.no_grad():
output = old_forward(*args, **kwargs)
output = module._old_forward(*args, **kwargs)
else:
output = old_forward(*args, **kwargs)
output = module._old_forward(*args, **kwargs)
return module._hf_hook.post_forward(module, output)

module.forward = new_forward
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)

return module


Expand Down
33 changes: 32 additions & 1 deletion tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import unittest
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -45,6 +45,18 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class ModelForTestCopy(nn.Module):
def __init__(self, id: int):
super().__init__()
self.id = id
self.linear1 = nn.Linear(3, 4)
self.batchnorm = nn.BatchNorm1d(4)
self.linear2 = nn.Linear(4, 5)

def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x))), self.id


class ModelForTestTiedWeights(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -325,6 +337,25 @@ def test_dispatch_model_multi_gpu(self):
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_cuda
def test_dispatch_model_copy(self):
original_model = ModelForTestCopy(id=1)
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": 0}

x = torch.randn(2, 3)
expected, original_output_id = original_model(x)

dispatch_model(original_model, device_map)

copied_model = copy.deepcopy(original_model)
copied_model.id = 2
output, copied_output_id = copied_model(x)

self.assertEqual(original_model.id, original_output_id)
self.assertEqual(copied_model.id, copied_output_id)
self.assertFalse(copied_model.linear1.forward is original_model.linear1.forward)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_cuda
def test_dispatch_model_move_offloaded_model(self):
model = ModelForTest()
Expand Down

0 comments on commit 03deec2

Please sign in to comment.