Skip to content

Commit

Permalink
[ADMGPU] Replace isInlinableLiteral16 with specific version
Browse files Browse the repository at this point in the history
The current implementation of `isInlinableLiteral16` assumes, a 16-bit inlinable
literal is either an `i16` or a `fp16`. This is not always true because of
`bf16`. However, we can't tell `fp16` and `bf16` apart by just looking at the
value. This patch splits `isInlinableLiteral16` into three versions, `i16`,
`fp16`, `bf16` respectively, and call the corresponding version.
  • Loading branch information
shiltian committed Mar 8, 2024
1 parent cb6f657 commit c14c089
Show file tree
Hide file tree
Showing 47 changed files with 1,256 additions and 1,118 deletions.
52 changes: 29 additions & 23 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3327,35 +3327,41 @@ bool AMDGPUDAGToDAGISel::SelectWMMAVISrc(SDValue In, SDValue &Src) const {

// 16 bit splat
SDValue SplatSrc32 = stripBitcast(In);
if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32)) {
if (auto *SplatSrc32BV = dyn_cast<BuildVectorSDNode>(SplatSrc32))
if (SDValue Splat32 = SplatSrc32BV->getSplatValue()) {
SDValue SplatSrc16 = stripBitcast(Splat32);
if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16)) {
if (auto *SplatSrc16BV = dyn_cast<BuildVectorSDNode>(SplatSrc16))
if (SDValue Splat = SplatSrc16BV->getSplatValue()) {

// f16
if (isInlineImmediate(Splat.getNode())) {
const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat);
int64_t Imm = C->getValueAPF().bitcastToAPInt().getSExtValue();
Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i16);
return true;
}

// bf16
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat)) {
const SIInstrInfo *TII = Subtarget->getInstrInfo();
APInt BF16Value = C->getAPIntValue();
APInt F32Value = BF16Value.zext(32).shl(16);
if (TII->isInlineConstant(F32Value)) {
int64_t Imm = F32Value.getSExtValue();
Src = CurDAG->getTargetConstant(Imm, SDLoc(In), MVT::i32);
return true;
}
const SIInstrInfo *TII = Subtarget->getInstrInfo();
std::optional<APInt> RawValue;
if (const ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(Splat))
RawValue = C->getValueAPF().bitcastToAPInt();
else if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat))
RawValue = C->getAPIntValue();

if (RawValue.has_value()) {
EVT VT = In.getValueType().getScalarType();
if (VT.getSimpleVT() == MVT::f16 || VT.getSimpleVT() == MVT::bf16) {
APFloat FloatVal(VT.getSimpleVT() == MVT::f16
? APFloatBase::IEEEhalf()
: APFloatBase::BFloat(),
RawValue.value());
if (TII->isInlineConstant(FloatVal)) {
Src = CurDAG->getTargetConstant(RawValue.value(), SDLoc(In),
MVT::i16);
return true;
}
} else if (VT.getSimpleVT() == MVT::i16) {
if (TII->isInlineConstant(RawValue.value())) {
Src = CurDAG->getTargetConstant(RawValue.value(), SDLoc(In),
MVT::i16);
return true;
}
} else
llvm_unreachable("unknown 16-bit type");
}
}
}
}
}

return false;
}
Expand Down
97 changes: 75 additions & 22 deletions llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,11 @@ static const fltSemantics *getFltSemantics(MVT VT) {

static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
switch (OperandType) {
// When floating-point immediate is used as operand of type i16, the 32-bit
// representation of the constant truncated to the 16 LSBs should be used.
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_IMM_INT32:
case AMDGPU::OPERAND_REG_IMM_FP32:
case AMDGPU::OPERAND_REG_IMM_FP32_DEFERRED:
Expand All @@ -1949,13 +1954,10 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) {
case AMDGPU::OPERAND_REG_INLINE_C_FP64:
case AMDGPU::OPERAND_REG_INLINE_AC_FP64:
return &APFloat::IEEEdouble();
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
case AMDGPU::OPERAND_REG_IMM_V2FP16:
Expand Down Expand Up @@ -2001,13 +2003,15 @@ static bool isSafeTruncation(int64_t Val, unsigned Size) {
}

static bool isInlineableLiteralOp16(int64_t Val, MVT VT, bool HasInv2Pi) {
if (VT.getScalarType() == MVT::i16) {
// FP immediate values are broken.
return isInlinableIntLiteral(Val);
}
if (VT.getScalarType() == MVT::i16)
return isInlinableLiteral32(Val, HasInv2Pi);

if (VT.getScalarType() == MVT::f16)
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);

// f16/v2f16 operands work correctly for all values.
return AMDGPU::isInlinableLiteral16(Val, HasInv2Pi);
assert(VT.getScalarType() == MVT::bf16);

return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
}

bool AMDGPUOperand::isInlinableImm(MVT type) const {
Expand Down Expand Up @@ -2041,9 +2045,30 @@ bool AMDGPUOperand::isInlinableImm(MVT type) const {
return false;

if (type.getScalarSizeInBits() == 16) {
return isInlineableLiteralOp16(
static_cast<int16_t>(FPLiteral.bitcastToAPInt().getZExtValue()),
type, AsmParser->hasInv2PiInlineImm());
bool Lost = false;
switch (type.getScalarType().SimpleTy) {
default:
llvm_unreachable("unknown 16-bit type");
case MVT::bf16:
FPLiteral.convert(APFloatBase::BFloat(), APFloat::rmNearestTiesToEven,
&Lost);
break;
case MVT::f16:
FPLiteral.convert(APFloatBase::IEEEhalf(), APFloat::rmNearestTiesToEven,
&Lost);
break;
case MVT::i16:
FPLiteral.convert(APFloatBase::IEEEsingle(),
APFloat::rmNearestTiesToEven, &Lost);
break;
}
// We need to use 32-bit representation here because when a floating-point
// inline constant is used as an i16 operand, its 32-bit representation
// representation will be used. We will need the 32-bit value to check if
// it is FP inline constant.
uint32_t ImmVal = FPLiteral.bitcastToAPInt().getZExtValue();
return isInlineableLiteralOp16(ImmVal, type,
AsmParser->hasInv2PiInlineImm());
}

// Check if single precision literal is inlinable
Expand Down Expand Up @@ -2375,15 +2400,26 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;

case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val))) {
Inst.addOperand(MCOperand::createImm(Val & 0xffffffff));
setImmKindConst();
return;
}

Inst.addOperand(MCOperand::createImm(Val & 0xffff));
setImmKindLiteral();
return;

case AMDGPU::OPERAND_REG_INLINE_C_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
if (isSafeTruncation(Val, 16) &&
AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm())) {
Inst.addOperand(MCOperand::createImm(Val));
setImmKindConst();
return;
Expand All @@ -2410,12 +2446,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo
return;

case AMDGPU::OPERAND_REG_INLINE_C_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableIntLiteral(static_cast<int16_t>(Val)));
Inst.addOperand(MCOperand::createImm(Val));
return;
}
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: {
assert(isSafeTruncation(Val, 16));
assert(AMDGPU::isInlinableLiteral16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));
assert(AMDGPU::isInlinableLiteralFP16(static_cast<int16_t>(Val),
AsmParser->hasInv2PiInlineImm()));

Inst.addOperand(MCOperand::createImm(Val));
return;
Expand Down Expand Up @@ -3542,7 +3583,7 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
if (OperandType == AMDGPU::OPERAND_REG_IMM_INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_INT16)
return AMDGPU::isInlinableIntLiteral(Val);
return AMDGPU::isInlinableLiteralI16(Val, hasInv2PiInlineImm());

if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2INT16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2INT16 ||
Expand All @@ -3559,7 +3600,19 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst,
OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16)
return AMDGPU::isInlinableLiteralV2BF16(Val);

return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm());
if (OperandType == AMDGPU::OPERAND_REG_IMM_FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_FP16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_FP16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED)
return AMDGPU::isInlinableLiteralFP16(Val, hasInv2PiInlineImm());

if (OperandType == AMDGPU::OPERAND_REG_IMM_BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_C_BF16 ||
OperandType == AMDGPU::OPERAND_REG_INLINE_AC_BF16 ||
OperandType == AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED)
return AMDGPU::isInlinableLiteralBF16(Val, hasInv2PiInlineImm());

llvm_unreachable("invalid operand type");
}
default:
llvm_unreachable("invalid operand size");
Expand Down
29 changes: 15 additions & 14 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -451,19 +451,20 @@ void AMDGPUInstPrinter::printVINTRPDst(const MCInst *MI, unsigned OpNo,
void AMDGPUInstPrinter::printImmediateInt16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
int32_t SImm = static_cast<int32_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
} else {
uint64_t Imm16 = static_cast<uint16_t>(Imm);
O << formatHex(Imm16);
return;
}

if (printImmediateFloat32(Imm, STI, O))
return;

O << formatHex(static_cast<uint64_t>(Imm & 0xffff));
}

// This must accept a 32-bit immediate value to correctly handle packed 16-bit
// operations.
static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
static bool printImmediateFP16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O) {
if (Imm == 0x3C00)
O << "1.0";
else if (Imm == 0xBC00)
Expand Down Expand Up @@ -529,17 +530,17 @@ void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm,
O << formatHex(static_cast<uint64_t>(Imm));
}

void AMDGPUInstPrinter::printImmediate16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
void AMDGPUInstPrinter::printImmediateF16(uint32_t Imm,
const MCSubtargetInfo &STI,
raw_ostream &O) {
int16_t SImm = static_cast<int16_t>(Imm);
if (isInlinableIntLiteral(SImm)) {
O << SImm;
return;
}

uint16_t HImm = static_cast<uint16_t>(Imm);
if (printImmediateFloat16(HImm, STI, O))
if (printImmediateFP16(HImm, STI, O))
return;

uint64_t Imm16 = static_cast<uint16_t>(Imm);
Expand All @@ -566,7 +567,7 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType,
case AMDGPU::OPERAND_REG_INLINE_C_V2FP16:
case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16:
if (isUInt<16>(Imm) &&
printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O))
printImmediateFP16(static_cast<uint16_t>(Imm), STI, O))
return;
break;
case AMDGPU::OPERAND_REG_IMM_V2BF16:
Expand Down Expand Up @@ -845,7 +846,7 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo,
case AMDGPU::OPERAND_REG_INLINE_AC_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
printImmediate16(Op.getImm(), STI, O);
printImmediateF16(Op.getImm(), STI, O);
break;
case AMDGPU::OPERAND_REG_INLINE_C_BF16:
case AMDGPU::OPERAND_REG_INLINE_AC_BF16:
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ class AMDGPUInstPrinter : public MCInstPrinter {
raw_ostream &O);
void printImmediateInt16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateF16(uint32_t Imm, const MCSubtargetInfo &STI,
raw_ostream &O);
void printImmediateV216(uint32_t Imm, uint8_t OpType,
const MCSubtargetInfo &STI, raw_ostream &O);
bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI,
Expand Down
11 changes: 5 additions & 6 deletions llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,6 @@ static uint32_t getIntInlineImmEncoding(IntTy Imm) {
return 0;
}

static uint32_t getLit16IntEncoding(uint16_t Val, const MCSubtargetInfo &STI) {
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
return IntImm == 0 ? 255 : IntImm;
}

static uint32_t getLit16Encoding(uint16_t Val, const MCSubtargetInfo &STI) {
uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val));
if (IntImm != 0)
Expand Down Expand Up @@ -214,6 +209,10 @@ static uint32_t getLit32Encoding(uint32_t Val, const MCSubtargetInfo &STI) {
return 255;
}

static uint32_t getLit16IntEncoding(uint32_t Val, const MCSubtargetInfo &STI) {
return getLit32Encoding(Val, STI);
}

static uint32_t getLit64Encoding(uint64_t Val, const MCSubtargetInfo &STI) {
uint32_t IntImm = getIntInlineImmEncoding(static_cast<int64_t>(Val));
if (IntImm != 0)
Expand Down Expand Up @@ -296,7 +295,7 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO,
case AMDGPU::OPERAND_REG_IMM_INT16:
case AMDGPU::OPERAND_REG_INLINE_C_INT16:
case AMDGPU::OPERAND_REG_INLINE_AC_INT16:
return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI);
return getLit16IntEncoding(static_cast<uint32_t>(Imm), STI);

case AMDGPU::OPERAND_REG_IMM_FP16:
case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED:
Expand Down
28 changes: 22 additions & 6 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15495,16 +15495,32 @@ bool SITargetLowering::checkAsmConstraintVal(SDValue Op, StringRef Constraint,
llvm_unreachable("Invalid asm constraint");
}

bool SITargetLowering::checkAsmConstraintValA(SDValue Op,
uint64_t Val,
bool SITargetLowering::checkAsmConstraintValA(SDValue Op, uint64_t Val,
unsigned MaxSize) const {
unsigned Size = std::min<unsigned>(Op.getScalarValueSizeInBits(), MaxSize);
bool HasInv2Pi = Subtarget->hasInv2PiInlineImm();
if ((Size == 16 && AMDGPU::isInlinableLiteral16(Val, HasInv2Pi)) ||
(Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi))) {
return true;
if (Size == 16) {
MVT VT = Op.getSimpleValueType();
switch (VT.SimpleTy) {
default:
return false;
case MVT::i16:
return AMDGPU::isInlinableLiteralI16(Val, HasInv2Pi);
case MVT::f16:
return AMDGPU::isInlinableLiteralFP16(Val, HasInv2Pi);
case MVT::bf16:
return AMDGPU::isInlinableLiteralBF16(Val, HasInv2Pi);
case MVT::v2i16:
return AMDGPU::getInlineEncodingV2I16(Val).has_value();
case MVT::v2f16:
return AMDGPU::getInlineEncodingV2F16(Val).has_value();
case MVT::v2bf16:
return AMDGPU::getInlineEncodingV2BF16(Val).has_value();
}
}
if ((Size == 32 && AMDGPU::isInlinableLiteral32(Val, HasInv2Pi)) ||
(Size == 64 && AMDGPU::isInlinableLiteral64(Val, HasInv2Pi)))
return true;
return false;
}

Expand Down
Loading

0 comments on commit c14c089

Please sign in to comment.