Skip to content

Commit

Permalink
fix forward hook (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
LRL-ModelCloud authored Jan 5, 2025
1 parent 55c594b commit c88f50f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions gptqmodel/nn_modules/hooked_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def from_conv1d(conv1d: Conv1D):
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, input, output)
self.forward_hook(self, (input,), output)
return output


Expand All @@ -41,7 +41,7 @@ def from_conv2d(conv2d: torch.nn.Conv2d):
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, input, output)
self.forward_hook(self, (input,), output)
return output


Expand All @@ -64,7 +64,7 @@ def from_linear(linear: torch.nn.Linear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = super().forward(input)
if self.forward_hook:
self.forward_hook(self, input, output)
self.forward_hook(self, (input,), output)
return output


Expand Down

0 comments on commit c88f50f

Please sign in to comment.