Skip to content

Commit

Permalink
Remove lowerDiffsCarryChecked.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 1, 2021
1 parent a7f2bbc commit 4482c15
Showing 1 changed file with 13 additions and 31 deletions.
44 changes: 13 additions & 31 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ inline void computeIndexDiffMap(

// Apply map to compute index lower diff, from index upper diff using
// expandAffineMap.
// FIXME: study how to replace expandAffineMap with constantFold.
SmallVector<Value, 8> lowerDiffModified =
expandAffineMap(b, loc, transform, upperDiffModified).getValue();
for (unsigned iter = 0; iter < lowerDiffModified.size(); ++iter) {
Expand Down Expand Up @@ -295,14 +296,11 @@ inline void computeIndexDiffMap(
assert(lowerLayerBounds.size() == lowerIndicesModified.size());

// Carry checked lower indices.
DenseMap<int64_t, Value> lowerDiffsCarryChecked;
DenseMap<int64_t, Value> lowerIndicesCarryChecked;
for (unsigned iter = 0; iter < q.size(); ++iter) {
int64_t lowerDim = q[iter].template cast<IntegerAttr>().getInt();
lowerDiffsCarryChecked[lowerDim] = lowerDiffs[iter];
lowerIndicesCarryChecked[lowerDim] = lowerIndicesModified[iter];
}
assert(lowerDiffsCarryChecked.size() == lowerIndicesModified.size());
assert(lowerIndicesCarryChecked.size() == lowerIndicesModified.size());

// We only implement carry logic. Borrow logic would never happen as
Expand All @@ -315,26 +313,18 @@ inline void computeIndexDiffMap(

// carry logic.
auto ifCarryOp = b.create<scf::IfOp>(
loc, TypeRange{b.getIntegerType(32), b.getIntegerType(32)},
carryOp, /*withElseRegion=*/true);
loc, b.getIntegerType(32), carryOp, /*withElseRegion=*/true);
auto ifCarryThenBuilder = ifCarryOp.getThenBodyBuilder();
auto carriedLowerDiff = ifCarryThenBuilder.create<AddIOp>(
loc, lowerDiffsCarryChecked[lowerDim], oneConstantI32Op);
auto carriedLowerIndice = ifCarryThenBuilder.create<AddIOp>(
loc, lowerIndicesCarryChecked[lowerDim], oneConstantI32Op);
ifCarryThenBuilder.create<scf::YieldOp>(
loc, ValueRange{carriedLowerDiff.getResult(),
carriedLowerIndice.getResult()});
loc, carriedLowerIndice.getResult());
auto ifCarryElseBuilder = ifCarryOp.getElseBodyBuilder();
carriedLowerDiff = ifCarryElseBuilder.create<AddIOp>(
loc, lowerDiffsCarryChecked[lowerDim], zeroConstantI32Op);
carriedLowerIndice = ifCarryElseBuilder.create<AddIOp>(
loc, lowerIndicesCarryChecked[lowerDim], zeroConstantI32Op);
ifCarryElseBuilder.create<scf::YieldOp>(
loc, ValueRange{carriedLowerDiff.getResult(),
carriedLowerIndice.getResult()});
auto carriedLowerDiffResult = ifCarryOp.results()[0];
auto carriedLowerIndiceResult = ifCarryOp.results()[1];
loc, carriedLowerIndice.getResult());
auto carriedLowerIndiceResult = ifCarryOp.results()[0];

// set carry flag for the next digit.
carryOp = b.create<CmpIOp>(loc, CmpIPredicate::sge,
Expand All @@ -343,38 +333,30 @@ inline void computeIndexDiffMap(

// overflow logic.
auto ifOverflowOp = b.create<scf::IfOp>(
loc, TypeRange{b.getIntegerType(32), b.getIntegerType(32)},
carryOp, /*withElseRegion=*/true);
loc, b.getIntegerType(32), carryOp, /*withElseRegion=*/true);
auto ifOverflowThenBuilder = ifOverflowOp.getThenBodyBuilder();
auto updatedLowerDiff = ifOverflowThenBuilder.create<SubIOp>(
loc, carriedLowerDiffResult, lowerLayerBounds[iter]);
auto updatedLowerIndice = ifOverflowThenBuilder.create<SubIOp>(
loc, carriedLowerIndiceResult, lowerLayerBounds[iter]);
ifOverflowThenBuilder.create<scf::YieldOp>(
loc, ValueRange{updatedLowerDiff.getResult(),
updatedLowerIndice.getResult()});
loc, updatedLowerIndice.getResult());
auto ifOverflowElseBuilder = ifOverflowOp.getElseBodyBuilder();
updatedLowerDiff = ifOverflowElseBuilder.create<SubIOp>(
loc, carriedLowerDiffResult, zeroConstantI32Op);
updatedLowerIndice = ifOverflowElseBuilder.create<SubIOp>(
loc, carriedLowerIndiceResult, zeroConstantI32Op);
ifOverflowElseBuilder.create<scf::YieldOp>(
loc, ValueRange{updatedLowerDiff.getResult(),
updatedLowerIndice.getResult()});
loc, updatedLowerIndice.getResult());

// updatedResult is by default of i32 type.
Value updatedLowerDiffResult = ifOverflowOp.results()[0];
Value updatedLowerIndiceResult = ifOverflowOp.results()[1];
lowerDiffsCarryChecked[lowerDim] = updatedLowerDiffResult;
lowerIndicesCarryChecked[lowerDim] = updatedLowerIndiceResult;
lowerIndicesCarryChecked[lowerDim] = ifOverflowOp.results()[0];
}
assert(lowerDiffsCarryChecked.size() == lowerIndicesModified.size());
assert(lowerIndicesCarryChecked.size() == lowerIndicesModified.size());
lowerDiffs.clear();
lowerIndicesModified.clear();
for (unsigned iter = 0; iter < q.size(); ++iter) {
int64_t lowerDim = q[iter].template cast<IntegerAttr>().getInt();
lowerDiffs.push_back(lowerDiffsCarryChecked[lowerDim]);
lowerDiffs.push_back(b.create<SubIOp>(
loc, lowerIndicesCarryChecked[lowerDim],
b.create<IndexCastOp>(loc, lowerIndicesOriginal[iter],
b.getIntegerType(32))));
lowerIndicesModified.push_back(lowerIndicesCarryChecked[lowerDim]);
}
assert(lowerDiffs.size() == q.size());
Expand Down

0 comments on commit 4482c15

Please sign in to comment.