Skip to content

Commit

Permalink
Merge pull request #215 from jiangtann/main
Browse files Browse the repository at this point in the history
fix lm_head type changed bug
  • Loading branch information
shibing624 authored Sep 20, 2023
2 parents 6fd2713 + 82659b6 commit d25c205
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,25 @@ class PeftArguments(TrainingArguments):
qlora: bool = field(default=False, metadata={"help": "Whether to use qlora"})


class CastOutputToFloat(torch.nn.Sequential):
class CastOutputToFloat(torch.nn.Module):
"""Cast the output of the model to float"""
def __init__(self, ori_linear: torch.nn.Linear) -> None:
super().__init__()
self.in_features = ori_linear.in_features
self.out_features = ori_linear.out_features
self.weight = ori_linear.weight
if ori_linear.bias is not None:
self.bias = ori_linear.bias
else:
self.register_parameter('bias', None)

def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias).to(torch.float32)

def forward(self, x):
return super().forward(x).to(torch.float32)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)


@dataclass
Expand Down

0 comments on commit d25c205

Please sign in to comment.