-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use SFTConfig
instead of SFTTrainer
keyword args
#2150
Conversation
Thanks for addressing these trl deprecations. For my understanding, some arguments have been dropped without adding them to |
Which ones? |
packing=data_args.packing, | ||
dataset_kwargs={ | ||
"append_concat_token": data_args.append_concat_token, | ||
"add_special_tokens": data_args.add_special_tokens, | ||
}, | ||
dataset_text_field=data_args.dataset_text_field, | ||
max_seq_length=data_args.max_seq_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here would be an example where arguments are removed from SFTTrainer
but no equivalent arguments were added to training_args
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I see what you mean.
When you run the script with, for example, --max_seq_length 123
, the value will no longer feed data_args
but training_args
. Everything happens behind the scenes when arguments are parsed.
python example.py --output_dir tmp --max_seq_length 123
Before:
from dataclasses import dataclass
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = None
max_seq_length: int = 512
if __name__ == "__main__":
parser = HfArgumentParser((DataTrainingArguments, TrainingArguments))
data_args, training_args = parser.parse_args_into_dataclasses()
print(data_args.max_seq_length) # 123
After:
from dataclasses import dataclass
from typing import Optional
from transformers import HfArgumentParser
from trl import SFTConfig
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = None
if __name__ == "__main__":
parser = HfArgumentParser((DataTrainingArguments, SFTConfig))
data_args, training_args = parser.parse_args_into_dataclasses()
print(training_args.max_seq_length) # 123
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR and explaining the change. I tried the updated script and it worked. LGTM.
…#2150) Update training script using trl to fix deprecations in argument usage.
…#2150) Update training script using trl to fix deprecations in argument usage.
…#2150) Update training script using trl to fix deprecations in argument usage.
SFTTrainer
's keyword args likepacking
,dataset_kwargs
,dataset_text_field
,max_seq_length
has been deprecated and will be soon removed. Instead, we useSFTConfig
(subclass ofTrainingArguments
)This PR updates the code related to
SFTTrainer
.