diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index f1d36796c..1641db714 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -2,8 +2,12 @@ class HookedLinear(torch.nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: - super().__init__(in_features, out_features, bias, device, dtype) + def __init__(self, in_features: int, out_features: int) -> None: + # avoid calling super().__init__() as it would allocate memory based on in/out features + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.forward_hook = None @staticmethod @@ -16,15 +20,13 @@ def replace_linear_with_hooked_linear(module): @staticmethod def from_linear(linear: torch.nn.Linear): - custom_linear = HookedLinear(linear.in_features, linear.out_features, bias=linear.bias is not None, - device=linear.weight.device, dtype=linear.weight.dtype) + custom_linear = HookedLinear(linear.in_features, linear.out_features) custom_linear.weight = linear.weight - if linear.bias is not None: - custom_linear.bias = linear.bias + custom_linear.bias = linear.bias return custom_linear def forward(self, input: torch.Tensor) -> torch.Tensor: output = super().forward(input) if self.forward_hook: self.forward_hook(self, input, output) - return output \ No newline at end of file + return output