diff --git a/packages/gguf/src/gguf.ts b/packages/gguf/src/gguf.ts index 67d121adc..152d2dbc5 100644 --- a/packages/gguf/src/gguf.ts +++ b/packages/gguf/src/gguf.ts @@ -400,8 +400,9 @@ export async function ggufAllShards( */ fetch?: typeof fetch; additionalFetchHeaders?: Record; + allowLocalFile?: boolean; } -): Promise<{ shards: GGUFParseOutput[]; parameterCount: number }> { +): Promise { const ggufShardFileInfo = parseGgufShardFilename(url); if (ggufShardFileInfo) { const total = parseInt(ggufShardFileInfo.total); @@ -414,15 +415,48 @@ export async function ggufAllShards( const PARALLEL_DOWNLOADS = 20; const shards = await promisesQueue( - urls.map((shardUrl) => () => gguf(shardUrl, { ...params, computeParametersCount: true })), + urls.map((shardUrl) => async () => { + const output = await gguf(shardUrl, { ...params, computeParametersCount: true }); + return output; + }), PARALLEL_DOWNLOADS ); + + // Sanity check split.count parameter + const output: GGUFParseOutput<{ strict: false }> = shards[0]; + const splitCount = output.metadata["split.count"]; + if (splitCount !== shards.length) { + throw new Error(`Expect to "split.count" to be ${shards.length}, but got ${splitCount}`); + } + + // Sanity check split.no parameter + for (let i = 0; i < shards.length; i++) { + const shard = shards[i]; + if (!shard.metadata["split.count"]) { + continue; + } + const splitNo = shard.metadata["split.no"]; + if (splitNo !== i) { + throw new Error(`Expect to "split.no" to be ${i}, but got ${splitNo}`); + } else if (i > 0) { + // skip first shard (already added) + output.tensorInfos = [...output.tensorInfos, ...shard.tensorInfos]; + } + } + + // Sanity check split.tensors.count parameter + const splitTensorsCount = output.metadata["split.tensors.count"]; + if (splitTensorsCount !== output.tensorInfos.length) { + throw new Error( + `Expect to "split.tensors.count" to be ${output.tensorInfos.length}, but got ${splitTensorsCount}` + ); + } + return { - shards, + ...output, parameterCount: shards.map(({ parameterCount }) => parameterCount).reduce((acc, val) => acc + val, 0), }; } else { - const { metadata, tensorInfos, parameterCount } = await gguf(url, { ...params, computeParametersCount: true }); - return { shards: [{ metadata, tensorInfos }], parameterCount }; + return await gguf(url, { ...params, computeParametersCount: true }); } } diff --git a/packages/gguf/src/types.spec.ts b/packages/gguf/src/types.spec.ts index 9d20bfa8c..4e47b23dc 100644 --- a/packages/gguf/src/types.spec.ts +++ b/packages/gguf/src/types.spec.ts @@ -51,5 +51,12 @@ describe("gguf-types", () => { // @ts-expect-error llama does not have ssm.* keys model["mamba.ssm.conv_kernel"] = 0; } + + if (model["split.count"]) { + model["split.no"] = 123; + } else { + // @ts-expect-error not a split (shard) model + model["split.no"] = 123; + } }); }); diff --git a/packages/gguf/src/types.ts b/packages/gguf/src/types.ts index 9e6f89dbf..c105e5fd8 100644 --- a/packages/gguf/src/types.ts +++ b/packages/gguf/src/types.ts @@ -92,6 +92,20 @@ interface NoTokenizer { "tokenizer.ggml.model"?: undefined; } +/// Splits + +interface Splits { + // Index of the current split (couting from 0) + "split.no": number; + // Total number of splits (couting from 1) + "split.count": number; + // Total number of tensors from all splits + "split.tensors.count": number; +} +interface NoSplits { + "split.count"?: undefined; +} + /// Models outside of llama.cpp: "rwkv" and "whisper" export type RWKV = GGUFGeneralInfo<"rwkv"> & @@ -126,7 +140,7 @@ export type GGUFMetadata } & GGUFModelKV & (Options extends { strict: true } ? unknown : Record); -export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer); +export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer) & (Splits | NoSplits); export interface GGUFTensorInfo { name: string;