-
Notifications
You must be signed in to change notification settings - Fork 3k
/
backend-impl.ts
161 lines (145 loc) · 5.34 KB
/
backend-impl.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Backend} from './backend.js';
import {InferenceSession} from './inference-session.js';
interface BackendInfo {
backend: Backend;
priority: number;
initPromise?: Promise<void>;
initialized?: boolean;
aborted?: boolean;
error?: string;
}
const backends: Map<string, BackendInfo> = new Map();
const backendsSortedByPriority: string[] = [];
/**
* Register a backend.
*
* @param name - the name as a key to lookup as an execution provider.
* @param backend - the backend object.
* @param priority - an integer indicating the priority of the backend. Higher number means higher priority. if priority
* < 0, it will be considered as a 'beta' version and will not be used as a fallback backend by default.
*
* @ignore
*/
export const registerBackend = (name: string, backend: Backend, priority: number): void => {
if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') {
const currentBackend = backends.get(name);
if (currentBackend === undefined) {
backends.set(name, {backend, priority});
} else if (currentBackend.priority > priority) {
// same name is already registered with a higher priority. skip registeration.
return;
} else if (currentBackend.priority === priority) {
if (currentBackend.backend !== backend) {
throw new Error(`cannot register backend "${name}" using priority ${priority}`);
}
}
if (priority >= 0) {
const i = backendsSortedByPriority.indexOf(name);
if (i !== -1) {
backendsSortedByPriority.splice(i, 1);
}
for (let i = 0; i < backendsSortedByPriority.length; i++) {
if (backends.get(backendsSortedByPriority[i])!.priority <= priority) {
backendsSortedByPriority.splice(i, 0, name);
return;
}
}
backendsSortedByPriority.push(name);
}
return;
}
throw new TypeError('not a valid backend');
};
/**
* Try to resolve and initialize a backend.
*
* @param backendName - the name of the backend.
* @returns the backend instance if resolved and initialized successfully, or an error message if failed.
*/
const tryResolveAndInitializeBackend = async(backendName: string): Promise<Backend|string> => {
const backendInfo = backends.get(backendName);
if (!backendInfo) {
return 'backend not found.';
}
if (backendInfo.initialized) {
return backendInfo.backend;
} else if (backendInfo.aborted) {
return backendInfo.error!;
} else {
const isInitializing = !!backendInfo.initPromise;
try {
if (!isInitializing) {
backendInfo.initPromise = backendInfo.backend.init(backendName);
}
await backendInfo.initPromise;
backendInfo.initialized = true;
return backendInfo.backend;
} catch (e) {
if (!isInitializing) {
backendInfo.error = `${e}`;
backendInfo.aborted = true;
}
return backendInfo.error!;
} finally {
delete backendInfo.initPromise;
}
}
};
/**
* Resolve execution providers from the specific session options.
*
* @param options - the session options object.
* @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with
* filtered EP list.
*
* @ignore
*/
export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions):
Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => {
// extract backend hints from session options
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
// try to resolve and initialize all requested backends
let backend: Backend|undefined;
const errors = [];
const availableBackendNames = new Set<string>();
for (const backendName of backendNames) {
const resolveResult = await tryResolveAndInitializeBackend(backendName);
if (typeof resolveResult === 'string') {
errors.push({name: backendName, err: resolveResult});
} else {
if (!backend) {
backend = resolveResult;
}
if (backend === resolveResult) {
availableBackendNames.add(backendName);
}
}
}
// if no backend is available, throw error.
if (!backend) {
throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
}
// for each explicitly requested backend, if it's not available, output warning message.
for (const {name, err} of errors) {
if (backendHints.includes(name)) {
// eslint-disable-next-line no-console
console.warn(`removing requested execution provider "${
name}" from session options because it is not available: ${err}`);
}
}
const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name));
return [
backend, new Proxy(options, {
get: (target, prop) => {
if (prop === 'executionProviders') {
return filteredEps;
}
return Reflect.get(target, prop);
}
})
];
};