Skip to content

Commit

Permalink
convert.py: get default outfile
Browse files Browse the repository at this point in the history
  • Loading branch information
mofosyne committed May 6, 2024
1 parent 7c6d166 commit ddfaad9
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,7 +1523,7 @@ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) ->
return vocab, special_vocab


def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
def default_convention_outfile(file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> str:
quantization = {
GGMLFileType.AllF32: "F32",
GGMLFileType.MostlyF16: "F16",
Expand All @@ -1546,7 +1546,12 @@ def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Pa
elif params.path_model is not None:
name = params.path_model.name

ret = model_paths[0].parent / f"{name}{version}-{expert_count}{parameters}-{quantization}.gguf"
return f"{name}{version}-{expert_count}{parameters}-{quantization}"


def default_outfile(model_paths: list[Path], file_type: GGMLFileType, params: Params, model_params_count: int, metadata: Metadata) -> Path:
default_filename = default_convention_outfile(file_type, params, model_params_count, metadata)
ret = model_paths[0].parent / f"{default_filename}.gguf"
if ret in model_paths:
logger.error(
f"Error: Default output path ({ret}) would overwrite the input. "
Expand Down Expand Up @@ -1585,6 +1590,7 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--metadata", type=Path, help="Specify the path for a metadata file")
parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name")

args = parser.parse_args(args_in)

Expand All @@ -1598,6 +1604,16 @@ def main(args_in: list[str] | None = None) -> None:

metadata = Metadata.load(args.metadata)

if args.get_outfile:
logging.basicConfig(level=logging.CRITICAL)
model_plus = load_some_model(args.model)
params = Params.load(model_plus)
model = convert_model_names(model_plus.model, params, args.skip_unknown)
model_params_count = model_parameter_count(model_plus.model)
ftype = pick_output_type(model, args.outtype)
print(f"{default_convention_outfile(ftype, params, model_params_count, metadata)}") # noqa: NP100
return

if args.no_vocab and args.vocab_only:
raise ValueError("--vocab-only does not make sense with --no-vocab")

Expand Down

0 comments on commit ddfaad9

Please sign in to comment.