-
Notifications
You must be signed in to change notification settings - Fork 274
/
Copy pathlist-models.ts
125 lines (114 loc) · 3.27 KB
/
list-models.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import { HUB_URL } from "../consts";
import { createApiError } from "../error";
import type { ApiModelInfo } from "../types/api/api-model";
import type { Credentials, PipelineType } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { parseLinkHeader } from "../utils/parseLinkHeader";
import { pick } from "../utils/pick";
const EXPAND_KEYS = [
"pipeline_tag",
"private",
"gated",
"downloads",
"likes",
"lastModified",
] as const satisfies readonly (keyof ApiModelInfo)[];
const EXPANDABLE_KEYS = [
"author",
"cardData",
"config",
"createdAt",
"disabled",
"downloads",
"downloadsAllTime",
"gated",
"gitalyUid",
"lastModified",
"library_name",
"likes",
"model-index",
"pipeline_tag",
"private",
"safetensors",
"sha",
// "siblings",
"spaces",
"tags",
"transformersInfo",
] as const satisfies readonly (keyof ApiModelInfo)[];
export interface ModelEntry {
id: string;
name: string;
private: boolean;
gated: false | "auto" | "manual";
task?: PipelineType;
likes: number;
downloads: number;
updatedAt: Date;
}
export async function* listModels<
const T extends Exclude<(typeof EXPANDABLE_KEYS)[number], (typeof EXPAND_KEYS)[number]> = never,
>(params?: {
search?: {
owner?: string;
task?: PipelineType;
tags?: string[];
};
credentials?: Credentials;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
fetch?: typeof fetch;
}): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
checkCredentials(params?.credentials);
let totalToFetch = params?.limit ?? Infinity;
const search = new URLSearchParams([
...Object.entries({
limit: String(Math.min(totalToFetch, 500)),
...(params?.search?.owner ? { author: params.search.owner } : undefined),
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
}),
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
...EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
]).toString();
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
while (url) {
const res: Response = await (params?.fetch ?? fetch)(url, {
headers: {
accept: "application/json",
...(params?.credentials ? { Authorization: `Bearer ${params.credentials.accessToken}` } : undefined),
},
});
if (!res.ok) {
throw createApiError(res);
}
const items: ApiModelInfo[] = await res.json();
for (const item of items) {
yield {
...(params?.additionalFields && pick(item, params.additionalFields)),
id: item._id,
name: item.id,
private: item.private,
task: item.pipeline_tag,
downloads: item.downloads,
gated: item.gated,
likes: item.likes,
updatedAt: new Date(item.lastModified),
} as ModelEntry & Pick<ApiModelInfo, T>;
totalToFetch--;
if (totalToFetch <= 0) {
return;
}
}
const linkHeader = res.headers.get("Link");
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
}
}