diff --git a/llama_stack/cli/model/remove.py b/llama_stack/cli/model/remove.py index aa0651ebc1..ee8d6299d9 100644 --- a/llama_stack/cli/model/remove.py +++ b/llama_stack/cli/model/remove.py @@ -10,7 +10,7 @@ from llama_stack.cli.subcommand import Subcommand from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR -from llama_stack.models.llama.sku_list import all_registered_models +from llama_stack.models.llama.sku_list import resolve_model class ModelRemove(Subcommand): @@ -44,13 +44,15 @@ def _add_arguments(self): def _run_model_remove_cmd(self, args: argparse.Namespace) -> None: from .safety_models import prompt_guard_model_sku - model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model) + prompt_guard = prompt_guard_model_sku() + if args.model == prompt_guard.model_id: + model = prompt_guard + else: + model = resolve_model(args.model) - model_list = [] - for model in all_registered_models() + [prompt_guard_model_sku()]: - model_list.append(model.descriptor().replace(":", "-")) + model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-")) - if args.model not in model_list or os.path.isdir(model_path): + if model is None or not os.path.isdir(model_path): print(f"'{args.model}' is not a valid llama model or does not exist.") return