forked from epfml/disco
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining_information.ts
185 lines (174 loc) · 7.25 KB
/
training_information.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import { AggregatorChoice } from '../aggregator/get'
import { Preprocessing } from '../dataset/data/preprocessing'
import { isModelCompileData, ModelCompileData } from './model_compile_data'
export function isTrainingInformation(raw: unknown): raw is TrainingInformation {
if (typeof raw !== 'object') {
return false
}
if (raw === null) {
return false
}
type Fields =
| 'dataType'
| 'scheme'
| 'epochs'
| 'roundDuration'
| 'validationSplit'
| 'batchSize'
| 'modelCompileData'
| 'modelID'
| 'preprocessingFunctions'
| 'inputColumns'
| 'outputColumns'
| 'IMAGE_H'
| 'IMAGE_W'
| 'modelURL'
| 'learningRate'
| 'decentralizedSecure'
| 'maxShareValue'
| 'minimumReadyPeers'
| 'LABEL_LIST'
| 'noiseScale'
| 'clippingRadius'
| 'aggregator'
const {
dataType,
scheme,
epochs,
// roundDuration,
validationSplit,
batchSize,
modelCompileData,
modelID,
preprocessingFunctions,
inputColumns,
outputColumns,
IMAGE_H,
IMAGE_W,
roundDuration,
modelURL,
learningRate,
decentralizedSecure,
maxShareValue,
minimumReadyPeers,
LABEL_LIST,
noiseScale,
clippingRadius,
aggregator,
} = raw as Record<Fields, unknown | undefined>
if (
typeof dataType !== 'string' ||
typeof modelID !== 'string' ||
typeof epochs !== 'number' ||
typeof batchSize !== 'number' ||
typeof roundDuration !== 'number' ||
typeof validationSplit !== 'number' ||
(modelURL !== undefined && typeof modelURL !== 'string') ||
(noiseScale !== undefined && typeof noiseScale !== 'number') ||
(clippingRadius !== undefined && typeof clippingRadius !== 'number') ||
(learningRate !== undefined && typeof learningRate !== 'number') ||
(decentralizedSecure !== undefined && typeof decentralizedSecure !== 'boolean') ||
(maxShareValue !== undefined && typeof maxShareValue !== 'number') ||
(minimumReadyPeers !== undefined && typeof minimumReadyPeers !== 'number') ||
(aggregator !== undefined && typeof aggregator !== 'number')
) {
return false
}
// interdepences on data type
if (dataType === 'image') {
if (typeof IMAGE_H !== 'number' || typeof IMAGE_W !== 'number') {
return false
}
} else if (dataType in ['text', 'tabular']) {
if (!(Array.isArray(inputColumns) && inputColumns.every((e) => typeof e === 'string'))) {
return false
}
if (!(Array.isArray(outputColumns) && outputColumns.every((e) => typeof e === 'string'))) {
return false
}
}
// interdepences on scheme
switch (scheme) {
case 'decentralized':
break
case 'federated':
break
case 'local':
break
}
if (!isModelCompileData(modelCompileData)) {
return false
}
if (
LABEL_LIST !== undefined &&
!(Array.isArray(LABEL_LIST) && LABEL_LIST.every((e) => typeof e === 'string'))
) {
return false
}
if (preprocessingFunctions !== undefined && !Array.isArray(preprocessingFunctions)) {
return false
}
return true
}
export interface TrainingInformation {
// modelID: unique ID for the model
modelID: string
// epochs: number of epochs to run training for
epochs: number
// roundDuration: number of batches between each weight sharing round, e.g. if 3 then after every
// 3 batches we share weights (in the distributed setting).
roundDuration: number
// validationSplit: fraction of data to keep for validation, note this only works for image data
validationSplit: number
// batchSize: batch size of training data
batchSize: number
// preprocessingFunctions: preprocessing functions such as resize and normalize
preprocessingFunctions?: Preprocessing[]
// modelCompileData: interface of additional training information (optimizer, loss and metrics)
modelCompileData: ModelCompileData
// dataType, e.g. image, tabular or text
dataType: string
// inputColumns: for tabular data, the columns to be chosen as input data for the model
inputColumns?: string[]
// outputColumns: for tabular data, the columns to be predicted by the model
outputColumns?: string[]
// IMAGE_H height of image (or RESIZED_IMAGE_H if ImagePreprocessing.Resize in preprocessingFunctions)
IMAGE_H?: number
// IMAGE_W width of image (or RESIZED_IMAGE_W if ImagePreprocessing.Resize in preprocessingFunctions)
IMAGE_W?: number
// Model URL to download the base task model from. Useful for pretrained or/and hosted models.
modelURL?: string
// LABEL_LIST of classes, e.g. if two class of images, one with dogs and one with cats, then we would
// define ['dogs', 'cats'].
LABEL_LIST?: string[]
// learningRate: learning rate for the optimizer
learningRate?: number
// scheme: Distributed training scheme, i.e. Federated and Decentralized
scheme: 'Federated' | 'Decentralized'
// noiseScale: Differential Privacy (DP): Affects the variance of the Gaussian noise added to the models / model updates.
// Number or undefined. If undefined, then no noise will be added.
noiseScale?: number
// clippingRadius: Privacy (DP and Secure Aggregation):
// Number or undefined. If undefined, then no model updates will be clipped.
// If number, then model updates will be scaled down if their norm exceeds clippingRadius.
clippingRadius?: number
// decentralizedSecure: Secure Aggregation on/off:
// Boolean. true for secure aggregation to be used, if the training scheme is decentralized, false otherwise
decentralizedSecure?: boolean
// byzantineRobustAggregator: Byzantine robust aggregator on/off:
// Boolean. true to use byzantine robust aggregation, if the training scheme is federated, false otherwise
byzantineRobustAggregator?: boolean
// tauPercentile: it indicates the percentile to take when choosing the tau for byzantine robust aggregator:
// Number (>0 && <1). It must be a number between 0 and 1 and it is used only if byzantineRobustAggregator is true.
tauPercentile?: number
// maxShareValue: Secure Aggregation: maximum absolute value of a number in a randomly generated share
// default is 100, must be a positive number, check the ~/disco/information/PRIVACY.md file for more information on significance of maxShareValue selection
// only relevant if secure aggregation is true (for either federated or decentralized learning)
maxShareValue?: number
// minimumReadyPeers: Decentralized Learning: minimum number of peers who must be ready to participate in aggregation before model updates are shared between clients
// default is 3, range is [3, totalNumberOfPeersParticipating]
minimumReadyPeers?: number
// aggregator: aggregator to be used by the server for federated learning, or by the peers for decentralized learning
// default is 'average', other options include for instance 'bandit'
aggregator?: AggregatorChoice
}