Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][DLTI] Pretty parsing and printing for DLTI attrs #113365

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/test/Fir/tco-default-datalayout.fir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ module {
// CHECK: module attributes {
// CHECK-SAME: dlti.dl_spec = #dlti.dl_spec<
// ...
// CHECK-SAME: #dlti.dl_entry<i64, dense<[32, 64]> : vector<2xi64>>,
// CHECK-SAME: i64 = dense<[32, 64]> : vector<2xi64>,
// ...
// CHECK-SAME: llvm.data_layout = ""
2 changes: 1 addition & 1 deletion flang/test/Fir/tco-explicit-datalayout.fir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i6
// CHECK: module attributes {
// CHECK-SAME: dlti.dl_spec = #dlti.dl_spec<
// ...
// CHECK-SAME: #dlti.dl_entry<i64, dense<128> : vector<2xi64>>,
// CHECK-SAME: i64 = dense<128> : vector<2xi64>,
// ...
// CHECK-SAME: llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:128-i128:128-f80:128-n8:16:32:64-S128"
57 changes: 31 additions & 26 deletions mlir/include/mlir/Dialect/DLTI/DLTIAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ def DLTI_DataLayoutSpecAttr :

/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) {
return llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
return ::llvm::cast<mlir::DataLayoutSpecInterface>(*this).queryHelper(key);
}
}];
}

//===----------------------------------------------------------------------===//
// MapAttr
//===----------------------------------------------------------------------===//

def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
let summary = "A mapping of DLTI-information by way of key-value pairs";
let description = [{
Expand All @@ -106,18 +110,16 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {

Consider the following flat encoding of a single-key dictionary
```
#dlti.map<#dlti.dl_entry<"CPU::cache::L1::size_in_bytes", 65536 : i32>>
#dlti.map<"CPU::cache::L1::size_in_bytes" = 65536 : i32>>
```
versus nested maps, which make it possible to obtain sub-dictionaries of
related information (with the following example making use of other
attributes that also implement the `DLTIQueryInterface`):
```
#dlti.target_system_spec<"CPU":
#dlti.target_device_spec<#dlti.dl_entry<"cache",
#dlti.map<#dlti.dl_entry<"L1",
#dlti.map<#dlti.dl_entry<"size_in_bytes", 65536 : i32>>>,
#dlti.dl_entry<"L1d",
#dlti.map<#dlti.dl_entry<"size_in_bytes", 32768 : i32>>> >>>>
#dlti.target_system_spec<"CPU" =
#dlti.target_device_spec<"cache" =
#dlti.map<"L1" = #dlti.map<"size_in_bytes" = 65536 : i32>,
"L1d" = #dlti.map<"size_in_bytes" = 32768 : i32> >>>
```

With the flat encoding, the implied structure of the key is ignored, that is
Expand All @@ -132,14 +134,13 @@ def DLTI_MapAttr : DLTIAttr<"Map", [DLTIQueryInterface]> {
`transform.dlti.query ["CPU","cache","L1","size_in_bytes"] at %op` gives
back the first leaf value contained. To access the other leaf, we need to do
`transform.dlti.query ["CPU","cache","L1d","size_in_bytes"] at %op`.
```
}];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
);
let mnemonic = "map";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) {
Expand Down Expand Up @@ -167,20 +168,23 @@ def DLTI_TargetSystemSpecAttr :
```
dlti.target_system_spec =
#dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.L1_cache_size_in_bytes", 4096: ui32>>,
"GPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>,
"XPU": #dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 4096 : ui32>>>
"CPU" = #dlti.target_device_spec<
"L1_cache_size_in_bytes" = 4096: ui32>,
"GPU" = #dlti.target_device_spec<
"max_vector_op_width" = 64 : ui32>,
"XPU" = #dlti.target_device_spec<
"max_vector_op_width" = 4096 : ui32>>
```

The verifier checks that keys are strings and pointed to values implement
DLTI's TargetDeviceSpecInterface.
}];
let parameters = (ins
ArrayRefParameter<"DeviceIDTargetDeviceSpecPair", "">:$entries
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
);
let mnemonic = "target_system_spec";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
/// Return the device specification that matches the given device ID
std::optional<TargetDeviceSpecInterface>
Expand All @@ -189,16 +193,18 @@ def DLTI_TargetSystemSpecAttr :

/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
return llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
return ::llvm::cast<mlir::TargetSystemSpecInterface>(*this).queryHelper(key);
}
}];
let extraClassDefinition = [{
std::optional<TargetDeviceSpecInterface>
$cppClass::getDeviceSpecForDeviceID(
TargetSystemSpecInterface::DeviceID deviceID) {
for (const auto& entry : getEntries()) {
if (entry.first == deviceID)
return entry.second;
if (entry.getKey() == DataLayoutEntryKey(deviceID))
if (auto deviceSpec =
::llvm::dyn_cast<TargetDeviceSpecInterface>(entry.getValue()))
return deviceSpec;
}
return std::nullopt;
}
Expand All @@ -219,21 +225,20 @@ def DLTI_TargetDeviceSpecAttr :

Example:
```
#dlti.target_device_spec<
#dlti.dl_entry<"dlti.max_vector_op_width", 64 : ui32>>
#dlti.target_device_spec<"max_vector_op_width" = 64 : ui32>
```
}];
let parameters = (ins
ArrayRefParameter<"DataLayoutEntryInterface", "">:$entries
ArrayRefParameter<"DataLayoutEntryInterface">:$entries
);
let mnemonic = "target_device_spec";
let genVerifyDecl = 1;
let assemblyFormat = "`<` $entries `>`";
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
/// Returns the attribute associated with the key.
FailureOr<Attribute> query(DataLayoutEntryKey key) const {
return llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
return ::llvm::cast<mlir::TargetDeviceSpecInterface>(*this).queryHelper(key);
}
}];
}
Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef MLIR_INTERFACES_DATALAYOUTINTERFACES_H
#define MLIR_INTERFACES_DATALAYOUTINTERFACES_H

#include "mlir/IR/Attributes.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/DenseMap.h"
Expand All @@ -32,10 +33,7 @@ using DataLayoutEntryKey = llvm::PointerUnion<Type, StringAttr>;
using DataLayoutEntryList = llvm::SmallVector<DataLayoutEntryInterface, 4>;
using DataLayoutEntryListRef = llvm::ArrayRef<DataLayoutEntryInterface>;
using TargetDeviceSpecListRef = llvm::ArrayRef<TargetDeviceSpecInterface>;
using DeviceIDTargetDeviceSpecPair =
std::pair<StringAttr, TargetDeviceSpecInterface>;
using DeviceIDTargetDeviceSpecPairListRef =
llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>;
using TargetDeviceSpecEntry = std::pair<StringAttr, TargetDeviceSpecInterface>;
class DataLayoutOpInterface;
class DataLayoutSpecInterface;
class ModuleOp;
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def TargetDeviceSpecInterface : AttrInterface<"TargetDeviceSpecInterface", [DLTI
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
if (auto strKey = llvm::dyn_cast<StringAttr>(key))
if (auto strKey = ::llvm::dyn_cast<StringAttr>(key))
if (DataLayoutEntryInterface spec = getSpecForIdentifier(strKey))
return spec.getValue();
return ::mlir::failure();
Expand Down Expand Up @@ -304,7 +304,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
let methods = [
InterfaceMethod<
/*description=*/"Returns the list of layout entries.",
/*retTy=*/"llvm::ArrayRef<DeviceIDTargetDeviceSpecPair>",
/*retTy=*/"::llvm::ArrayRef<DataLayoutEntryInterface>",
/*methodName=*/"getEntries",
/*args=*/(ins)
>,
Expand Down Expand Up @@ -334,7 +334,7 @@ def TargetSystemSpecInterface : AttrInterface<"TargetSystemSpecInterface", [DLTI
/// Helper for default implementation of `DLTIQueryInterface`'s `query`.
::mlir::FailureOr<::mlir::Attribute>
queryHelper(::mlir::DataLayoutEntryKey key) const {
if (auto strKey = llvm::dyn_cast<::mlir::StringAttr>(key))
if (auto strKey = ::llvm::dyn_cast<::mlir::StringAttr>(key))
if (auto deviceSpec = getDeviceSpecForDeviceID(strKey))
return *deviceSpec;
return ::mlir::failure();
Expand Down
Loading
Loading