diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0d922218aea5ec..ce3c79bff08e19 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1107,6 +1107,9 @@ def train( from torch_ort import ORTModule logger.info("Converting to ORTModule ....") model = ORTModule(self.model) + if args.ort_config_file: + from torch_ort.experimental.json_config import load_from_json + load_from_json(model, path=args.ort_config_file) self.model_wrapped = model if args.deepspeed: if args.ort: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4e2511b9e10ff9..1b49b37303aaf9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -526,6 +526,10 @@ class TrainingArguments: default=False, metadata={"help": "Enable Ort"}, ) + ort_config_file: Optional[str] = field( + default=None, + metadata={"help": "Configure ORTModule internal options"}, + ) label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} )