Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge main -> google #5500

Merged
merged 12 commits into from
Apr 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ echo "Building with Ninja"
cd "${CMAKE_BUILD_DIR?}"
ninja

export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-$(nproc)}
# Limit parallelism dramatically to avoid exhausting GPU memory
# TODO(#5162): Handle this more robustly
export CTEST_PARALLEL_LEVEL=${CTEST_PARALLEL_LEVEL:-1}

# Only test drivers that use the GPU, since we run all tests on non-GPU machines
# as well.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
// CHECK-NEXT: %[[IN0_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list<index>
// CHECK-NEXT: %[[IN0_D0:.+]] = shapex.ranked_dim %[[IN0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[IN0_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[IN0_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list<index>
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: ^bb2:
// CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index
Expand All @@ -50,10 +50,10 @@
// CHECK-NEXT: %[[IN1_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list<index>
// CHECK-NEXT: %[[IN1_D0:.+]] = shapex.ranked_dim %[[IN1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[IN1_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[IN1_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list<index>
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: ^bb4:
// CHECK-NEXT: return
Expand All @@ -64,15 +64,15 @@
// CHECK: %[[IS_0:.+]] = cmpi eq, %[[INDEX]], %c0 : index
// CHECK-NEXT: cond_br %[[IS_0]], ^bb1, ^bb2
// CHECK-NEXT: ^bb1:
// CHECK-NEXT: %[[IN0_D0:.+]] = iree.list.get %[[LIST]], %c0 : !iree.list<index>
// CHECK-NEXT: %[[IN0_D0:.+]] = iree.list.get %[[LIST]][%c0] : !iree.list<index>
// CHECK-NEXT: %[[IN0_SHAPE:.+]] = shapex.make_ranked_shape %[[IN0_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: flow.variable.store %[[IN0_SHAPE]], @_tflite_dynamicEntry_input0_shape : !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: ^bb2:
// CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index
// CHECK-NEXT: cond_br %[[IS_1]], ^bb3, ^bb4
// CHECK-NEXT: ^bb3:
// CHECK-NEXT: %[[IN1_D0:.+]] = iree.list.get %[[LIST]], %c0 : !iree.list<index>
// CHECK-NEXT: %[[IN1_D0:.+]] = iree.list.get %[[LIST]][%c0] : !iree.list<index>
// CHECK-NEXT: %[[IN1_SHAPE:.+]] = shapex.make_ranked_shape %[[IN1_D0]] : (index) -> !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: flow.variable.store %[[IN1_SHAPE]], @_tflite_dynamicEntry_input1_shape : !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: br ^bb4
Expand All @@ -90,10 +90,10 @@
// CHECK-NEXT: %[[OUT0_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_output0_shape : !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list<index>
// CHECK-NEXT: %[[OUT0_D0:.+]] = shapex.ranked_dim %[[OUT0_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[OUT0_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[OUT0_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list<index>
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: ^bb2:
// CHECK-NEXT: %[[IS_1:.+]] = cmpi eq, %[[INDEX]], %c1 : index
Expand All @@ -102,10 +102,10 @@
// CHECK-NEXT: %[[OUT1_SHAPE:.+]] = flow.variable.load @_tflite_dynamicEntry_output1_shape : !shapex.ranked_shape<[?,8,8,3]>
// CHECK-NEXT: iree.list.resize %[[LIST]], %c4 : !iree.list<index>
// CHECK-NEXT: %[[OUT1_D0:.+]] = shapex.ranked_dim %[[OUT1_SHAPE]][0] : !shapex.ranked_shape<[?,8,8,3]> -> index
// CHECK-NEXT: iree.list.set %[[LIST]], %c0, %[[OUT1_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c1, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c2, %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]], %c3, %c3 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c0], %[[OUT1_D0]] : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c1], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c2], %c8 : !iree.list<index>
// CHECK-NEXT: iree.list.set %[[LIST]][%c3], %c3 : !iree.list<index>
// CHECK-NEXT: br ^bb4
// CHECK-NEXT: ^bb4:
// CHECK-NEXT: return
Expand Down
1 change: 0 additions & 1 deletion iree/compiler/Conversion/LinalgToLLVM/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ void buildLLVMTransformPassPipeline(OpPassManager &passManager,
if (options.usingLinalgOnTensors) {
passManager.addPass(createMaterializeCPULaunchConfigurationPass());
OpPassManager &nestedModulePM = passManager.nest<ModuleOp>();
nestedModulePM.addPass(createInlinerPass());
// TODO(ataei): We want to enable when tensor -> vector pass is fully
// supported which requires first moving vector-tiling before this step.
if (options.useLinalgOnTensorsToVectors) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ static PassRegistration<OutlineLargeConstantsPass> pass(
"iree-flow-outline-large-constants",
"Outlines large tensor constants into flow.variables at the module level.",
[] {
return std::make_unique<OutlineLargeConstantsPass>(kMinLargeConstantSize);
// TODO(#5493): add a flag for this.
return std::make_unique<OutlineLargeConstantsPass>(256);
});

} // namespace Flow
Expand Down
6 changes: 3 additions & 3 deletions iree/compiler/Dialect/Flow/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createExportBenchmarkFuncsPass();

// Outlines large tensor constants into flow.variables at the module level.
//
// NOTE: a total guess :) this feels like about the most per-dispatch-buffer
// data we'd want to embed in the command buffer.
static constexpr size_t kMinLargeConstantSize = 256;
// TODO(#5493): implement the support for inlining constants into the command
// buffer and raise this value to one that is measured to be good.
static constexpr size_t kMinLargeConstantSize = 1;
std::unique_ptr<OperationPass<ModuleOp>> createOutlineLargeConstantsPass(
size_t minLargeConstantSize = kMinLargeConstantSize);

Expand Down
7 changes: 4 additions & 3 deletions iree/compiler/Dialect/HAL/Target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,12 @@ class CUDATargetBackend final : public TargetBackend {

llvmModule->setDataLayout(targetMachine->createDataLayout());

std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine);
FlatbufferBuilder builder;
iree_CUDAExecutableDef_start_as_root(builder);

// Serialize cuda kernel into the binary that we will embed in the
// final flatbuffer.
FlatbufferBuilder builder;
std::string targetISA = translateModuleToISA(*llvmModule, *targetMachine);
auto ptxCudeRef = flatbuffers_uint8_vec_create(
builder, reinterpret_cast<const uint8_t *>(targetISA.c_str()),
targetISA.size());
Expand All @@ -168,7 +170,6 @@ class CUDATargetBackend final : public TargetBackend {
}
auto blockSizesRef = iree_CUDABlockSizeDef_vec_end(builder);

iree_CUDAExecutableDef_start_as_root(builder);
iree_CUDAExecutableDef_entry_points_add(builder, entryPointsRef);
iree_CUDAExecutableDef_block_sizes_add(builder, blockSizesRef);
iree_CUDAExecutableDef_ptx_image_add(builder, ptxCudeRef);
Expand Down
5 changes: 3 additions & 2 deletions iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,11 @@ class LLVMAOTTargetBackend final : public TargetBackend {
linkArtifacts.keepAllFiles();
}

FlatbufferBuilder builder;
iree_DyLibExecutableDef_start_as_root(builder);

// Embed debug symbols at the end of the flatbuffer by adding first in the
// bottoms-up builder.
FlatbufferBuilder builder;
flatbuffers_uint8_vec_ref_t debugDatabaseRef = 0;
flatbuffers_string_ref_t debugDatabaseFilenameRef = 0;
if (options_.debugSymbols && linkArtifacts.debugFile.outputFile) {
Expand All @@ -328,7 +330,6 @@ class LLVMAOTTargetBackend final : public TargetBackend {
<< linkArtifacts.libraryFile.path;
}

iree_DyLibExecutableDef_start_as_root(builder);
iree_DyLibExecutableDef_library_embedded_add(builder, libraryEmbeddedRef);
iree_DyLibExecutableDef_debug_database_filename_add(
builder, debugDatabaseFilenameRef);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend {

// 4. Pack the MTLLibrary and metadata into a flatbuffer.
FlatbufferBuilder builder;
iree_MetalExecutableDef_start_as_root(builder);

auto shaderSourcesRef = builder.createStringVec(llvm::map_range(
mslShaders, [&](const MetalShader &shader) { return shader.source; }));
Expand All @@ -135,7 +136,6 @@ class MetalSPIRVTargetBackend : public SPIRVTargetBackend {

auto entryPointNamesRef = builder.createStringVec(entryPointNames);

iree_MetalExecutableDef_start_as_root(builder);
iree_MetalExecutableDef_entry_points_add(builder, entryPointNamesRef);
iree_MetalExecutableDef_threadgroup_sizes_add(builder, threadgroupSizesRef);
iree_MetalExecutableDef_shader_sources_add(builder, shaderSourcesRef);
Expand Down
5 changes: 3 additions & 2 deletions iree/compiler/Dialect/HAL/Target/VMLA/VMLATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ class VMLATargetBackend final : public TargetBackend {

LogicalResult serializeExecutable(IREE::HAL::ExecutableTargetOp targetOp,
OpBuilder &executableBuilder) override {
// Serialize the VM module to bytes directly into a flatbuffer.
FlatbufferBuilder builder;
iree_VMLAExecutableDef_start_as_root(builder);

// Serialize the VM module to bytes directly into a flatbuffer.
IREE::VM::BytecodeTargetOptions bytecodeOptions;
auto dataRef = builder.streamUint8Vec([&](raw_ostream &stream) {
return succeeded(translateModuleToBytecode(targetOp.getInnerModule(),
Expand All @@ -115,7 +117,6 @@ class VMLATargetBackend final : public TargetBackend {

// Pack the executable definition and get the bytes with the proper header.
// The header is used to verify the contents at runtime.
iree_VMLAExecutableDef_start_as_root(builder);
iree_VMLAExecutableDef_bytecode_module_add(builder, dataRef);
iree_VMLAExecutableDef_end_as_root(builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
ModuleOp innerModuleOp = targetOp.getInnerModule();
auto spvModuleOp = *innerModuleOp.getOps<spirv::ModuleOp>().begin();

FlatbufferBuilder builder;
iree_SpirVExecutableDef_start_as_root(builder);

// Serialize the spirv::ModuleOp into the binary that we will embed in the
// final flatbuffer.
FlatbufferBuilder builder;
SmallVector<uint32_t, 256> spvBinary;
if (failed(spirv::serialize(spvModuleOp, spvBinary)) || spvBinary.empty()) {
return targetOp.emitError() << "failed to serialize spv.module";
Expand All @@ -157,7 +159,6 @@ class VulkanSPIRVTargetBackend : public SPIRVTargetBackend {
}
auto entryPointsRef = builder.createStringVec(entryPointNames);

iree_SpirVExecutableDef_start_as_root(builder);
iree_SpirVExecutableDef_entry_points_add(builder, entryPointsRef);
iree_SpirVExecutableDef_code_add(builder, spvCodeRef);
iree_SpirVExecutableDef_end_as_root(builder);
Expand Down
23 changes: 19 additions & 4 deletions iree/compiler/Dialect/IREE/IR/IREEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ IREEDialect::IREEDialect(MLIRContext* context)
Type IREEDialect::parseType(DialectAsmParser& parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
llvm::StringRef spec = parser.getFullSymbolSpec();
if (spec.consume_front("ptr")) {
if (spec == "variant") {
return IREE::VariantType::get(getContext());
} else if (spec.consume_front("ptr")) {
if (!spec.consume_front("<") || !spec.consume_back(">")) {
parser.emitError(parser.getCurrentLocation())
<< "malformed ptr type '" << parser.getFullSymbolSpec() << "'";
Expand All @@ -63,7 +65,12 @@ Type IREEDialect::parseType(DialectAsmParser& parser) const {
<< "malformed list type '" << parser.getFullSymbolSpec() << "'";
return Type();
}
auto elementType = mlir::parseType(spec, getContext());
Type elementType;
if (spec == "?") {
elementType = IREE::VariantType::get(getContext());
} else {
elementType = mlir::parseType(spec, getContext());
}
if (!elementType) {
parser.emitError(parser.getCurrentLocation())
<< "invalid list element type specification: '"
Expand All @@ -77,14 +84,22 @@ Type IREEDialect::parseType(DialectAsmParser& parser) const {
}

void IREEDialect::printType(Type type, DialectAsmPrinter& os) const {
if (auto ptrType = type.dyn_cast<IREE::PtrType>()) {
if (type.isa<IREE::VariantType>()) {
os << "variant";
} else if (auto ptrType = type.dyn_cast<IREE::PtrType>()) {
os << "ptr<" << ptrType.getTargetType() << ">";
} else if (type.isa<IREE::ByteBufferType>()) {
os << "byte_buffer";
} else if (type.isa<IREE::MutableByteBufferType>()) {
os << "mutable_byte_buffer";
} else if (auto listType = type.dyn_cast<IREE::ListType>()) {
os << "list<" << listType.getElementType() << ">";
os << "list<";
if (listType.getElementType().isa<IREE::VariantType>()) {
os << "?";
} else {
os << listType.getElementType();
}
os << ">";
} else {
llvm_unreachable("unhandled IREE type");
}
Expand Down
79 changes: 59 additions & 20 deletions iree/compiler/Dialect/IREE/IR/IREEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,43 +155,82 @@ void UnfoldableConstantOp::getCanonicalizationPatterns(
// Lists
//===----------------------------------------------------------------------===//

static ParseResult parseListType(OpAsmParser &parser, Type &listType,
Type &elementType) {
static ParseResult parseListTypeGet(OpAsmParser &parser, Type &listType,
Type &elementType) {
if (failed(parser.parseType(listType))) {
return parser.emitError(parser.getCurrentLocation(),
"expected !iree.list<> type");
"expected !iree.list<T> type");
}
auto listElementType = listType.cast<ListType>().getElementType();
if (succeeded(parser.parseOptionalArrow())) {
// Use overridden type - required for variants only.
if (failed(parser.parseType(elementType))) {
return parser.emitError(
parser.getCurrentLocation(),
"expected an element type when specifying list access types");
}
if (!ListType::canImplicitlyCast(listElementType, elementType)) {
return parser.emitError(
parser.getCurrentLocation(),
"list access types must match the same base type as the list element "
"type (when not variant)");
}
} else {
// Use list element type as the result element type.
elementType = listElementType;
}
elementType = listType.cast<ListType>().getElementType();
return success();
}

static ParseResult parseListType(OpAsmParser &parser, Type &listType,
SmallVectorImpl<Type> &elementTypes) {
if (failed(parser.parseType(listType))) {
static void printListTypeGet(OpAsmPrinter &printer, Operation *, Type listType,
Type elementType) {
printer.printType(listType);
auto listElementType = listType.cast<ListType>().getElementType();
if (listElementType != elementType) {
printer.printArrowTypeList(ArrayRef<Type>{elementType});
}
}

static ParseResult parseListTypeSet(OpAsmParser &parser, Type &listType,
Type &elementType) {
Type leadingType;
if (failed(parser.parseType(leadingType))) {
return parser.emitError(parser.getCurrentLocation(),
"expected !iree.list<> type");
"expected element type or !iree.list<T> type");
}
for (size_t i = 0; i < elementTypes.size(); ++i) {
elementTypes[i] = listType.cast<ListType>().getElementType();
if (succeeded(parser.parseOptionalArrow())) {
elementType = leadingType;
if (failed(parser.parseType(listType)) || !listType.isa<ListType>()) {
return parser.emitError(parser.getCurrentLocation(),
"expected an !iree.list<T> type");
}
} else {
if (!leadingType.isa<ListType>()) {
return parser.emitError(parser.getCurrentLocation(),
"expected an !iree.list<T> type");
}
listType = leadingType;
elementType = listType.cast<ListType>().getElementType();
}
return success();
}

static void printListType(OpAsmPrinter &printer, Operation *, Type listType,
Type elementType) {
printer.printType(listType);
}

static void printListType(OpAsmPrinter &printer, Operation *, Type listType,
TypeRange elementTypes) {
printer.printType(listType);
static void printListTypeSet(OpAsmPrinter &printer, Operation *, Type listType,
Type elementType) {
auto listElementType = listType.cast<ListType>().getElementType();
if (listElementType != elementType) {
printer.printType(elementType);
printer.printArrowTypeList(ArrayRef<Type>{listType});
} else {
printer.printType(listType);
}
}

static LogicalResult verifyListGetOp(ListGetOp &op) {
auto listType = op.list().getType().cast<IREE::ListType>();
auto elementType = listType.getElementType();
auto resultType = op.result().getType();
if (resultType != elementType) {
if (!ListType::canImplicitlyCast(elementType, resultType)) {
return op.emitError() << "list contains " << elementType
<< " and cannot be accessed as " << resultType;
}
Expand All @@ -202,7 +241,7 @@ static LogicalResult verifyListSetOp(ListSetOp &op) {
auto listType = op.list().getType().cast<IREE::ListType>();
auto elementType = listType.getElementType();
auto valueType = op.value().getType();
if (valueType != elementType) {
if (!ListType::canImplicitlyCast(valueType, elementType)) {
return op.emitError() << "list contains " << elementType
<< " and cannot be mutated as " << valueType;
}
Expand Down
Loading