Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Olson <[email protected]>
jolson-ibm committed Oct 3, 2023
1 parent 24de8fb commit 2c5fed6
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions examples/run_peft_tuning.py
Original file line number Diff line number Diff line change
@@ -241,7 +241,12 @@ def register_common_arguments(subparsers: Tuple[argparse.ArgumentParser]) -> Non
default="float32",
choices=["float16", "bfloat16", "float32"],
)

subparser.add_argument(
"--compatibility_test",
help="Execute bootstrap and training (epoch = 1) as a quick compute test",
default=False,
action="store_true"
)

def register_multitask_prompt_tuning_args(subparser: argparse.ArgumentParser):
"""Register additional configuration options for MP(rompt)T subtask.
@@ -358,6 +363,7 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
)
print_strs = [
"Experiment Configuration",
"- Compatibility test: [{}]".format(args.compatibility_test),
"- Model Name: [{}]".format(args.model_name),
" |- Inferred Model Resource Type: [{}]".format(model_type),
"- Tuning Type: [{}]".format(args.tuning_type),
@@ -393,12 +399,18 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
train_stream = dataset_info.dataset_loader()[0]
if args.num_shots is not None:
train_stream = subsample_stream(train_stream, args.num_shots)
compat_string = ""
epoch_string = ""
if args.compatibility_test:
compat_string = " for a compatibility test."
args.num_epochs = 1
epoch_string = f" Setting num_epochs to {args.num_epochs}"
# Init the resource & Build the tuning config from our dataset/arg info
print_colored("[Loading the base model resource...]")
print_colored(f"[Loading the base model resource{compat_string}...]")
base_model = model_type.bootstrap(args.model_name, tokenizer_name=args.model_name)
tuning_config = build_tuning_config(args, dataset_info)
# Then actually train the model & save it
print_colored("[Starting the training...]")
print_colored(f"[Starting the training{compat_string}{epoch_string}...]")
model = PeftPromptTuning.train(
base_model,
train_stream,
@@ -415,5 +427,8 @@ def show_experiment_configuration(args, dataset_info, model_type) -> None:
accumulate_steps=args.accumulate_steps,
torch_dtype=args.torch_dtype,
)
model.save(args.output_dir, save_base_model=not args.prompt_only)
print_colored("[Training Complete]")
if not args.compatibility_test:
model.save(args.output_dir, save_base_model=not args.prompt_only)
print_colored("[Training Complete]")
else:
print_colored("[Compatibility Test Successfully Complete]")

0 comments on commit 2c5fed6

Please sign in to comment.