Skip to content

Commit

Permalink
use resolve_model
Browse files Browse the repository at this point in the history
Signed-off-by: reidliu <[email protected]>
  • Loading branch information
reidliu committed Feb 21, 2025
1 parent 30f97a0 commit 71c737d
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions llama_stack/cli/model/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.models.llama.sku_list import resolve_model


class ModelRemove(Subcommand):
Expand Down Expand Up @@ -44,13 +44,15 @@ def _add_arguments(self):
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku

model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model)
prompt_guard = prompt_guard_model_sku()
if args.model == prompt_guard.model_id:
model = prompt_guard
else:
model = resolve_model(args.model)

model_list = []
for model in all_registered_models() + [prompt_guard_model_sku()]:
model_list.append(model.descriptor().replace(":", "-"))
model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-"))

if args.model not in model_list or os.path.isdir(model_path):
if model is None or not os.path.isdir(model_path):
print(f"'{args.model}' is not a valid llama model or does not exist.")
return

Expand Down

0 comments on commit 71c737d

Please sign in to comment.