Skip to content

Commit

Permalink
[mlir][sparse] Add more error messages and avoid crashing in new pars…
Browse files Browse the repository at this point in the history
…er (#67034)

Updates:
1. Added more invalid encodings to test the robustness of the new syntax
2. Changed the asserts that caused crashing into returning booleans
3. Modified some error messages to make them clearer and handled
failures in parsing quotes as keyword for level formats and properties.
  • Loading branch information
yinying-lisa-li authored Sep 22, 2023
1 parent 62a3d84 commit 8466eb7
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 44 deletions.
15 changes: 9 additions & 6 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ using namespace mlir::sparse_tensor::ir_detail;

FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
StringRef base;
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&base));
uint8_t properties = 0;
const auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
"expected valid level format (e.g. dense, compressed or singleton)")
uint8_t properties = 0;

ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalParen,
Expand All @@ -73,19 +74,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
} else if (base.compare("singleton") == 0) {
properties |= static_cast<uint8_t>(LevelFormat::Singleton);
} else {
parser.emitError(loc, "unknown level format");
parser.emitError(loc, "unknown level format: ") << base;
return failure();
}

ERROR_IF(!isValidDLT(static_cast<DimLevelType>(properties)),
"invalid level type");
"invalid level type: level format doesn't support the properties");
return properties;
}

ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
uint8_t *properties) const {
StringRef strVal;
FAILURE_IF_FAILED(parser.parseOptionalKeyword(&strVal));
auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
Expand All @@ -95,7 +98,7 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
} else if (strVal.compare("block2_4") == 0) {
*properties |= static_cast<uint8_t>(LevelNondefaultProperty::Block2_4);
} else {
parser.emitError(parser.getCurrentLocation(), "unknown level property");
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
}
return success();
Expand Down
50 changes: 19 additions & 31 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,34 +196,25 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) {
return pair1 <= pair2 ? sm1 : sm2;
}

LLVM_ATTRIBUTE_UNUSED static void
assertInternalConsistency(VarEnv const &env, VarInfo::ID id, StringRef name) {
#ifndef NDEBUG
bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) {
const auto &var = env.access(id);
assert(var.getName() == name && "found inconsistent name");
assert(var.getID() == id && "found inconsistent VarInfo::ID");
#endif // NDEBUG
return (var.getName() == name && var.getID() == id);
}

// NOTE(wrengr): if we can actually obtain an `AsmParser` for `minSMLoc`
// (or find some other way to convert SMLoc to FileLineColLoc), then this
// would no longer be `const VarEnv` (and couldn't be a free-function either).
LLVM_ATTRIBUTE_UNUSED static void assertUsageConsistency(VarEnv const &env,
VarInfo::ID id,
llvm::SMLoc loc,
VarKind vk) {
#ifndef NDEBUG
bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc,
VarKind vk) {
const auto &var = env.access(id);
assert(var.getKind() == vk &&
"a variable of that name already exists with a different VarKind");
// Since the same variable can occur at several locations,
// it would not be appropriate to do `assert(var.getLoc() == loc)`.
/* TODO(wrengr):
const auto minLoc = minSMLoc(_, var.getLoc(), loc);
assert(minLoc && "Location mismatch/incompatibility");
var.loc = minLoc;
// */
#endif // NDEBUG
return var.getKind() == vk;
}

std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
Expand All @@ -236,24 +227,23 @@ std::optional<VarInfo::ID> VarEnv::lookup(StringRef name) const {
if (iter == ids.end())
return std::nullopt;
const auto id = iter->second;
#ifndef NDEBUG
assertInternalConsistency(*this, id, name);
#endif // NDEBUG
if (!isInternalConsistent(*this, id, name))
return std::nullopt;
return id;
}

std::pair<VarInfo::ID, bool> VarEnv::create(StringRef name, llvm::SMLoc loc,
VarKind vk, bool verifyUsage) {
std::optional<std::pair<VarInfo::ID, bool>>
VarEnv::create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage) {
const auto &[iter, didInsert] = ids.try_emplace(name, nextID());
const auto id = iter->second;
if (didInsert) {
vars.emplace_back(id, name, loc, vk);
} else {
#ifndef NDEBUG
assertInternalConsistency(*this, id, name);
if (verifyUsage)
assertUsageConsistency(*this, id, loc, vk);
#endif // NDEBUG
if (!isInternalConsistent(*this, id, name))
return std::nullopt;
if (verifyUsage)
if (!isUsageConsistent(*this, id, loc, vk))
return std::nullopt;
}
return std::make_pair(id, didInsert);
}
Expand All @@ -265,20 +255,18 @@ VarEnv::lookupOrCreate(Policy creationPolicy, StringRef name, llvm::SMLoc loc,
case Policy::MustNot: {
const auto oid = lookup(name);
if (!oid)
return std::nullopt; // Doesn't exist, but must not create.
#ifndef NDEBUG
assertUsageConsistency(*this, *oid, loc, vk);
#endif // NDEBUG
return std::nullopt; // Doesn't exist, but must not create.
if (!isUsageConsistent(*this, *oid, loc, vk))
return std::nullopt;
return std::make_pair(*oid, false);
}
case Policy::May:
return create(name, loc, vk, /*verifyUsage=*/true);
case Policy::Must: {
const auto res = create(name, loc, vk, /*verifyUsage=*/false);
// const auto id = res.first;
const auto didCreate = res.second;
const auto didCreate = res->second;
if (!didCreate)
return std::nullopt; // Already exists, but must create.
return std::nullopt; // Already exists, but must create.
return res;
}
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/SparseTensor/IR/Detail/Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ class VarEnv final {
/// for the variable with the given name (i.e., either the newly created
/// variable, or the pre-existing variable), and a bool indicating whether
/// a new variable was created.
std::pair<VarInfo::ID, bool> create(StringRef name, llvm::SMLoc loc,
VarKind vk, bool verifyUsage = false);
std::optional<std::pair<VarInfo::ID, bool>>
create(StringRef name, llvm::SMLoc loc, VarKind vk, bool verifyUsage = false);

/// Attempts to lookup or create a variable according to the given
/// `Policy`. Returns nullopt in one of two circumstances:
Expand Down
163 changes: 158 additions & 5 deletions mlir/test/Dialect/SparseTensor/invalid_encoding.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,49 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+1 {{expected a non-empty array for lvlTypes}}
#a = #sparse_tensor.encoding<{lvlTypes = []}>
// expected-error@+1 {{expected '(' in dimension-specifier list}}
#a = #sparse_tensor.encoding<{map = []}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected '->'}}
#a = #sparse_tensor.encoding<{map = ()}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected ')' in dimension-specifier list}}
#a = #sparse_tensor.encoding<{map = (d0 -> d0)}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected '(' in dimension-specifier list}}
#a = #sparse_tensor.encoding<{map = d0 -> d0}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected '(' in level-specifier list}}
#a = #sparse_tensor.encoding<{map = (d0) -> d0}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected ':'}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0)}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected valid level format (e.g. dense, compressed or singleton)}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0:)}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----

// expected-error@+1 {{expected valid level format (e.g. dense, compressed or singleton)}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : (compressed))}>
func.func private @scalar(%arg0: tensor<f64, #a>) -> ()

// -----
Expand All @@ -18,17 +60,61 @@ func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

#a = #sparse_tensor.encoding<{lvlTypes = [1]}> // expected-error {{expected a string value in lvlTypes}}
// expected-error@+1 {{unexpected dimToLvl mapping from 2 to 1}}
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense)}>
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected bare identifier}}
#a = #sparse_tensor.encoding<{map = (1)}>
func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{unexpected key: nap}}
#a = #sparse_tensor.encoding<{nap = (d0) -> (d0 : dense)}>
func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected '(' in dimension-specifier list}}
#a = #sparse_tensor.encoding<{map = -> (d0 : dense)}>
func.func private @tensor_type_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

#a = #sparse_tensor.encoding<{lvlTypes = ["strange"]}> // expected-error {{unexpected level-type: strange}}
// expected-error@+1 {{unknown level format: strange}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : strange)}>
func.func private @tensor_value_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

#a = #sparse_tensor.encoding<{dimToLvl = "wrong"}> // expected-error {{expected an affine map for dimToLvl}}
// expected-error@+1 {{expected valid level format (e.g. dense, compressed or singleton)}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : "wrong")}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected valid level property (e.g. nonordered, nonunique or high)}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed("wrong"))}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----
// expected-error@+1 {{expected ')' in level-specifier list}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed[high])}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{unknown level property: wrong}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed(wrong))}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----

// expected-error@+1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed, dense)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()

// -----
Expand All @@ -39,6 +125,73 @@ func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> ()

// -----

// expected-error@+1 {{unexpected character}}
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed; d1 : dense)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected attribute value}}
#a = #sparse_tensor.encoding<{map = (d0: d1) -> (d0 : compressed, d1 : dense)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected ':'}}
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 = compressed, d1 = dense)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected attribute value}}
#a = #sparse_tensor.encoding<{map = (d0 : compressed, d1 : compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

// expected-error@+1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = (d0 = compressed, d1 = compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

// expected-error@+1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = (d0 = l0, d1 = l1) {l0, l1} -> (l0 = d0 : dense, l1 = d1 : compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

// expected-error@+1 {{expected '='}}
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : d0 = dense, l1 : d1 = compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----
// expected-error@+1 {{use of undeclared identifier 'd0'}}
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : l0 = dense, d1 : l1 = compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----
// expected-error@+1 {{use of undeclared identifier 'd0'}}
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 : dense, d1 : compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----
// expected-error@+1 {{expected '='}}
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 : dense, l1 : compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----
// expected-error@+1 {{use of undeclared identifier}}
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (l0 = dense, l1 = compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----
// expected-error@+1 {{use of undeclared identifier 'd0'}}
#a = #sparse_tensor.encoding<{map = {l0, l1} (d0 = l0, d1 = l1) -> (d0 = l0 : dense, d1 = l1 : compressed)}>
func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<16x32xi32, #a>) -> ()

// -----

#a = #sparse_tensor.encoding<{posWidth = "x"}> // expected-error {{expected an integral position bitwidth}}
func.func private @tensor_no_int_ptr(%arg0: tensor<16x32xf32, #a>) -> ()

Expand Down

0 comments on commit 8466eb7

Please sign in to comment.