-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add f6E2M3FN type #107999
[MLIR] Add f6E2M3FN type #107999
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-ods Author: Sergey Kozub (sergey-kozub) ChangesThis PR adds
f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs
Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125 Related PRs:
Full diff: https://github.com/llvm/llvm-project/pull/107999.diff 24 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 24531baecaa353..cc6da482a1c369 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
/// Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
+/// Returns the typeID of an Float6E2M3FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E2M3FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type);
+
+/// Creates an f6E2M3FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx);
+
/// Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5ac3a04b1c26ba..196d34e12d9b28 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
Attribute metadata = Attribute());
// Types.
+ FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 87ccc041f19758..f2231e9507570e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
+ static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
@@ -414,11 +415,15 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
- return llvm::isa<Float6E3M2FNType, Float8E5M2Type, Float8E4M3Type,
- Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
- Float16Type, FloatTF32Type, Float32Type, Float64Type,
- Float80Type, Float128Type>(type);
+ return llvm::isa<Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
+ Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+ Float64Type, Float80Type, Float128Type>(type);
+}
+
+inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {
+ return Float6E2M3FNType::get(ctx);
}
inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b54d4ee4b7eb7a..09c2d34dc7dd1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,27 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Float6E2M3FNType
+
+def Builtin_Float6E2M3FN : Builtin_FloatType<"Float6E2M3FN", "f6E2M3FN"> {
+ let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+ let description = [{
+ An 6-bit floating point type with 1 sign bit, 2 bits exponent and 3 bits
+ mantissa. This is not a standard type as defined by IEEE-754, but it
+ follows similar conventions with the following characteristics:
+
+ * bit encoding: S1E2M3
+ * exponent bias: 1
+ * infinities: Not supported
+ * NaNs: Not supported
+ * denormals when exponent is 0
+
+ Open Compute Project (OCP) microscaling formats (MX) specification:
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Float6E3M2FNType
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 09eab50f53a540..3cc1c95f1ed37a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+ BuildableType<"$_builder.getFloat6E2M3FNType()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
BuildableType<"$_builder.getFloat6E3M2FNType()">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0fe..8b6f365fbda02e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,6 +125,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
+ bool isFloat6E2M3FN() const;
bool isFloat6E3M2FN() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fa18cbe9e2b901..6ae64a17d1fadb 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
TOK_KEYWORD(f8E4M3B11FNUZ)
TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E2M3FN)
TOK_KEYWORD(f6E3M2FN)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 05276031211fa9..a3798ca8d90b1b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -39,6 +39,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
+ case Token::kw_f6E2M3FN:
case Token::kw_f6E3M2FN:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3:
@@ -304,6 +305,9 @@ Type Parser::parseNonFunctionType() {
}
// float-type
+ case Token::kw_f6E2M3FN:
+ consumeToken(Token::kw_f6E2M3FN);
+ return builder.getFloat6E2M3FNType();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
return builder.getFloat6E3M2FNType();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 1cb429d9ca7b2d..6b64bc3c9d6f63 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
}
};
+/// Floating Point Type subclass - Float6E2M3FNType.
+class PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
+ }
+};
+
/// Floating Point Type subclass - Float6E3M2FNType.
class PyFloat6E3M2FNType
: public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
@@ -901,6 +922,7 @@ void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
+ PyFloat6E2M3FNType::bind(m);
PyFloat6E3M2FNType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 254650d66a67e6..f943bf726b172c 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) {
return llvm::cast<FloatType>(unwrap(type)).getWidth();
}
+MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
+ return wrap(Float6E2M3FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
+ return unwrap(type).isFloat6E2M3FN();
+}
+
+MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
+ return wrap(FloatType::getFloat6E2M3FN(unwrap(ctx)));
+}
+
MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
return wrap(Float6E3M2FNType::getTypeID());
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b2c54bb3212edb..51a1b91338c6a0 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
- type.isFloat6E3M2FN())
+ type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
return IntegerType::get(&getContext(), type.getWidth());
return type;
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a5ee6edc6320d5..5e5e10b1fa1c2b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -55,6 +55,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
+ .Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5142b462820786..c7ed158aabb6e7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2575,6 +2575,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
opaqueTy.getTypeData());
})
.Case<IndexType>([&](Type) { os << "index"; })
+ .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
.Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 71f622b02adee0..144a13df2179b7 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
+FloatType Builder::getFloat6E2M3FNType() {
+ return FloatType::getFloat6E2M3FN(context);
+}
+
FloatType Builder::getFloat6E3M2FNType() {
return FloatType::getFloat6E3M2FN(context);
}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e46b6a4a6bb693..702d98ec31427b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -101,6 +101,8 @@ unsigned FloatType::getWidth() {
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
+ if (llvm::isa<Float6E2M3FNType>(*this))
+ return APFloat::Float6E2M3FN();
if (llvm::isa<Float6E3M2FNType>(*this))
return APFloat::Float6E3M2FN();
if (llvm::isa<Float8E5M2Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2851e6457ea3cb..1684566626886c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,6 +221,7 @@ class MLIRContextImpl {
llvm::DenseMap<StringRef, AbstractType *> nameToType;
/// Cached Type Instances.
+ Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
@@ -314,6 +315,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
//// Types.
/// Floating-point Types.
+ impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
@@ -1015,6 +1017,9 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
+Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
+ return context->getImpl().f6E2M3FNTy;
+}
Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
return context->getImpl().f6E3M2FNTy;
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f1..c828fd3766eaa7 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,6 +34,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
+bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); }
bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7b4fac7275bfc6..17a02b0bd445a7 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
"F32Type",
"F64Type",
"FlatSymbolRefAttr",
+ "Float6E2M3FNType",
"Float6E3M2FNType",
"Float8E3M4Type",
"Float8E4M3B11FNUZType",
@@ -1540,6 +1541,19 @@ class FlatSymbolRefAttr(Attribute):
Returns the value of the FlatSymbolRef attribute as a string
"""
+class Float6E2M3FNType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Optional[Context] = None) -> Float6E2M3FNType:
+ """
+ Create a float6_e2m3fn type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
class Float6E3M2FNType(FloatType):
static_typeid: ClassVar[TypeID]
@staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 0c6ece91d8b94a..4be425f220c978 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
F16Type,
F32Type,
F64Type,
+ Float6E2M3FNType,
Float6E3M2FNType,
Float8E3M4Type,
Float8E4M3B11FNUZType,
@@ -75,6 +76,7 @@ def ui(width):
f8E4M3FN = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
f8E3M4 = lambda: Float8E3M4Type.get()
+f6E2M3FN = lambda: Float6E2M3FNType.get()
f6E3M2FN = lambda: Float6E3M2FNType.get()
none = lambda: NoneType.get()
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 38cbf9d5d2b579..23dbf0c292c2c3 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -36,6 +36,10 @@ func.func @any_attr_of_fail() {
//===----------------------------------------------------------------------===//
func.func @float_attrs_pass() {
+ "test.float_attrs"() {
+ // CHECK: float_attr = 2.000000e+00 : f6E2M3FN
+ float_attr = 2. : f6E2M3FN
+ } : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f6E3M2FN
float_attr = 2. : f6E3M2FN
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 04be037978c8f6..7eca1a40373054 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -42,6 +42,9 @@ llvm.mlir.global internal @int_global_undef() : i64
// CHECK: @externally_initialized_global = internal externally_initialized global i32 0
llvm.mlir.global internal @externally_initialized_global(0 : i32) {externally_initialized} : i32
+// CHECK: @f6E2M3FN_global_as_i6 = internal global i6 12
+llvm.mlir.global internal @f6E2M3FN_global_as_i6(1.5 : f6E2M3FN) : i6
+
// CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b72ef4de0bd6dd..bc3ba4cd0b1448 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
def testFloatTypeSubclasses():
ctx = Context()
# CHECK: True
+ print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
+ # CHECK: True
print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
@@ -235,6 +237,8 @@ def testIndexType():
@run
def testFloatType():
with Context():
+ # CHECK: float: f6E2M3FN
+ print("float:", Float6E2M3FNType.get())
# CHECK: float: f6E3M2FN
print("float:", Float6E3M2FNType.get())
# CHECK: float: f8E3M4
@@ -613,6 +617,7 @@ def testTypeIDs():
types = [
(IntegerType, IntegerType.get_signless(16)),
(IndexType, IndexType.get()),
+ (Float6E2M3FNType, Float6E2M3FNType.get()),
(Float6E3M2FNType, Float6E3M2FNType.get()),
(Float8E3M4Type, Float8E3M4Type.get()),
(Float8E4M3Type, Float8E4M3Type.get()),
@@ -639,6 +644,7 @@ def testTypeIDs():
# CHECK: IntegerType(i16)
# CHECK: IndexType(index)
+ # CHECK: Float6E2M3FNType(f6E2M3FN)
# CHECK: Float6E3M2FNType(f6E3M2FN)
# CHECK: Float8E3M4Type(f8E3M4)
# CHECK: Float8E4M3Type(f8E4M3)
@@ -719,6 +725,9 @@ def print_downcasted(typ):
# CHECK: F64Type
# CHECK: F64Type(f64)
print_downcasted(F64Type.get())
+ # CHECK: Float6E2M3FNType
+ # CHECK: Float6E2M3FNType(f6E2M3FN)
+ print_downcasted(Float6E2M3FNType.get())
# CHECK: Float6E3M2FNType
# CHECK: Float6E3M2FNType(f6E3M2FN)
print_downcasted(Float6E3M2FNType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index fed149d03ecf31..350a0f7abea5a4 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -50,6 +50,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::CallSiteLoc": '"loc(callsite(...))"',
"mlir::FusedLoc": '"loc(fused<...>[...])"',
"mlir::UnknownLoc": '"loc(unknown)"',
+ "mlir::Float6E2M3FNType": '"f6E2M3FN"',
"mlir::Float6E3M2FNType": '"f6E3M2FN"',
"mlir::Float8E5M2Type": '"f8E5M2"',
"mlir::Float8E4M3Type": '"f8E4M3"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index d2c66714b4b118..9df1944f6255d9 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -231,7 +231,7 @@ const common = {
token(seq(choice('si', 'ui', 'i'), /[1-9]/, repeat(/[0-9]/))),
float_type : $ => token(
choice('f16', 'f32', 'f64', 'f80', 'f128', 'bf16', 'f8E3M4', 'f8E4M3FN',
- 'f8E4M3', 'f8E5M2', 'f6E3M2FN')),
+ 'f8E4M3', 'f8E5M2', 'f6E2M3FN', 'f6E3M2FN')),
index_type : $ => token('index'),
none_type : $ => token('none'),
complex_type : $ => seq(token('complex'), '<', $._prim_type, '>'),
|
@llvm/pr-subscribers-mlir-core Author: Sergey Kozub (sergey-kozub) ChangesThis PR adds
f6E2M3FN
- Exponent bias: 1
- Maximum stored exponent value: 3 (binary 11)
- Maximum unbiased exponent value: 3 - 1 = 2
- Minimum stored exponent value: 1 (binary 01)
- Minimum unbiased exponent value: 1 − 1 = 0
- Has Positive and Negative zero
- Doesn't have infinity
- Doesn't have NaNs
Additional details:
- Zeros (+/-): S.00.000
- Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5
- Min normal number: S.01.000 = ±2^(0) = ±1.0
- Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875
- Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125 Related PRs:
Full diff: https://github.com/llvm/llvm-project/pull/107999.diff 24 Files Affected:
diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h
index 24531baecaa353..cc6da482a1c369 100644
--- a/mlir/include/mlir-c/BuiltinTypes.h
+++ b/mlir/include/mlir-c/BuiltinTypes.h
@@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type);
/// Returns the bitwidth of a floating-point type.
MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type);
+/// Returns the typeID of an Float6E2M3FN type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void);
+
+/// Checks whether the given type is an f6E2M3FN type.
+MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type);
+
+/// Creates an f6E2M3FN type in the given context. The type is owned by the
+/// context.
+MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx);
+
/// Returns the typeID of an Float6E3M2FN type.
MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void);
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 5ac3a04b1c26ba..196d34e12d9b28 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -60,6 +60,7 @@ class Builder {
Attribute metadata = Attribute());
// Types.
+ FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 87ccc041f19758..f2231e9507570e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -67,6 +67,7 @@ class FloatType : public Type {
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
+ static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
@@ -414,11 +415,15 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
}
inline bool FloatType::classof(Type type) {
- return llvm::isa<Float6E3M2FNType, Float8E5M2Type, Float8E4M3Type,
- Float8E4M3FNType, Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, BFloat16Type,
- Float16Type, FloatTF32Type, Float32Type, Float64Type,
- Float80Type, Float128Type>(type);
+ return llvm::isa<Float6E2M3FNType, Float6E3M2FNType, Float8E5M2Type,
+ Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
+ Float64Type, Float80Type, Float128Type>(type);
+}
+
+inline FloatType FloatType::getFloat6E2M3FN(MLIRContext *ctx) {
+ return Float6E2M3FNType::get(ctx);
}
inline FloatType FloatType::getFloat6E3M2FN(MLIRContext *ctx) {
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index b54d4ee4b7eb7a..09c2d34dc7dd1b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -233,6 +233,27 @@ def Builtin_Float8E3M4 : Builtin_FloatType<"Float8E3M4", "f8E3M4"> {
}];
}
+//===----------------------------------------------------------------------===//
+// Float6E2M3FNType
+
+def Builtin_Float6E2M3FN : Builtin_FloatType<"Float6E2M3FN", "f6E2M3FN"> {
+ let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa";
+ let description = [{
+ An 6-bit floating point type with 1 sign bit, 2 bits exponent and 3 bits
+ mantissa. This is not a standard type as defined by IEEE-754, but it
+ follows similar conventions with the following characteristics:
+
+ * bit encoding: S1E2M3
+ * exponent bias: 1
+ * infinities: Not supported
+ * NaNs: Not supported
+ * denormals when exponent is 0
+
+ Open Compute Project (OCP) microscaling formats (MX) specification:
+ https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Float6E3M2FNType
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 09eab50f53a540..3cc1c95f1ed37a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -344,6 +344,8 @@ def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
BuildableType<"$_builder.getFloat8E3M4Type()">;
+def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
+ BuildableType<"$_builder.getFloat6E2M3FNType()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
BuildableType<"$_builder.getFloat6E3M2FNType()">;
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index b6a307fd7cb0fe..8b6f365fbda02e 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -125,6 +125,7 @@ class Type {
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const;
+ bool isFloat6E2M3FN() const;
bool isFloat6E3M2FN() const;
bool isFloat8E5M2() const;
bool isFloat8E4M3() const;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index fa18cbe9e2b901..6ae64a17d1fadb 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -101,6 +101,7 @@ TOK_KEYWORD(f8E5M2FNUZ)
TOK_KEYWORD(f8E4M3FNUZ)
TOK_KEYWORD(f8E4M3B11FNUZ)
TOK_KEYWORD(f8E3M4)
+TOK_KEYWORD(f6E2M3FN)
TOK_KEYWORD(f6E3M2FN)
TOK_KEYWORD(f128)
TOK_KEYWORD(false)
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 05276031211fa9..a3798ca8d90b1b 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -39,6 +39,7 @@ OptionalParseResult Parser::parseOptionalType(Type &type) {
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
+ case Token::kw_f6E2M3FN:
case Token::kw_f6E3M2FN:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3:
@@ -304,6 +305,9 @@ Type Parser::parseNonFunctionType() {
}
// float-type
+ case Token::kw_f6E2M3FN:
+ consumeToken(Token::kw_f6E2M3FN);
+ return builder.getFloat6E2M3FNType();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
return builder.getFloat6E3M2FNType();
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 1cb429d9ca7b2d..6b64bc3c9d6f63 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType<PyFloatType> {
}
};
+/// Floating Point Type subclass - Float6E2M3FNType.
+class PyFloat6E2M3FNType
+ : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+ static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+ mlirFloat6E2M3FNTypeGetTypeID;
+ static constexpr const char *pyClassName = "Float6E2M3FNType";
+ using PyConcreteType::PyConcreteType;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_static(
+ "get",
+ [](DefaultingPyMlirContext context) {
+ MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+ return PyFloat6E2M3FNType(context->getRef(), t);
+ },
+ py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
+ }
+};
+
/// Floating Point Type subclass - Float6E3M2FNType.
class PyFloat6E3M2FNType
: public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
@@ -901,6 +922,7 @@ void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);
+ PyFloat6E2M3FNType::bind(m);
PyFloat6E3M2FNType::bind(m);
PyFloat8E4M3FNType::bind(m);
PyFloat8E5M2Type::bind(m);
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index 254650d66a67e6..f943bf726b172c 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) {
return llvm::cast<FloatType>(unwrap(type)).getWidth();
}
+MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
+ return wrap(Float6E2M3FNType::getTypeID());
+}
+
+bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
+ return unwrap(type).isFloat6E2M3FN();
+}
+
+MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
+ return wrap(FloatType::getFloat6E2M3FN(unwrap(ctx)));
+}
+
MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
return wrap(Float6E3M2FNType::getTypeID());
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index b2c54bb3212edb..51a1b91338c6a0 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -250,7 +250,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) const {
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
- type.isFloat6E3M2FN())
+ type.isFloat6E2M3FN() || type.isFloat6E3M2FN())
return IntegerType::get(&getContext(), type.getWidth());
return type;
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index a5ee6edc6320d5..5e5e10b1fa1c2b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -55,6 +55,7 @@ static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
+ .Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 5142b462820786..c7ed158aabb6e7 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2575,6 +2575,7 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
opaqueTy.getTypeData());
})
.Case<IndexType>([&](Type) { os << "index"; })
+ .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
.Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 71f622b02adee0..144a13df2179b7 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -34,6 +34,10 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//
+FloatType Builder::getFloat6E2M3FNType() {
+ return FloatType::getFloat6E2M3FN(context);
+}
+
FloatType Builder::getFloat6E3M2FNType() {
return FloatType::getFloat6E3M2FN(context);
}
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index e46b6a4a6bb693..702d98ec31427b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -101,6 +101,8 @@ unsigned FloatType::getWidth() {
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
+ if (llvm::isa<Float6E2M3FNType>(*this))
+ return APFloat::Float6E2M3FN();
if (llvm::isa<Float6E3M2FNType>(*this))
return APFloat::Float6E3M2FN();
if (llvm::isa<Float8E5M2Type>(*this))
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 2851e6457ea3cb..1684566626886c 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -221,6 +221,7 @@ class MLIRContextImpl {
llvm::DenseMap<StringRef, AbstractType *> nameToType;
/// Cached Type Instances.
+ Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
@@ -314,6 +315,7 @@ MLIRContext::MLIRContext(const DialectRegistry ®istry, Threading setting)
//// Types.
/// Floating-point Types.
+ impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
@@ -1015,6 +1017,9 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
+Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
+ return context->getImpl().f6E2M3FNTy;
+}
Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
return context->getImpl().f6E3M2FNTy;
}
diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp
index fa093664cf77f1..c828fd3766eaa7 100644
--- a/mlir/lib/IR/Types.cpp
+++ b/mlir/lib/IR/Types.cpp
@@ -34,6 +34,7 @@ Type AbstractType::replaceImmediateSubElements(Type type,
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
+bool Type::isFloat6E2M3FN() const { return llvm::isa<Float6E2M3FNType>(*this); }
bool Type::isFloat6E3M2FN() const { return llvm::isa<Float6E3M2FNType>(*this); }
bool Type::isFloat8E5M2() const { return llvm::isa<Float8E5M2Type>(*this); }
bool Type::isFloat8E4M3() const { return llvm::isa<Float8E4M3Type>(*this); }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 7b4fac7275bfc6..17a02b0bd445a7 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -120,6 +120,7 @@ __all__ = [
"F32Type",
"F64Type",
"FlatSymbolRefAttr",
+ "Float6E2M3FNType",
"Float6E3M2FNType",
"Float8E3M4Type",
"Float8E4M3B11FNUZType",
@@ -1540,6 +1541,19 @@ class FlatSymbolRefAttr(Attribute):
Returns the value of the FlatSymbolRef attribute as a string
"""
+class Float6E2M3FNType(FloatType):
+ static_typeid: ClassVar[TypeID]
+ @staticmethod
+ def get(context: Optional[Context] = None) -> Float6E2M3FNType:
+ """
+ Create a float6_e2m3fn type.
+ """
+ @staticmethod
+ def isinstance(other: Type) -> bool: ...
+ def __init__(self, cast_from_type: Type) -> None: ...
+ @property
+ def typeid(self) -> TypeID: ...
+
class Float6E3M2FNType(FloatType):
static_typeid: ClassVar[TypeID]
@staticmethod
diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py
index 0c6ece91d8b94a..4be425f220c978 100644
--- a/mlir/python/mlir/extras/types.py
+++ b/mlir/python/mlir/extras/types.py
@@ -12,6 +12,7 @@
F16Type,
F32Type,
F64Type,
+ Float6E2M3FNType,
Float6E3M2FNType,
Float8E3M4Type,
Float8E4M3B11FNUZType,
@@ -75,6 +76,7 @@ def ui(width):
f8E4M3FN = lambda: Float8E4M3FNType.get()
f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get()
f8E3M4 = lambda: Float8E3M4Type.get()
+f6E2M3FN = lambda: Float6E2M3FNType.get()
f6E3M2FN = lambda: Float6E3M2FNType.get()
none = lambda: NoneType.get()
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 38cbf9d5d2b579..23dbf0c292c2c3 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -36,6 +36,10 @@ func.func @any_attr_of_fail() {
//===----------------------------------------------------------------------===//
func.func @float_attrs_pass() {
+ "test.float_attrs"() {
+ // CHECK: float_attr = 2.000000e+00 : f6E2M3FN
+ float_attr = 2. : f6E2M3FN
+ } : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f6E3M2FN
float_attr = 2. : f6E3M2FN
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 04be037978c8f6..7eca1a40373054 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -42,6 +42,9 @@ llvm.mlir.global internal @int_global_undef() : i64
// CHECK: @externally_initialized_global = internal externally_initialized global i32 0
llvm.mlir.global internal @externally_initialized_global(0 : i32) {externally_initialized} : i32
+// CHECK: @f6E2M3FN_global_as_i6 = internal global i6 12
+llvm.mlir.global internal @f6E2M3FN_global_as_i6(1.5 : f6E2M3FN) : i6
+
// CHECK: @f6E3M2FN_global_as_i6 = internal global i6 14
llvm.mlir.global internal @f6E3M2FN_global_as_i6(1.5 : f6E3M2FN) : i6
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index b72ef4de0bd6dd..bc3ba4cd0b1448 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -113,6 +113,8 @@ def testTypeIsInstance():
def testFloatTypeSubclasses():
ctx = Context()
# CHECK: True
+ print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType))
+ # CHECK: True
print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType))
# CHECK: True
print(isinstance(Type.parse("f8E3M4", ctx), FloatType))
@@ -235,6 +237,8 @@ def testIndexType():
@run
def testFloatType():
with Context():
+ # CHECK: float: f6E2M3FN
+ print("float:", Float6E2M3FNType.get())
# CHECK: float: f6E3M2FN
print("float:", Float6E3M2FNType.get())
# CHECK: float: f8E3M4
@@ -613,6 +617,7 @@ def testTypeIDs():
types = [
(IntegerType, IntegerType.get_signless(16)),
(IndexType, IndexType.get()),
+ (Float6E2M3FNType, Float6E2M3FNType.get()),
(Float6E3M2FNType, Float6E3M2FNType.get()),
(Float8E3M4Type, Float8E3M4Type.get()),
(Float8E4M3Type, Float8E4M3Type.get()),
@@ -639,6 +644,7 @@ def testTypeIDs():
# CHECK: IntegerType(i16)
# CHECK: IndexType(index)
+ # CHECK: Float6E2M3FNType(f6E2M3FN)
# CHECK: Float6E3M2FNType(f6E3M2FN)
# CHECK: Float8E3M4Type(f8E3M4)
# CHECK: Float8E4M3Type(f8E4M3)
@@ -719,6 +725,9 @@ def print_downcasted(typ):
# CHECK: F64Type
# CHECK: F64Type(f64)
print_downcasted(F64Type.get())
+ # CHECK: Float6E2M3FNType
+ # CHECK: Float6E2M3FNType(f6E2M3FN)
+ print_downcasted(Float6E2M3FNType.get())
# CHECK: Float6E3M2FNType
# CHECK: Float6E3M2FNType(f6E3M2FN)
print_downcasted(Float6E3M2FNType.get())
diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py
index fed149d03ecf31..350a0f7abea5a4 100644
--- a/mlir/utils/lldb-scripts/mlirDataFormatters.py
+++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py
@@ -50,6 +50,7 @@ def build_ptr_str_from_addr(addrValue: lldb.SBValue, type: lldb.SBType):
"mlir::CallSiteLoc": '"loc(callsite(...))"',
"mlir::FusedLoc": '"loc(fused<...>[...])"',
"mlir::UnknownLoc": '"loc(unknown)"',
+ "mlir::Float6E2M3FNType": '"f6E2M3FN"',
"mlir::Float6E3M2FNType": '"f6E3M2FN"',
"mlir::Float8E5M2Type": '"f8E5M2"',
"mlir::Float8E4M3Type": '"f8E4M3"',
diff --git a/mlir/utils/tree-sitter-mlir/grammar.js b/mlir/utils/tree-sitter-mlir/grammar.js
index d2c66714b4b118..9df1944f6255d9 100644
--- a/mlir/utils/tree-sitter-mlir/grammar.js
+++ b/mlir/utils/tree-sitter-mlir/grammar.js
@@ -231,7 +231,7 @@ const common = {
token(seq(choice('si', 'ui', 'i'), /[1-9]/, repeat(/[0-9]/))),
float_type : $ => token(
choice('f16', 'f32', 'f64', 'f80', 'f128', 'bf16', 'f8E3M4', 'f8E4M3FN',
- 'f8E4M3', 'f8E5M2', 'f6E3M2FN')),
+ 'f8E4M3', 'f8E5M2', 'f6E2M3FN', 'f6E3M2FN')),
index_type : $ => token('index'),
none_type : $ => token('none'),
complex_type : $ => seq(token('complex'), '<', $._prim_type, '>'),
|
mlir/include/mlir/IR/BuiltinTypes.td
Outdated
// Float6E2M3FNType | ||
|
||
def Builtin_Float6E2M3FN : Builtin_FloatType<"Float6E2M3FN", "f6E2M3FN"> { | ||
let summary = "6-bit floating point with 3 bits exponent and 2 bit mantissa"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2-bit exponent and 3-bit mantissa.
Seems that correct spelling for adjectives is 2-bit , 3-bit, etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
This PR adds `f6E2M3FN` type to mlir. `f6E2M3FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f6E2M3FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.000 - Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5 - Min normal number: S.01.000 = ±2^(0) = ±1.0 - Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875 - Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125 ``` Related PRs: - [PR-94735](llvm#94735) [APFloat] Add APFloat support for FP6 data types - [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR
b8e2f07
to
4e31910
Compare
This PR adds `f4E2M1FN` type to mlir. `f4E2M1FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 ``` Related PRs: - [PR-95392](#95392) [APFloat] Add APFloat support for FP4 data type - [PR-105573](#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](#107999) [MLIR] Add f6E2M3FN type
This PR adds `f4E2M1FN` type to mlir. `f4E2M1FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 ``` Related PRs: - [PR-95392](llvm#95392) [APFloat] Add APFloat support for FP4 data type - [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](llvm#107999) [MLIR] Add f6E2M3FN type
This PR adds `f8E8M0FNU` type to MLIR. `f8E8M0FNU` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values. ```c f8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - [PR-107127](#107127) [APFloat] Add APFloat support for E8M0 type - [PR-105573](#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](#107999) [MLIR] Add f6E2M3FN type - [PR-108877](#108877) [MLIR] Add f4E2M1FN type
This PR adds `f4E2M1FN` type to mlir. `f4E2M1FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 ``` Related PRs: - [PR-95392](llvm#95392) [APFloat] Add APFloat support for FP4 data type - [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](llvm#107999) [MLIR] Add f6E2M3FN type
This PR adds `f8E8M0FNU` type to MLIR. `f8E8M0FNU` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values. ```c f8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - [PR-107127](llvm#107127) [APFloat] Add APFloat support for E8M0 type - [PR-105573](llvm#105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](llvm#107999) [MLIR] Add f6E2M3FN type - [PR-108877](llvm#108877) [MLIR] Add f4E2M1FN type
…U) (#2581) This is a proposal to add MX (microscaling) floating point types to StableHLO. Related links: - StableHLO [PR#2582](#2582) Add MX floating point types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU) - LLVM [PR#95392](llvm/llvm-project#95392) [APFloat] Add APFloat support for FP4 data type - LLVM [PR#94735](llvm/llvm-project#94735) [APFloat] Add APFloat support for FP6 data types - LLVM [PR#107127](llvm/llvm-project#107127) [APFloat] Add APFloat support for E8M0 type - LLVM [PR#108877](llvm/llvm-project#108877) [MLIR] Add f4E2M1FN type - LLVM [PR#107999](llvm/llvm-project#107999) [MLIR] Add f6E2M3FN type - LLVM [PR#105573](llvm/llvm-project#105573) [MLIR] Add f6E3M2FN type - LLVM [PR#111028](llvm/llvm-project#111028) [MLIR] Add f8E8M0FNU type - JAX-ML [PR#181](jax-ml/ml_dtypes#181) Add sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn - JAX-ML [PR#166](jax-ml/ml_dtypes#181) Add float8_e8m0_fnu (E8M0) OCP MX scale format
This PR adds
f6E2M3FN
type to mlir.f6E2M3FN
type is proposed in OpenCompute MX Specification. It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values.Related PRs: