Skip to content

Commit

Permalink
Merge pull request #25 from togethercomputer/justusc/enable-wandb
Browse files Browse the repository at this point in the history
Enable WandB option in the CLI
  • Loading branch information
justusc authored Sep 8, 2023
2 parents 89a19f4 + 81dd9ed commit 0f1d894
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ resp = together.Finetune.create(
batch_size = 4,
learning_rate = 1e-5,
suffix = 'my-demo-finetune',
wandb_api_key = '1a2b3c4d5e.......',
)

fine_tune_id = resp['id']
Expand Down
32 changes: 17 additions & 15 deletions src/together/commands/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import json
import os

from together import Finetune, extract_time

Expand Down Expand Up @@ -136,21 +137,21 @@ def _add_create(parser: argparse._SubParsersAction[argparse.ArgumentParser]) ->
help="Up to 40 characters that will be added to your fine-tuned model name.",
type=str,
)
# subparser.add_argument(
# "--wandb-api-key",
# "-wb",
# metavar="WANDB_API_KEY",
# default=os.getenv("WANDB_API_KEY"),
# help="Wandb API key to report metrics to wandb.ai. If not set WANDB_API_KEY environment variable is used.",
# type=str,
# )
# subparser.add_argument(
# "--no-wandb-api-key",
# "-nwb",
# default=False,
# help="Do not report metrics to wandb.ai.",
# action="store_true",
# )
subparser.add_argument(
"--wandb-api-key",
"-wb",
metavar="WANDB_API_KEY",
default=os.getenv("WANDB_API_KEY"),
help="Wandb API key to report metrics to wandb.ai. If not set WANDB_API_KEY environment variable is used.",
type=str,
)
subparser.add_argument(
"--no-wandb-api-key",
"-nwb",
default=False,
help="Do not report metrics to wandb.ai.",
action="store_true",
)

subparser.set_defaults(func=_run_create)

Expand Down Expand Up @@ -296,6 +297,7 @@ def _run_create(args: argparse.Namespace) -> None:
# checkpoint_steps=args.checkpoint_steps,
suffix=args.suffix,
estimate_price=args.estimate_price,
wandb_api_key=args.wandb_api_key if not args.no_wandb_api_key else None,
)

print(json.dumps(response, indent=4))
Expand Down
3 changes: 2 additions & 1 deletion src/together/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def create(
str
] = None, # resulting finetuned model name will include the suffix
estimate_price: bool = False,
wandb_api_key: Optional[str] = None,
) -> Dict[Any, Any]:
if n_epochs is None or n_epochs < 1:
logger.fatal("The number of epochs must be specified")
Expand Down Expand Up @@ -102,7 +103,7 @@ def create(
# "seed": seed,
# "fp16": fp16,
"suffix": suffix,
# "wandb_key": wandb_api_key,
"wandb_key": wandb_api_key,
}

# check if model name is one of the models available for finetuning
Expand Down

0 comments on commit 0f1d894

Please sign in to comment.