-
Notifications
You must be signed in to change notification settings - Fork 539
/
Copy pathLowerToBackendContract.cpp
450 lines (421 loc) · 18.5 KB
/
LowerToBackendContract.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
//===- LowerToBackendContract.cpp --------------------------------*- C++-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "torch-lower-to-backend-contract"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
//===----------------------------------------------------------------------===//
// Checking the backend contract.
//===----------------------------------------------------------------------===//
static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target,
ArrayRef<std::string> backendLegalOps);
static LogicalResult checkType(Operation *op, Type type,
bool actuallyEmitDiagnostics) {
// Allow various scalar types that backends are expected to be able to handle.
if (type.isa<Torch::IntType, Torch::FloatType, Torch::BoolType,
Torch::DeviceType>())
return success();
// Backends are not expected to support dynamic computations on these types,
// but they frequently appear as parameters to ops which backends
// can statically pattern match and eliminate from the program.
// For example, a tensor operand might be optional, and the backend
// will pattern-match statically whether it is passed as a tensor or None.
if (type.isa<Torch::NoneType, Torch::StringType>())
return success();
// We blanket prohibit non-value-semantic tensors.
// All of our backends are currently based on value-semantic tensors, so
// we consider it our responsibility to lower all non-value-semantic tensors
// to value-semantic tensors.
if (type.isa<NonValueTensorType>()) {
if (actuallyEmitDiagnostics) {
return op
->emitError("unsupported by backend contract: non-value tensor type")
.attachNote()
.append("this is likely due to a missing case in the "
"MaximizeValueSemantics pass");
} else {
return failure();
}
}
// For value-semantic tensors, we require at least a known rank and dtype.
// We are not aware of a situation where our backends can handle an unranked
// tensor type or a tensor with a dynamic dtype.
//
// There are somewhat fundamental reasons for this. In particular, the problem
// of unranked codegen is completely different from the problem of ranked
// codegen (since ranked corresponds to a fixed loop nest structure). For all
// codegen systems we are aware of, the program must be reduced to operate
// on ranked tensors at some point in compilation, and we are not aware of
// any backend with a general solution to this problem before it reaches
// codegen. So we consider it our responsibility to eliminate unranked tensor
// from the program.
//
// We aren't aware of any backend with any infrastructure to represent dynamic
// dtypes, let alone transform and optimize them. Additionally, it is unlikely
// that any backend, even if it supports dynamic dtypes in some form, will
// have an sufficiently rich system for representing PyTorch type promotion
// rules. So we consider it our responsibility to ensure that all dtypes are
// statically known.
if (auto tensorType = type.dyn_cast<ValueTensorType>()) {
if (!tensorType.hasSizes()) {
if (actuallyEmitDiagnostics) {
return op
->emitError(
"unsupported by backend contract: tensor with unknown rank")
.attachNote()
.append("this is likely due to a missing transfer function "
"in abstract_interp_lib_gen.py");
} else {
return failure();
}
}
if (!tensorType.hasDtype()) {
if (actuallyEmitDiagnostics) {
return op
->emitError(
"unsupported by backend contract: tensor with unknown dtype")
.attachNote()
.append("this is likely due to a missing case in RefineTypes");
} else {
return failure();
}
}
return success();
}
// Optional types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary.
if (auto optionalType = type.dyn_cast<OptionalType>()) {
// TODO: Be stricter about tensor types.
// See comment below for ListType.
if (optionalType.getContainedType().isa<ValueTensorType>())
return success();
return checkType(op, optionalType.getContainedType(),
actuallyEmitDiagnostics);
}
// List types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary. For example, the
// strides of a convolution op are represented as a list.
if (auto listType = type.dyn_cast<ListType>()) {
// TODO: Be stricter about tensor types.
// For the moment, there are cases (such as for torch.cat) where we end
// up with `!torch.list<vtensor>` which doesn't have shape or dtype in
// the contained type information. Somehow this slips through and works.
// We should be stricter about this and properly infer the contained type
// and shape.
if (listType.getContainedType().isa<ValueTensorType>())
return success();
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
}
// Tuple types are also in the category of types which we don't expect
// backends to dynamically compute with, but they can be pattern matched
// in many cases that are practically necessary.
if (auto tupleType = type.dyn_cast<Torch::TupleType>()) {
for (auto containedType : tupleType.getContainedTypes()) {
if (failed(checkType(op, containedType, actuallyEmitDiagnostics)))
return failure();
}
return success();
}
// Unsupported type.
if (actuallyEmitDiagnostics) {
return op->emitError("unsupported by backend contract: type ") << type;
} else {
return failure();
}
}
static LogicalResult checkOpIsBackendLegal(Operation *op,
const ConversionTarget &target,
bool actuallyEmitDiagnostics) {
if (target.isLegal(op))
return success();
if (actuallyEmitDiagnostics) {
return op->emitError("found an op that was marked as backend illegal")
.attachNote()
.append("this is likely due to DecomposeComplexOps being unable to "
"decompose this op");
} else {
return failure();
}
}
static bool satisfiesBackendContract(ModuleOp module,
const ConversionTarget &target,
bool actuallyEmitDiagnostics = false) {
// We do not permit `torch.global_slot`'s in the backend contract, since
// support for them is not widespread, and this does not align with PyTorch's
// more tracing-based direction.
//
// We just check for the GlobalSlotModuleInitializerOp since its verifier
// ensures that the set of global slots matches those initialized by the
// module initializer.
auto walkResult0 = module.walk([&](Torch::GlobalSlotModuleInitializerOp op) {
if (actuallyEmitDiagnostics) {
// Report the error on the terminator to avoid dumping the whole
// initializer itself, which can have pages of ops in it.
op.getBody()
->getTerminator()
->emitError("unsupported by backend contract: module initializers")
.attachNote()
.append("this is likely due to InlineGlobalSlots being unable to "
"inline a global slot");
}
return WalkResult::interrupt();
});
if (walkResult0.wasInterrupted())
return false;
// Check all the types of all Value's in the program and the legality of all
// the ops.
//
// A pre-order walk gives a more intuitive "first error".
// TODO: Should we report more than the first error?
// How do we avoid making it too spammy?
auto walkResult1 = module.walk<WalkOrder::PreOrder>([&](Block *block) {
for (BlockArgument arg : block->getArguments())
if (failed(checkType(block->getParentOp(), arg.getType(),
actuallyEmitDiagnostics))) {
return WalkResult::interrupt();
}
for (Operation &op : *block) {
if (failed(checkOpIsBackendLegal(&op, target, actuallyEmitDiagnostics)))
return WalkResult::interrupt();
for (OpResult result : op.getResults())
if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult1.wasInterrupted())
return false;
return true;
}
// Explicitly set ops and dialects allowed and not allowed in backend contract.
static ConversionTarget
getBackendContractTarget(MLIRContext *context, bool decompose,
ArrayRef<std::string> backendLegalOps) {
ConversionTarget target(*context);
target.addLegalDialect<func::FuncDialect, Torch::TorchDialect>();
if (decompose)
markDecomposedOpsAsIllegal(context, target, backendLegalOps);
return target;
}
namespace {
class LowerToBackendContractPass
: public LowerToBackendContractBase<LowerToBackendContractPass> {
public:
LowerToBackendContractPass() = default;
LowerToBackendContractPass(int maxIterations, bool decompose,
ArrayRef<std::string> backendLegalOps) {
this->maxIterations = maxIterations;
this->decompose = decompose;
this->backendLegalOps = backendLegalOps;
}
void runOnOperation() override {
ModuleOp module = getOperation();
MLIRContext *context = &getContext();
ConversionTarget target =
getBackendContractTarget(context, decompose, backendLegalOps);
OpPassManager pm(module.getOperationName());
TorchLoweringPipelineOptions options;
options.decompose = decompose;
options.backendLegalOps = backendLegalOps;
createTorchSimplificationPipeline(pm, options);
int i = 0;
do {
if (i++ == maxIterations) {
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: "
<< "failed to satisfy backend contract after "
<< maxIterations
<< " iterations of the simplification pipeline\n";
});
// Show the diagnostics.
(void)satisfiesBackendContract(module, target,
/*actuallyEmitDiagnostics=*/true);
return signalPassFailure();
}
if (failed(runPipeline(pm, module)))
return signalPassFailure();
} while (!satisfiesBackendContract(module, target));
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: "
<< "succeeded after " << i
<< " iterations of the simplification pipeline\n";
});
}
};
class VerifyBackendContractNoDecompositionsPass
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> {
public:
VerifyBackendContractNoDecompositionsPass() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
ConversionTarget target =
getBackendContractTarget(context, /*decompose*/false,
/*backendLegalOps*/{});
if (!satisfiesBackendContract(getOperation(), target,
/*actuallyEmitDiagnostics=*/true)) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createLowerToBackendContractPass(
int maxIterations, bool decompose, ArrayRef<std::string> backendLegalOps) {
return std::make_unique<LowerToBackendContractPass>(maxIterations, decompose,
backendLegalOps);
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::Torch::createVerifyBackendContractNoDecompositionsPass() {
return std::make_unique<VerifyBackendContractNoDecompositionsPass>();
}
// The backend contract guarantees that ops with decompositions available will
// be decomposed. The only way to have an op reach the backend contract without
// getting decomposed is by having the user explicitly specify that op in the
// `backendLegalOps` argument to the `LowerToBackendContractPass`. Therefore,
// here we mark as illegal all ops with decompositions except for those in
// `backendLegalOps`.
//
// The legality check takes place here instead of in the `DecomposeComplexOps`
// pass for two reasons:
// 1. Makes sure the `DecomposeComplexOps` pass always succeeds, allowing it to
// run multiple times. This is needed for graphs where static information such
// as dtypes and shapes takes multiple iterations to propagate through the
// entire graph. `DecomposeComplexOps` pass failing would cause the entire
// `LowerToBackendContractPass` to fail
// 2. Makes the legality requirements in the backend contract for ops with
// decompositions explicit in this file
static void markDecomposedOpsAsIllegal(MLIRContext *context,
ConversionTarget &target,
ArrayRef<std::string> backendLegalOps) {
target.addIllegalOp<AtenSoftmaxIntOp>();
target.addIllegalOp<Aten_SoftmaxOp>();
target.addIllegalOp<Aten_LogSoftmaxOp>();
target.addIllegalOp<AtenLogSoftmaxIntOp>();
target.addIllegalOp<AtenEmptyLikeOp>();
target.addIllegalOp<AtenOnesLikeOp>();
target.addIllegalOp<AtenZerosLikeOp>();
target.addIllegalOp<AtenRollOp>();
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>();
target.addIllegalOp<AtenFlattenUsingIntsOp>();
target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenAddmmOp>();
target.addIllegalOp<AtenMeanOp>();
target.addIllegalOp<AtenMeanDimOp>();
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
std::optional<unsigned> lhsRank = getTensorRank(op.getSelf());
std::optional<unsigned> rhsRank = getTensorRank(op.getOther());
if (!lhsRank || !rhsRank)
return false;
// Make aten.matmul legal if the following condition is satisfied.
return (*lhsRank != 2 || *rhsRank != 2) && (*lhsRank != 3 || *rhsRank != 3);
});
target.addIllegalOp<AtenAddcmulOp>();
target.addIllegalOp<AtenAddcdivOp>();
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenNativeBatchNormOp>();
target.addIllegalOp<AtenConvolutionOverrideableOp>();
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
target.addIllegalOp<AtenConvolutionBackwardOp>();
target.addIllegalOp<AtenConv2dOp>();
target.addIllegalOp<AtenConvTranspose2dInputOp>();
target.addIllegalOp<AtenArangeOp>();
target.addIllegalOp<AtenArangeStartOp>();
target.addIllegalOp<AtenArgmaxOp>();
target.addIllegalOp<AtenSquareOp>();
target.addIllegalOp<AtenVarOp>();
target.addIllegalOp<AtenStdOp>();
target.addIllegalOp<Aten_UnsafeViewOp>();
target.addIllegalOp<Aten_ReshapeAliasOp>();
target.addIllegalOp<AtenBernoulliOp>();
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
target.addIllegalOp<AtenBernoulliPOp>();
target.addIllegalOp<AtenBernoulliTensorOp>();
target.addIllegalOp<AtenZeroOp>();
target.addIllegalOp<AtenRandLikeOp>();
target.addIllegalOp<AtenHardsigmoidOp>();
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenHardswishOp>();
target.addIllegalOp<AtenSoftplusOp>();
target.addIllegalOp<AtenSiluOp>();
target.addIllegalOp<AtenNewZerosOp>();
target.addIllegalOp<AtenNewOnesOp>();
target.addIllegalOp<AtenHardtanhOp>();
target.addIllegalOp<AtenFullOp>();
target.addIllegalOp<AtenLinearOp>();
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenIndexPutOp>();
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>();
target.addIllegalOp<AtenNewEmptyOp>();
target.addIllegalOp<AtenIndexPutHackedTwinOp>();
target.addIllegalOp<AtenPadOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenClampMinOp>();
target.addIllegalOp<AtenClampMaxOp>();
target.addIllegalOp<AtenBaddbmmOp>();
target.addIllegalOp<AtenFloorDivideOp>();
target.addIllegalOp<AtenNumpyTOp>();
target.addIllegalOp<AtenSelectScatterOp>();
target.addIllegalOp<AtenVarDimOp>();
target.addIllegalOp<AtenAmaxOp>();
target.addIllegalOp<AtenVarCorrectionOp>();
target.addIllegalOp<AtenStdDimOp>();
target.addIllegalOp<AtenStdCorrectionOp>();
target.addIllegalOp<AtenNarrowOp>();
target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
target.addIllegalOp<AtenVarMeanCorrectionOp>();
target.addIllegalOp<PrimsConvertElementTypeOp>();
target.addIllegalOp<PrimsVarOp>();
target.addIllegalOp<PrimsSqrtOp>();
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>();
for (std::string opName : backendLegalOps) {
target.addLegalOp(OperationName(opName, context));
}
}