diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def index 3275d50a85a4bf..f645ad25cbd86d 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.def +++ b/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -173,16 +173,20 @@ TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_f16x2, "V2hV2hV2h", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "", +TARGET_BUILTIN(__nvvm_fmin_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72)) +TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "", +TARGET_BUILTIN(__nvvm_fmin_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "V2yV2yV2y", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "", +TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "", AND(SM_86, PTX72)) BUILTIN(__nvvm_fmin_f, "fff", "") BUILTIN(__nvvm_fmin_ftz_f, "fff", "") @@ -215,16 +219,20 @@ TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_f16x2, "V2hV2hV2h", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "", +TARGET_BUILTIN(__nvvm_fmax_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72)) +TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "", +TARGET_BUILTIN(__nvvm_fmax_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "V2yV2yV2y", "", AND(SM_86, PTX72)) -TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "", +TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "", AND(SM_86, PTX72)) BUILTIN(__nvvm_fmax_f, "fff", "") BUILTIN(__nvvm_fmax_ftz_f, "fff", "") @@ -352,10 +360,10 @@ TARGET_BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42)) TARGET_BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42)) TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70)) -TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fma_rn_bf16, "yyyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "yyyy", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70)) BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "") BUILTIN(__nvvm_fma_rn_f, "ffff", "") BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "") @@ -543,20 +551,20 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "") BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "") BUILTIN(__nvvm_f2h_rn, "Usf", "") -TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "V2yff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "V2yff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "V2yff", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "V2yff", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rn, "yf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "yf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rz, "yf", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70)) @@ -1024,10 +1032,10 @@ TARGET_BUILTIN(__nvvm_cp_async_wait_all, "v", "", AND(SM_80,PTX70)) // bf16, bf16x2 abs, neg -TARGET_BUILTIN(__nvvm_abs_bf16, "UsUs", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_abs_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_neg_bf16, "UsUs", "", AND(SM_80,PTX70)) -TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_abs_bf16, "yy", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_abs_bf16x2, "V2yV2y", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_neg_bf16, "yy", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_neg_bf16x2, "V2yV2y", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_mapa, "v*v*i", "", AND(SM_90, PTX78)) TARGET_BUILTIN(__nvvm_mapa_shared_cluster, "v*3v*3i", "", AND(SM_90, PTX78)) diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 75cb6835049c67..353f3ebb608c2b 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -899,13 +899,13 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a // CHECK-LABEL: nvvm_cvt_sm80 __device__ void nvvm_cvt_sm80() { #if __CUDA_ARCH__ >= 800 - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00) __nvvm_ff2bf16x2_rn(1, 1); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00) __nvvm_ff2bf16x2_rn_relu(1, 1); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00) __nvvm_ff2bf16x2_rz(1, 1); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00) + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00) __nvvm_ff2bf16x2_rz_relu(1, 1); // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00) @@ -917,13 +917,13 @@ __device__ void nvvm_cvt_sm80() { // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00) __nvvm_ff2f16x2_rz_relu(1, 1); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00) + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00) __nvvm_f2bf16_rn(1); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00) + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00) __nvvm_f2bf16_rn_relu(1); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00) + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz(float 1.000000e+00) __nvvm_f2bf16_rz(1); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00) + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00) __nvvm_f2bf16_rz_relu(1); // CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00) @@ -932,32 +932,32 @@ __device__ void nvvm_cvt_sm80() { // CHECK: ret void } +#define NAN32 0x7FBFFFFF +#define NAN16 (__bf16)0x7FBF +#define BF16 (__bf16)0.1f +#define BF16_2 (__bf16)0.2f +#define NANBF16 (__bf16)0xFFC1 +#define BF16X2 {(__bf16)0.1f, (__bf16)0.1f} +#define BF16X2_2 {(__bf16)0.2f, (__bf16)0.2f} +#define NANBF16X2 {NANBF16, NANBF16} + // CHECK-LABEL: nvvm_abs_neg_bf16_bf16x2_sm80 __device__ void nvvm_abs_neg_bf16_bf16x2_sm80() { #if __CUDA_ARCH__ >= 800 - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.abs.bf16(i16 -1) - __nvvm_abs_bf16(0xFFFF); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.abs.bf16x2(i32 -1) - __nvvm_abs_bf16x2(0xFFFFFFFF); + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD) + __nvvm_abs_bf16(BF16); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> ) + __nvvm_abs_bf16x2(BF16X2); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.neg.bf16(i16 -1) - __nvvm_neg_bf16(0xFFFF); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.neg.bf16x2(i32 -1) - __nvvm_neg_bf16x2(0xFFFFFFFF); + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD) + __nvvm_neg_bf16(BF16); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> ) + __nvvm_neg_bf16x2(BF16X2); #endif // CHECK: ret void } -#define NAN32 0x7FBFFFFF -#define NAN16 0x7FBF -#define BF16 0x1234 -#define BF16_2 0x4321 -#define NANBF16 0xFFC1 -#define BF16X2 0x12341234 -#define BF16X2_2 0x32343234 -#define NANBF16X2 0xFFC1FFC1 - // CHECK-LABEL: nvvm_min_max_sm80 __device__ void nvvm_min_max_sm80() { #if __CUDA_ARCH__ >= 800 @@ -967,14 +967,22 @@ __device__ void nvvm_min_max_sm80() { // CHECK_PTX70_SM80: call float @llvm.nvvm.fmin.ftz.nan.f __nvvm_fmin_ftz_nan_f(0.1f, (float)NAN32); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.bf16 + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.bf16 __nvvm_fmin_bf16(BF16, BF16_2); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.nan.bf16 + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.bf16 + __nvvm_fmin_ftz_bf16(BF16, BF16_2); + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.nan.bf16 __nvvm_fmin_nan_bf16(BF16, NANBF16); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.bf16x2 + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.nan.bf16 + __nvvm_fmin_ftz_nan_bf16(BF16, NANBF16); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.bf16x2 __nvvm_fmin_bf16x2(BF16X2, BF16X2_2); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.nan.bf16x2 + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.bf16x2 + __nvvm_fmin_ftz_bf16x2(BF16X2, BF16X2_2); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2 __nvvm_fmin_nan_bf16x2(BF16X2, NANBF16X2); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.nan.bf16x2 + __nvvm_fmin_ftz_nan_bf16x2(BF16X2, NANBF16X2); // CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f __nvvm_fmax_nan_f(0.1f, 0.11f); // CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f @@ -984,14 +992,22 @@ __device__ void nvvm_min_max_sm80() { __nvvm_fmax_nan_f(0.1f, (float)NAN32); // CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f __nvvm_fmax_ftz_nan_f(0.1f, (float)NAN32); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.bf16 + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.bf16 __nvvm_fmax_bf16(BF16, BF16_2); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.nan.bf16 + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.bf16 + __nvvm_fmax_ftz_bf16(BF16, BF16_2); + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.nan.bf16 __nvvm_fmax_nan_bf16(BF16, NANBF16); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.bf16x2 + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.nan.bf16 + __nvvm_fmax_ftz_nan_bf16(BF16, NANBF16); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.bf16x2 __nvvm_fmax_bf16x2(BF16X2, BF16X2_2); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.nan.bf16x2 + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.bf16x2 + __nvvm_fmax_ftz_bf16x2(BF16X2, BF16X2_2); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2 __nvvm_fmax_nan_bf16x2(NANBF16X2, BF16X2); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.nan.bf16x2 + __nvvm_fmax_ftz_nan_bf16x2(NANBF16X2, BF16X2); // CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f __nvvm_fmax_nan_f(0.1f, (float)NAN32); // CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f @@ -1004,14 +1020,14 @@ __device__ void nvvm_min_max_sm80() { // CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80 __device__ void nvvm_fma_bf16_bf16x2_sm80() { #if __CUDA_ARCH__ >= 800 - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16 - __nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234); - // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16 - __nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2 - __nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF); - // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2 - __nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF); + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.bf16 + __nvvm_fma_rn_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.relu.bf16 + __nvvm_fma_rn_relu_bf16(BF16, BF16_2, BF16_2); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2 + __nvvm_fma_rn_bf16x2(BF16X2, BF16X2_2, BF16X2_2); + // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2 + __nvvm_fma_rn_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2); #endif // CHECK: ret void } @@ -1020,13 +1036,13 @@ __device__ void nvvm_fma_bf16_bf16x2_sm80() { __device__ void nvvm_min_max_sm86() { #if __CUDA_ARCH__ >= 860 - // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.xorsign.abs.bf16 + // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16 __nvvm_fmin_xorsign_abs_bf16(BF16, BF16_2); - // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16 + // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16 __nvvm_fmin_nan_xorsign_abs_bf16(BF16, NANBF16); - // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2 + // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2 __nvvm_fmin_xorsign_abs_bf16x2(BF16X2, BF16X2_2); - // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2 + // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2 __nvvm_fmin_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2); // CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.xorsign.abs.f __nvvm_fmin_xorsign_abs_f(-0.1f, 0.1f); @@ -1037,13 +1053,13 @@ __device__ void nvvm_min_max_sm86() { // CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f __nvvm_fmin_ftz_nan_xorsign_abs_f(-0.1f, (float)NAN32); - // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.xorsign.abs.bf16 + // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16 __nvvm_fmax_xorsign_abs_bf16(BF16, BF16_2); - // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16 + // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16 __nvvm_fmax_nan_xorsign_abs_bf16(BF16, NANBF16); - // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2 + // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2 __nvvm_fmax_xorsign_abs_bf16x2(BF16X2, BF16X2_2); - // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2 + // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2 __nvvm_fmax_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2); // CHECK_PTX72_SM86: call float @llvm.nvvm.fmax.xorsign.abs.f __nvvm_fmax_xorsign_abs_f(-0.1f, 0.1f); diff --git a/clang/test/CodeGenCUDA/bf16.cu b/clang/test/CodeGenCUDA/bf16.cu index 32082904c4d81c..3c443420dbd36a 100644 --- a/clang/test/CodeGenCUDA/bf16.cu +++ b/clang/test/CodeGenCUDA/bf16.cu @@ -8,7 +8,7 @@ // CHECK-LABEL: .visible .func _Z8test_argPDF16bDF16b( // CHECK: .param .b64 _Z8test_argPDF16bDF16b_param_0, -// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1 +// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2] // __device__ void test_arg(__bf16 *out, __bf16 in) { // CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0]; @@ -20,8 +20,8 @@ __device__ void test_arg(__bf16 *out, __bf16 in) { } -// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b( -// CHECK: .param .b16 _Z8test_retDF16b_param_0 +// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z8test_retDF16b( +// CHECK: .param .align 2 .b8 _Z8test_retDF16b_param_0[2] __device__ __bf16 test_ret( __bf16 in) { // CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0]; return in; @@ -31,12 +31,12 @@ __device__ __bf16 test_ret( __bf16 in) { __device__ __bf16 external_func( __bf16 in); -// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b( -// CHECK: .param .b16 _Z9test_callDF16b_param_0 +// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z9test_callDF16b( +// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2] __device__ __bf16 test_call( __bf16 in) { // CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0]; // CHECK: st.param.b16 [param0+0], %[[R]]; -// CHECK: .param .b32 retval0; +// CHECK: .param .align 2 .b8 retval0[2]; // CHECK: call.uni (retval0), // CHECK-NEXT: _Z13external_funcDF16b, // CHECK-NEXT: ( diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 7e4ad18cf53216..914f6c36a3e4a2 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -595,19 +595,21 @@ let TargetPrefix = "nvvm" in { [IntrNoMem, IntrSpeculatable, Commutative]>; } - foreach variant = ["_bf16", "_nan_bf16", "_xorsign_abs_bf16", - "_nan_xorsign_abs_bf16"] in { + foreach variant = ["_bf16", "_ftz_bf16", "_nan_bf16", "_ftz_nan_bf16", + "_xorsign_abs_bf16", "_ftz_xorsign_abs_bf16", "_nan_xorsign_abs_bf16", + "_ftz_nan_xorsign_abs_bf16"] in { def int_nvvm_f # operation # variant : ClangBuiltin, - DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty], + DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty, llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable, Commutative]>; } - foreach variant = ["_bf16x2", "_nan_bf16x2", "_xorsign_abs_bf16x2", - "_nan_xorsign_abs_bf16x2"] in { + foreach variant = ["_bf16x2", "_ftz_bf16x2", "_nan_bf16x2", + "_ftz_nan_bf16x2", "_xorsign_abs_bf16x2", "_ftz_xorsign_abs_bf16x2", + "_nan_xorsign_abs_bf16x2", "_ftz_nan_xorsign_abs_bf16x2"] in { def int_nvvm_f # operation # variant : ClangBuiltin, - DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], + DefaultAttrsIntrinsic<[llvm_v2bf16_ty], [llvm_v2bf16_ty, llvm_v2bf16_ty], [IntrNoMem, IntrSpeculatable, Commutative]>; } } @@ -774,10 +776,10 @@ let TargetPrefix = "nvvm" in { foreach unary = ["abs", "neg"] in { def int_nvvm_ # unary # _bf16 : ClangBuiltin, - DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem]>; + DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty], [IntrNoMem]>; def int_nvvm_ # unary # _bf16x2 : ClangBuiltin, - DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>; + DefaultAttrsIntrinsic<[llvm_v2bf16_ty], [llvm_v2bf16_ty], [IntrNoMem]>; } // @@ -870,17 +872,19 @@ let TargetPrefix = "nvvm" in { [IntrNoMem, IntrSpeculatable]>; } - foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in { + foreach variant = ["_rn_bf16", "_rn_ftz_bf16", "_rn_sat_bf16", + "_rn_ftz_sat_bf16", "_rn_relu_bf16", "_rn_ftz_relu_bf16"] in { def int_nvvm_fma # variant : ClangBuiltin, - DefaultAttrsIntrinsic<[llvm_i16_ty], - [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty], + DefaultAttrsIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>; } - foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in { + foreach variant = ["_rn_bf16x2", "_rn_ftz_bf16x2", "_rn_sat_bf16x2", + "_rn_ftz_sat_bf16x2", "_rn_relu_bf16x2", "_rn_ftz_relu_bf16x2"] in { def int_nvvm_fma # variant : ClangBuiltin, - DefaultAttrsIntrinsic<[llvm_i32_ty], - [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], + DefaultAttrsIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty], [IntrNoMem, IntrSpeculatable]>; } @@ -1232,14 +1236,19 @@ let TargetPrefix = "nvvm" in { def int_nvvm_f2h_rn : ClangBuiltin<"__nvvm_f2h_rn">, DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_bf2h_rn_ftz : ClangBuiltin<"__nvvm_bf2h_rn_ftz">, + DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_bf2h_rn : ClangBuiltin<"__nvvm_bf2h_rn">, + DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">, - Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">, - Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_ff2bf16x2_rz : ClangBuiltin<"__nvvm_ff2bf16x2_rz">, - Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_ff2bf16x2_rz_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rz_relu">, - Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; + Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>; def int_nvvm_ff2f16x2_rn : ClangBuiltin<"__nvvm_ff2f16x2_rn">, Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; @@ -1251,13 +1260,13 @@ let TargetPrefix = "nvvm" in { Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_f2bf16_rn : ClangBuiltin<"__nvvm_f2bf16_rn">, - Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_f2bf16_rn_relu : ClangBuiltin<"__nvvm_f2bf16_rn_relu">, - Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_f2bf16_rz : ClangBuiltin<"__nvvm_f2bf16_rz">, - Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_f2bf16_rz_relu : ClangBuiltin<"__nvvm_f2bf16_rz_relu">, - Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">, Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index f53f32f749fef5..d26f39b16bb356 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/IntrinsicsARM.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/IR/IntrinsicsX86.h" @@ -591,6 +592,71 @@ static bool UpgradeX86IntrinsicFunction(Function *F, StringRef Name, return false; } +static Intrinsic::ID ShouldUpgradeNVPTXBF16Intrinsic(StringRef Name) { + return StringSwitch(Name) + .Case("abs.bf16", Intrinsic::nvvm_abs_bf16) + .Case("abs.bf16x2", Intrinsic::nvvm_abs_bf16x2) + .Case("fma.rn.bf16", Intrinsic::nvvm_fma_rn_bf16) + .Case("fma.rn.bf16x2", Intrinsic::nvvm_fma_rn_bf16x2) + .Case("fma.rn.ftz_bf16", Intrinsic::nvvm_fma_rn_ftz_bf16) + .Case("fma.rn.ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2) + .Case("fma.rn.ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16) + .Case("fma.rn.ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2) + .Case("fma.rn.ftz_sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16) + .Case("fma.rn.ftz_sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2) + .Case("fma.rn.relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16) + .Case("fma.rn.relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2) + .Case("fma.rn.sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16) + .Case("fma.rn.sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2) + .Case("fmax.bf16", Intrinsic::nvvm_fmax_bf16) + .Case("fmax.bf16x2", Intrinsic::nvvm_fmax_bf16x2) + .Case("fmax.ftz.bf16", Intrinsic::nvvm_fmax_ftz_bf16) + .Case("fmax.ftz.bf16x2", Intrinsic::nvvm_fmax_ftz_bf16x2) + .Case("fmax.ftz.nan.bf16", Intrinsic::nvvm_fmax_ftz_nan_bf16) + .Case("fmax.ftz.nan.bf16x2", Intrinsic::nvvm_fmax_ftz_nan_bf16x2) + .Case("fmax.ftz.nan.xorsign.abs.bf16", + Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16) + .Case("fmax.ftz.nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16x2) + .Case("fmax.ftz.xorsign.abs.bf16", + Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16) + .Case("fmax.ftz.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16x2) + .Case("fmax.nan.bf16", Intrinsic::nvvm_fmax_nan_bf16) + .Case("fmax.nan.bf16x2", Intrinsic::nvvm_fmax_nan_bf16x2) + .Case("fmax.nan.xorsign.abs.bf16", + Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16) + .Case("fmax.nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16x2) + .Case("fmax.xorsign.abs.bf16", Intrinsic::nvvm_fmax_xorsign_abs_bf16) + .Case("fmax.xorsign.abs.bf16x2", Intrinsic::nvvm_fmax_xorsign_abs_bf16x2) + .Case("fmin.bf16", Intrinsic::nvvm_fmin_bf16) + .Case("fmin.bf16x2", Intrinsic::nvvm_fmin_bf16x2) + .Case("fmin.ftz.bf16", Intrinsic::nvvm_fmin_ftz_bf16) + .Case("fmin.ftz.bf16x2", Intrinsic::nvvm_fmin_ftz_bf16x2) + .Case("fmin.ftz.nan_bf16", Intrinsic::nvvm_fmin_ftz_nan_bf16) + .Case("fmin.ftz.nan_bf16x2", Intrinsic::nvvm_fmin_ftz_nan_bf16x2) + .Case("fmin.ftz.nan.xorsign.abs.bf16", + Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16) + .Case("fmin.ftz.nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16x2) + .Case("fmin.ftz.xorsign.abs.bf16", + Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16) + .Case("fmin.ftz.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16x2) + .Case("fmin.nan.bf16", Intrinsic::nvvm_fmin_nan_bf16) + .Case("fmin.nan.bf16x2", Intrinsic::nvvm_fmin_nan_bf16x2) + .Case("fmin.nan.xorsign.abs.bf16", + Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16) + .Case("fmin.nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16x2) + .Case("fmin.xorsign.abs.bf16", Intrinsic::nvvm_fmin_xorsign_abs_bf16) + .Case("fmin.xorsign.abs.bf16x2", Intrinsic::nvvm_fmin_xorsign_abs_bf16x2) + .Case("neg.bf16", Intrinsic::nvvm_neg_bf16) + .Case("neg.bf16x2", Intrinsic::nvvm_neg_bf16x2) + .Default(Intrinsic::not_intrinsic); +} + static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) { assert(F && "Illegal to upgrade a non-existent Function."); @@ -1082,7 +1148,12 @@ static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) { {F->getReturnType()}); return true; } - + IID = ShouldUpgradeNVPTXBF16Intrinsic(Name); + if (IID != Intrinsic::not_intrinsic && + !F->getReturnType()->getScalarType()->isBFloatTy()) { + NewFn = nullptr; + return true; + } // The following nvvm intrinsics correspond exactly to an LLVM idiom, but // not to an intrinsic alone. We expand them in UpgradeIntrinsicCall. // @@ -4049,11 +4120,34 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) { {Arg->getType()}), Arg, "ctpop"); Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); - } else if (IsNVVM && Name == "h2f") { - Rep = Builder.CreateCall(Intrinsic::getDeclaration( + } else if (IsNVVM) { + if (Name == "h2f") { + Rep = + Builder.CreateCall(Intrinsic::getDeclaration( F->getParent(), Intrinsic::convert_from_fp16, {Builder.getFloatTy()}), CI->getArgOperand(0), "h2f"); + } else { + Intrinsic::ID IID = ShouldUpgradeNVPTXBF16Intrinsic(Name); + if (IID != Intrinsic::not_intrinsic && + !F->getReturnType()->getScalarType()->isBFloatTy()) { + rename(F); + NewFn = Intrinsic::getDeclaration(F->getParent(), IID); + SmallVector Args; + for (size_t I = 0; I < NewFn->arg_size(); ++I) { + Value *Arg = CI->getArgOperand(I); + Type *OldType = Arg->getType(); + Type *NewType = NewFn->getArg(I)->getType(); + Args.push_back((OldType->isIntegerTy() && + NewType->getScalarType()->isBFloatTy()) + ? Builder.CreateBitCast(Arg, NewType) + : Arg); + } + Rep = Builder.CreateCall(NewFn, Args); + if (F->getReturnType()->isIntegerTy()) + Rep = Builder.CreateBitCast(Rep, F->getReturnType()); + } + } } else if (IsARM) { Rep = UpgradeARMIntrinsicCall(Name, CI, F, Builder); } else if (IsAMDGCN) { diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 179306b59b0ffc..fd032676dcf64e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -272,6 +272,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO, MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); break; + case Type::BFloatTyID: + MCOp = MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); + break; case Type::FloatTyID: MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); @@ -330,6 +334,11 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) { return MCOperand::createExpr(Expr); } +static bool ShouldPassAsArray(Type *Ty) { + return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || + Ty->isHalfTy() || Ty->isBFloatTy(); +} + void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { const DataLayout &DL = getDataLayout(); const NVPTXSubtarget &STI = TM.getSubtarget(*F); @@ -341,11 +350,11 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { if (Ty->getTypeID() == Type::VoidTyID) return; - O << " ("; if (isABI) { - if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) { + if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) && + !ShouldPassAsArray(Ty)) { unsigned size = 0; if (auto *ITy = dyn_cast(Ty)) { size = ITy->getBitWidth(); @@ -353,16 +362,12 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { assert(Ty->isFloatingPointTy() && "Floating point type expected here"); size = Ty->getPrimitiveSizeInBits(); } - // PTX ABI requires all scalar return values to be at least 32 - // bits in size. fp16 normally uses .b16 as its storage type in - // PTX, so its size must be adjusted here, too. size = promoteScalarArgumentSize(size); - O << ".param .b" << size << " func_retval0"; } else if (isa(Ty)) { O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits() << " func_retval0"; - } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { + } else if (ShouldPassAsArray(Ty)) { unsigned totalsz = DL.getTypeAllocSize(Ty); unsigned retAlignment = 0; if (!getAlign(*F, 0, retAlignment)) @@ -1355,8 +1360,10 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const { } break; } + case Type::BFloatTyID: case Type::HalfTyID: - // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. + // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53 + // PTX assembly. return "b16"; case Type::FloatTyID: return "f32"; @@ -1510,7 +1517,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { }; if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) { - if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { + if (ShouldPassAsArray(Ty)) { // Just print .param .align .b8 .param[size]; // = optimal alignment for the element type; always multiple of // PAL.getParamAlignment @@ -1581,12 +1588,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { } else if (PTy) { assert(PTySizeInBits && "Invalid pointer size"); sz = PTySizeInBits; - } else if (Ty->isHalfTy()) - // PTX ABI requires all scalar parameters to be at least 32 - // bits in size. fp16 normally uses .b16 as its storage type - // in PTX, so its size must be adjusted here, too. - sz = 32; - else + } else sz = Ty->getPrimitiveSizeInBits(); if (isABI) O << "\t.param .b" << sz << " "; diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index cb8a1867c44f01..db69431cceefcd 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -500,7 +500,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { SelectAddrSpaceCast(N); return; case ISD::ConstantFP: - if (tryConstantFP16(N)) + if (tryConstantFP(N)) return; break; default: @@ -524,15 +524,17 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) { } } -// There's no way to specify FP16 immediates in .f16 ops, so we have to -// load them into an .f16 register first. -bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) { - if (N->getValueType(0) != MVT::f16) +// There's no way to specify FP16 and BF16 immediates in .(b)f16 ops, so we +// have to load them into an .(b)f16 register first. +bool NVPTXDAGToDAGISel::tryConstantFP(SDNode *N) { + if (N->getValueType(0) != MVT::f16 && N->getValueType(0) != MVT::bf16) return false; SDValue Val = CurDAG->getTargetConstantFP( - cast(N)->getValueAPF(), SDLoc(N), MVT::f16); - SDNode *LoadConstF16 = - CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val); + cast(N)->getValueAPF(), SDLoc(N), N->getValueType(0)); + SDNode *LoadConstF16 = CurDAG->getMachineNode( + (N->getValueType(0) == MVT::f16 ? NVPTX::LOAD_CONST_F16 + : NVPTX::LOAD_CONST_BF16), + SDLoc(N), N->getValueType(0), Val); ReplaceNode(N, LoadConstF16); return true; } @@ -612,9 +614,9 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) { // We only care about f16x2 as it's the only real vector type we // need to deal with. - if (Vector.getSimpleValueType() != MVT::v2f16) + MVT VT = Vector.getSimpleValueType(); + if (!(VT == MVT::v2f16 || VT == MVT::v2bf16)) return false; - // Find and record all uses of this vector that extract element 0 or 1. SmallVector E0, E1; for (auto *U : Vector.getNode()->uses()) { @@ -640,8 +642,9 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) { // Merge (f16 extractelt(V, 0), f16 extractelt(V,1)) // into f16,f16 SplitF16x2(V) - SDNode *ScatterOp = CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), - MVT::f16, MVT::f16, Vector); + MVT EltVT = VT.getVectorElementType(); + SDNode *ScatterOp = + CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector); for (auto *Node : E0) ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0)); for (auto *Node : E1) @@ -1258,10 +1261,11 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) { NumElts = EltVT.getVectorNumElements(); EltVT = EltVT.getVectorElementType(); // vectors of f16 are loaded/stored as multiples of v2f16 elements. - if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) { - assert(NumElts % 2 == 0 && "Vector must have even number of elements"); - EltVT = MVT::v2f16; - NumElts /= 2; + if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) || + (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) { + assert(NumElts % 2 == 0 && "Vector must have even number of elements"); + EltVT = N->getValueType(0); + NumElts /= 2; } } diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index 2a8ee5089ca02b..25bb73cd553612 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -71,7 +71,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel { bool tryTextureIntrinsic(SDNode *N); bool trySurfaceIntrinsic(SDNode *N); bool tryBFE(SDNode *N); - bool tryConstantFP16(SDNode *N); + bool tryConstantFP(SDNode *N); bool SelectSETP_F16X2(SDNode *N); bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 98a877cbafec97..fa050bcdc34121 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -149,6 +149,14 @@ static bool IsPTXVectorType(MVT VT) { } } +static bool Isv2f16Orv2bf16Type(EVT VT) { + return (VT == MVT::v2f16 || VT == MVT::v2bf16); +} + +static bool Isf16Orbf16Type(MVT VT) { + return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16); +} + /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive /// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors /// into their primitive components. @@ -199,7 +207,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL, // Vectors with an even number of f16 elements will be passed to // us as an array of v2f16/v2bf16 elements. We must match this so we // stay in sync with Ins/Outs. - if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) { + if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) { EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16; NumElts /= 2; } @@ -404,6 +412,21 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action); }; + auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, + LegalizeAction NoBF16Action) { + bool IsOpSupported = STI.hasBF16Math(); + // Few instructions are available on sm_90 only + switch(Op) { + case ISD::FADD: + case ISD::FMUL: + case ISD::FSUB: + IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78; + break; + } + setOperationAction( + Op, VT, IsOpSupported ? Action : NoBF16Action); + }; + addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); @@ -426,6 +449,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote); setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand); + // Conversion to/from BFP16/BFP16x2 is always legal. + setOperationAction(ISD::SINT_TO_FP, MVT::bf16, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::bf16, Legal); + setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand); + setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand); + + setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote); + setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand); // Operations not directly supported by NVPTX. for (MVT VT : {MVT::f16, MVT::v2f16, MVT::f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::i32, MVT::i64}) { @@ -482,17 +515,25 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // Turn FP extload into load/fpextend setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand); // Turn FP truncstore into trunc + store. // FIXME: vector types should also be expanded setTruncStoreAction(MVT::f32, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); setTruncStoreAction(MVT::f64, MVT::f32, Expand); // PTX does not support load / store predicate registers @@ -569,9 +610,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM}); - // setcc for f16x2 needs special handling to prevent legalizer's - // attempt to scalarize it due to v2i1 not being legal. - if (STI.allowFP16Math()) + // setcc for f16x2 and bf16x2 needs special handling to prevent + // legalizer's attempt to scalarize it due to v2i1 not being legal. + if (STI.allowFP16Math() || STI.hasBF16Math()) setTargetDAGCombine(ISD::SETCC); // Promote fp16 arithmetic if fp16 hardware isn't available or the @@ -583,6 +624,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { setFP16OperationAction(Op, MVT::f16, Legal, Promote); setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); + setBF16OperationAction(Op, MVT::bf16, Legal, Promote); + setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); + // bf16 must be promoted to f32. + if (getOperationAction(Op, MVT::bf16) == Promote) + AddPromotedToType(Op, MVT::bf16, MVT::f32); } // f16/f16x2 neg was introduced in PTX 60, SM_53. @@ -593,19 +639,25 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction(ISD::FNEG, VT, IsFP16FP16x2NegAvailable ? Legal : Expand); + setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand); + setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand); // (would be) Library functions. // These map to conversion instructions for scalar FP types. for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, ISD::FROUNDEVEN, ISD::FTRUNC}) { + setOperationAction(Op, MVT::bf16, Legal); setOperationAction(Op, MVT::f16, Legal); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); + setOperationAction(Op, MVT::v2bf16, Expand); } setOperationAction(ISD::FROUND, MVT::f16, Promote); setOperationAction(ISD::FROUND, MVT::v2f16, Expand); + setOperationAction(ISD::FROUND, MVT::bf16, Promote); + setOperationAction(ISD::FROUND, MVT::v2bf16, Expand); setOperationAction(ISD::FROUND, MVT::f32, Custom); setOperationAction(ISD::FROUND, MVT::f64, Custom); @@ -613,6 +665,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // 'Expand' implements FCOPYSIGN without calling an external library. setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); @@ -622,9 +676,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) { setOperationAction(Op, MVT::f16, Promote); + setOperationAction(Op, MVT::bf16, Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); + setOperationAction(Op, MVT::v2bf16, Expand); } // max.f16, max.f16x2 and max.NaN are supported on sm_80+. auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) { @@ -633,14 +689,18 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, }; for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote); + setBF16OperationAction(Op, MVT::bf16, Legal, Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); } for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) { setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand); + setFP16OperationAction(Op, MVT::bf16, Legal, Expand); setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); + setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. @@ -1258,7 +1318,7 @@ NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const { if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1) return TypeSplitVector; - if (VT == MVT::v2f16) + if (Isv2f16Orv2bf16Type(VT)) return TypeLegal; return TargetLoweringBase::getPreferredVectorAction(VT); } @@ -1321,6 +1381,11 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op); } +static bool IsTypePassedAsArray(const Type *Ty) { + return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) || + Ty->isHalfTy() || Ty->isBFloatTy(); +} + std::string NVPTXTargetLowering::getPrototype( const DataLayout &DL, Type *retTy, const ArgListTy &Args, const SmallVectorImpl &Outs, MaybeAlign retAlignment, @@ -1341,7 +1406,8 @@ std::string NVPTXTargetLowering::getPrototype( O << "()"; } else { O << "("; - if (retTy->isFloatingPointTy() || (retTy->isIntegerTy() && !retTy->isIntegerTy(128))) { + if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) && + !IsTypePassedAsArray(retTy)) { unsigned size = 0; if (auto *ITy = dyn_cast(retTy)) { size = ITy->getBitWidth(); @@ -1358,8 +1424,7 @@ std::string NVPTXTargetLowering::getPrototype( O << ".param .b" << size << " _"; } else if (isa(retTy)) { O << ".param .b" << PtrVT.getSizeInBits() << " _"; - } else if (retTy->isAggregateType() || retTy->isVectorTy() || - retTy->isIntegerTy(128)) { + } else if (IsTypePassedAsArray(retTy)) { O << ".param .align " << (retAlignment ? retAlignment->value() : 0) << " .b8 _[" << DL.getTypeAllocSize(retTy) << "]"; } else { @@ -1381,7 +1446,7 @@ std::string NVPTXTargetLowering::getPrototype( first = false; if (!Outs[OIdx].Flags.isByVal()) { - if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) { + if (IsTypePassedAsArray(Ty)) { unsigned ParamAlign = 0; const CallInst *CallI = cast(&CB); // +1 because index 0 is reserved for return type alignment @@ -1408,13 +1473,9 @@ std::string NVPTXTargetLowering::getPrototype( sz = promoteScalarArgumentSize(sz); } else if (isa(Ty)) { sz = PtrVT.getSizeInBits(); - } else if (Ty->isHalfTy()) - // PTX ABI requires all scalar parameters to be at least 32 - // bits in size. fp16 normally uses .b16 as its storage type - // in PTX, so its size must be adjusted here, too. - sz = 32; - else + } else { sz = Ty->getPrimitiveSizeInBits(); + } O << ".param .b" << sz << " "; O << "_"; continue; @@ -1577,6 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); bool NeedAlign; // Does argument declaration specify alignment? + bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty); if (IsVAArg) { if (ParamCount == FirstVAArg) { SDValue DeclareParamOps[] = { @@ -1586,10 +1648,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, DeclareParamOps); } - NeedAlign = IsByVal || Ty->isAggregateType() || Ty->isVectorTy() || - Ty->isIntegerTy(128); - } else if (IsByVal || Ty->isAggregateType() || Ty->isVectorTy() || - Ty->isIntegerTy(128)) { + NeedAlign = PassAsArray; + } else if (PassAsArray) { // declare .param .align .b8 .param[]; SDValue DeclareParamOps[] = { Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32), @@ -1739,15 +1799,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, ComputeValueVTs(*this, DL, RetTy, resvtparts); // Declare - // .param .align 16 .b8 retval0[], or + // .param .align N .b8 retval0[], or // .param .b retval0 unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy); - // Emit ".param .b retval0" instead of byte arrays only for - // these three types to match the logic in - // NVPTXAsmPrinter::printReturnValStr and NVPTXTargetLowering::getPrototype. - // Plus, this behavior is consistent with nvcc's. - if (RetTy->isFloatingPointTy() || RetTy->isPointerTy() || - (RetTy->isIntegerTy() && !RetTy->isIntegerTy(128))) { + if (!IsTypePassedAsArray(RetTy)) { resultsz = promoteScalarArgumentSize(resultsz); SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32), @@ -2043,7 +2098,7 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { // generates good SASS in both cases. SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { - if (!(Op->getValueType(0) == MVT::v2f16 && + if (!(Isv2f16Orv2bf16Type(Op->getValueType(0)) && isa(Op->getOperand(0)) && isa(Op->getOperand(1)))) return Op; @@ -2054,7 +2109,7 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, cast(Op->getOperand(1))->getValueAPF().bitcastToAPInt(); SDValue Const = DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32); - return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const); + return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const); } SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, @@ -2415,7 +2470,7 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { // v2f16 is legal, so we can't rely on legalizer to handle unaligned // loads and have to handle it here. - if (Op.getValueType() == MVT::v2f16) { + if (Isv2f16Orv2bf16Type(Op.getValueType())) { LoadSDNode *Load = cast(Op); EVT MemVT = Load->getMemoryVT(); if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), @@ -2460,7 +2515,7 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { // v2f16 is legal, so we can't rely on legalizer to handle unaligned // stores and have to handle it here. - if (VT == MVT::v2f16 && + if (Isv2f16Orv2bf16Type(VT) && !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), VT, *Store->getMemOperand())) return expandUnalignedStore(Store, DAG); @@ -2551,7 +2606,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { // v8f16 is a special case. PTX doesn't have st.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // store them with st.v4.b32. - assert((EltVT == MVT::f16 || EltVT == MVT::bf16) && + assert(Isf16Orbf16Type(EltVT.getSimpleVT()) && "Wrong type for the vector."); Opcode = NVPTXISD::StoreV4; StoreF16x2 = true; @@ -2567,11 +2622,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { // Combine f16,f16 -> v2f16 NumElts /= 2; for (unsigned i = 0; i < NumElts; ++i) { - SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, DAG.getIntPtrConstant(i * 2, DL)); - SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, DAG.getIntPtrConstant(i * 2 + 1, DL)); - SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1); + EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2); + SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1); Ops.push_back(V2); } } else { @@ -2672,7 +2728,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( if (theArgs[i]->use_empty()) { // argument is dead - if (Ty->isAggregateType() || Ty->isIntegerTy(128)) { + if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) { SmallVector vtparts; ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts); @@ -2737,9 +2793,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( EVT LoadVT = EltVT; if (EltVT == MVT::i1) LoadVT = MVT::i8; - else if (EltVT == MVT::v2f16) + else if (Isv2f16Orv2bf16Type(EltVT)) // getLoad needs a vector type, but it can't handle - // vectors which contain v2f16 elements. So we must load + // vectors which contain v2f16 or v2bf16 elements. So we must load // using i32 here and then bitcast back. LoadVT = MVT::i32; @@ -2763,8 +2819,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( if (EltVT == MVT::i1) Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt); // v2f16 was loaded as an i32. Now we must bitcast it back. - else if (EltVT == MVT::v2f16) - Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt); + else if (Isv2f16Orv2bf16Type(EltVT)) + Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt); // If a promoted integer type is used, truncate down to the original MVT PromotedVT; @@ -5194,7 +5250,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, // v8f16 is a special case. PTX doesn't have ld.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // load them with ld.v4.b32. - assert((EltVT == MVT::f16 || EltVT == MVT::bf16) && + assert(Isf16Orbf16Type(EltVT.getSimpleVT()) && "Unsupported v8 vector type."); LoadF16x2 = true; Opcode = NVPTXISD::LoadV4; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 54237af13d8dc2..b98f76ed4b38d9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -19,6 +19,8 @@ let hasSideEffects = false in { let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand; + def bf16imm : Operand; + } // List of vector specific properties @@ -154,6 +156,7 @@ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" def useShortPtr : Predicate<"useShortPointers()">; def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; +def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; // Helper class to aid conversion between ValueType and a matching RegisterClass. @@ -304,6 +307,31 @@ multiclass F3 { !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>, Requires<[useFP16Math]>; + def bf16rr_ftz : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b), + !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"), + [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>, + Requires<[hasBF16Math, doF32FTZ]>; + def bf16rr : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b), + !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"), + [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>, + Requires<[hasBF16Math]>; + + def bf16x2rr_ftz : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int32Regs:$a, Int32Regs:$b), + !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"), + [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>, + Requires<[hasBF16Math, doF32FTZ]>; + def bf16x2rr : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int32Regs:$a, Int32Regs:$b), + !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"), + [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>, + Requires<[hasBF16Math]>; } // Template for instructions which take three FP args. The @@ -378,7 +406,31 @@ multiclass F3_fma_component { !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>, Requires<[useFP16Math, allowFMA]>; + def bf16rr_ftz : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b), + !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"), + [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>, + Requires<[hasBF16Math, allowFMA, doF32FTZ]>; + def bf16rr : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b), + !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"), + [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>, + Requires<[hasBF16Math, allowFMA]>; + def bf16x2rr_ftz : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int32Regs:$a, Int32Regs:$b), + !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"), + [(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>, + Requires<[hasBF16Math, allowFMA, doF32FTZ]>; + def bf16x2rr : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int32Regs:$a, Int32Regs:$b), + !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"), + [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>, + Requires<[hasBF16Math, allowFMA]>; // These have strange names so we don't perturb existing mir tests. def _rnf64rr : NVPTXInst<(outs Float64Regs:$dst), @@ -440,6 +492,30 @@ multiclass F3_fma_component { !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"), [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>, Requires<[useFP16Math, noFMA]>; + def _rnbf16rr_ftz : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b), + !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"), + [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>, + Requires<[hasBF16Math, noFMA, doF32FTZ]>; + def _rnbf16rr : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b), + !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"), + [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>, + Requires<[hasBF16Math, noFMA]>; + def _rnbf16x2rr_ftz : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int32Regs:$a, Int32Regs:$b), + !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"), + [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>, + Requires<[hasBF16Math, noFMA, doF32FTZ]>; + def _rnbf16x2rr : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int32Regs:$a, Int32Regs:$b), + !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"), + [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>, + Requires<[hasBF16Math, noFMA]>; } // Template for operations which take two f32 or f64 operands. Provides three @@ -470,62 +546,86 @@ let hasSideEffects = false in { // Generate a cvt to the given type from all possible types. Each instance // takes a CvtMode immediate that defines the conversion mode to use. It can // be CvtNONE to omit a conversion mode. - multiclass CVT_FROM_ALL { + multiclass CVT_FROM_ALL Preds = []> { def _s8 : NVPTXInst<(outs RC:$dst), (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".s8 \t$dst, $src;"), []>; + ToType, ".s8 \t$dst, $src;"), []>, + Requires; def _u8 : NVPTXInst<(outs RC:$dst), (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".u8 \t$dst, $src;"), []>; + ToType, ".u8 \t$dst, $src;"), []>, + Requires; def _s16 : NVPTXInst<(outs RC:$dst), (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".s16 \t$dst, $src;"), []>; + ToType, ".s16 \t$dst, $src;"), []>, + Requires; def _u16 : NVPTXInst<(outs RC:$dst), (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".u16 \t$dst, $src;"), []>; + ToType, ".u16 \t$dst, $src;"), []>, + Requires; def _s32 : NVPTXInst<(outs RC:$dst), (ins Int32Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".s32 \t$dst, $src;"), []>; + ToType, ".s32 \t$dst, $src;"), []>, + Requires; def _u32 : NVPTXInst<(outs RC:$dst), (ins Int32Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".u32 \t$dst, $src;"), []>; + ToType, ".u32 \t$dst, $src;"), []>, + Requires; def _s64 : NVPTXInst<(outs RC:$dst), (ins Int64Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".s64 \t$dst, $src;"), []>; + ToType, ".s64 \t$dst, $src;"), []>, + Requires; def _u64 : NVPTXInst<(outs RC:$dst), (ins Int64Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".u64 \t$dst, $src;"), []>; + ToType, ".u64 \t$dst, $src;"), []>, + Requires; def _f16 : NVPTXInst<(outs RC:$dst), (ins Int16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".f16 \t$dst, $src;"), []>; + ToType, ".f16 \t$dst, $src;"), []>, + Requires; + def _bf16 : + NVPTXInst<(outs RC:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:relu}${mode:sat}.", + ToType, ".bf16 \t$dst, $src;"), []>, + Requiresf32 was introduced early. + [hasPTX<71>, hasSM<80>], + // bf16->everything else needs sm90/ptx78 + [hasPTX<78>, hasSM<90>])>; def _f32 : NVPTXInst<(outs RC:$dst), (ins Float32Regs:$src, CvtMode:$mode), - !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".f32 \t$dst, $src;"), []>; + !strconcat("cvt${mode:base}${mode:ftz}${mode:relu}${mode:sat}.", + ToType, ".f32 \t$dst, $src;"), []>, + Requiresbf16 was introduced early. + [hasPTX<70>, hasSM<80>], + Preds)>; def _f64 : NVPTXInst<(outs RC:$dst), (ins Float64Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", - FromName, ".f64 \t$dst, $src;"), []>; + ToType, ".f64 \t$dst, $src;"), []>, + Requires; } // Generate cvts from all types to all types. @@ -538,6 +638,7 @@ let hasSideEffects = false in { defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>; defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>; defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>; + defm CVT_bf16 : CVT_FROM_ALL<"bf16", Int16Regs, [hasPTX<78>, hasSM<90>]>; defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>; defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>; @@ -556,18 +657,7 @@ let hasSideEffects = false in { def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src), "cvt.s64.s32 \t$dst, $src;", []>; -multiclass CVT_FROM_FLOAT_SM80 { - def _f32 : - NVPTXInst<(outs RC:$dst), - (ins Float32Regs:$src, CvtMode:$mode), - !strconcat("cvt${mode:base}${mode:relu}.", - FromName, ".f32 \t$dst, $src;"), []>, - Requires<[hasPTX<70>, hasSM<80>]>; - } - - defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>; - - multiclass CVT_FROM_FLOAT_V2_SM80 { + multiclass CVT_FROM_FLOAT_V2_SM80 { def _f32 : NVPTXInst<(outs RC:$dst), (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), @@ -641,6 +731,7 @@ defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>; defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>; defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>; defm SELP_f16 : SELP_PATTERN<"b16", f16, Int16Regs, f16imm, fpimm>; +defm SELP_bf16 : SELP_PATTERN<"b16", bf16, Int16Regs, bf16imm, fpimm>; defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>; defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>; @@ -1005,7 +1096,9 @@ def DoubleConst1 : PatLeaf<(fpimm), [{ def LOAD_CONST_F16 : NVPTXInst<(outs Int16Regs:$dst), (ins f16imm:$a), "mov.b16 \t$dst, $a;", []>; - +def LOAD_CONST_BF16 : + NVPTXInst<(outs Int16Regs:$dst), (ins bf16imm:$a), + "mov.b16 \t$dst, $a;", []>; defm FADD : F3_fma_component<"add", fadd>; defm FSUB : F3_fma_component<"sub", fsub>; defm FMUL : F3_fma_component<"mul", fmul>; @@ -1033,6 +1126,20 @@ def FNEG16 : FNEG_F16_F16X2<"neg.f16", f16, Int16Regs, True>; def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>; def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", v2f16, Int32Regs, True>; +// +// BF16 NEG +// + +class FNEG_BF16_F16X2 : + NVPTXInst<(outs RC:$dst), (ins RC:$src), + !strconcat(OpcStr, " \t$dst, $src;"), + [(set RC:$dst, (fneg (T RC:$src)))]>, + Requires<[hasBF16Math, hasPTX<70>, hasSM<80>, Pred]>; +def BFNEG16_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, Int16Regs, doF32FTZ>; +def BFNEG16 : FNEG_BF16_F16X2<"neg.bf16", bf16, Int16Regs, True>; +def BFNEG16x2_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>; +def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>; + // // F64 division // @@ -1211,13 +1318,24 @@ multiclass FMA_F16 Requires<[useFP16Math, Pred]>; } -defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>; -defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>; -defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>; -defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>; -defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; -defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>; -defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>; +multiclass FMA_BF16 { + def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), + !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), + [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>, + Requires<[hasBF16Math, Pred]>; +} + +defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>; +defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>; +defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>; +defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>; +defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>; +defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>; +defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>; +defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>; +defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; +defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>; +defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>; // sin/cos def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), @@ -1661,6 +1779,18 @@ def SETP_f16x2rr : "setp${cmp:base}${cmp:ftz}.f16x2 \t$p|$q, $a, $b;", []>, Requires<[useFP16Math]>; +def SETP_bf16rr : + NVPTXInst<(outs Int1Regs:$dst), + (ins Int16Regs:$a, Int16Regs:$b, CmpMode:$cmp), + "setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;", + []>, Requires<[hasBF16Math]>; + +def SETP_bf16x2rr : + NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q), + (ins Int32Regs:$a, Int32Regs:$b, CmpMode:$cmp), + "setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;", + []>, + Requires<[hasBF16Math]>; // FIXME: This doesn't appear to be correct. The "set" mnemonic has the form @@ -1691,6 +1821,7 @@ defm SET_b64 : SET<"b64", Int64Regs, i64imm>; defm SET_s64 : SET<"s64", Int64Regs, i64imm>; defm SET_u64 : SET<"u64", Int64Regs, i64imm>; defm SET_f16 : SET<"f16", Int16Regs, f16imm>; +defm SET_bf16 : SET<"bf16", Int16Regs, bf16imm>; defm SET_f32 : SET<"f32", Float32Regs, f32imm>; defm SET_f64 : SET<"f64", Float64Regs, f64imm>; @@ -1959,6 +2090,26 @@ multiclass FSET_FORMAT { (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Int16Regs:$b, Mode)>, Requires<[useFP16Math]>; + // bf16 -> pred + def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))), + (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>, + Requires<[hasBF16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))), + (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>, + Requires<[hasBF16Math]>; + def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), fpimm:$b)), + (SETP_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>, + Requires<[hasBF16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), fpimm:$b)), + (SETP_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>, + Requires<[hasBF16Math]>; + def : Pat<(i1 (OpNode fpimm:$a, (bf16 Int16Regs:$b))), + (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, ModeFTZ)>, + Requires<[hasBF16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode fpimm:$a, (bf16 Int16Regs:$b))), + (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, Mode)>, + Requires<[hasBF16Math]>; + // f32 -> pred def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)), (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, @@ -2004,6 +2155,26 @@ multiclass FSET_FORMAT { (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Int16Regs:$b, Mode)>, Requires<[useFP16Math]>; + // bf16 -> i32 + def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))), + (SET_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>, + Requires<[hasBF16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))), + (SET_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>, + Requires<[hasBF16Math]>; + def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), fpimm:$b)), + (SET_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>, + Requires<[hasBF16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), fpimm:$b)), + (SET_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>, + Requires<[hasBF16Math]>; + def : Pat<(i32 (OpNode fpimm:$a, (bf16 Int16Regs:$b))), + (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, ModeFTZ)>, + Requires<[hasBF16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode fpimm:$a, (bf16 Int16Regs:$b))), + (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, Mode)>, + Requires<[hasBF16Math]>; + // f32 -> i32 def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)), (SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, @@ -2430,7 +2601,7 @@ def MoveParamSymbolI32 : MoveParamSymbolInst; def MoveParamI16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), - "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ? + "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ? [(set i16:$dst, (MoveParam i16:$src))]>; def MoveParamF64 : MoveParamInst; def MoveParamF32 : MoveParamInst; @@ -2776,7 +2947,7 @@ def: Pat<(vt (bitconvert (i16 Int16Regs:$a))), def: Pat<(i16 (bitconvert (vt Int16Regs:$a))), (ProxyRegI16 Int16Regs:$a)>; } - + // NOTE: pred->fp are currently sub-optimal due to an issue in TableGen where // we cannot specify floating-point literals in isel patterns. Therefore, we // use an integer selp to select either 1 or 0 and then cvt to floating-point. @@ -2801,6 +2972,26 @@ def : Pat<(f16 (uint_to_fp Int32Regs:$a)), def : Pat<(f16 (uint_to_fp Int64Regs:$a)), (CVT_f16_u64 Int64Regs:$a, CvtRN)>; +// sint -> bf16 +def : Pat<(bf16 (sint_to_fp Int1Regs:$a)), + (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(bf16 (sint_to_fp Int16Regs:$a)), + (CVT_bf16_s16 Int16Regs:$a, CvtRN)>; +def : Pat<(bf16 (sint_to_fp Int32Regs:$a)), + (CVT_bf16_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(bf16 (sint_to_fp Int64Regs:$a)), + (CVT_bf16_s64 Int64Regs:$a, CvtRN)>; + +// uint -> bf16 +def : Pat<(bf16 (uint_to_fp Int1Regs:$a)), + (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(bf16 (uint_to_fp Int16Regs:$a)), + (CVT_bf16_u16 Int16Regs:$a, CvtRN)>; +def : Pat<(bf16 (uint_to_fp Int32Regs:$a)), + (CVT_bf16_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(bf16 (uint_to_fp Int64Regs:$a)), + (CVT_bf16_u64 Int64Regs:$a, CvtRN)>; + // sint -> f32 def : Pat<(f32 (sint_to_fp Int1Regs:$a)), (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; @@ -2862,6 +3053,25 @@ def : Pat<(i32 (fp_to_uint (f16 Int16Regs:$a))), def : Pat<(i64 (fp_to_uint (f16 Int16Regs:$a))), (CVT_u64_f16 Int16Regs:$a, CvtRZI)>; +// bf16 -> sint +def : Pat<(i1 (fp_to_sint (bf16 Int16Regs:$a))), + (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>; +def : Pat<(i16 (fp_to_sint (bf16 Int16Regs:$a))), + (CVT_s16_bf16 (bf16 Int16Regs:$a), CvtRZI)>; +def : Pat<(i32 (fp_to_sint (bf16 Int16Regs:$a))), + (CVT_s32_bf16 (bf16 Int16Regs:$a), CvtRZI)>; +def : Pat<(i64 (fp_to_sint (bf16 Int16Regs:$a))), + (CVT_s64_bf16 Int16Regs:$a, CvtRZI)>; + +// bf16 -> uint +def : Pat<(i1 (fp_to_uint (bf16 Int16Regs:$a))), + (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>; +def : Pat<(i16 (fp_to_uint (bf16 Int16Regs:$a))), + (CVT_u16_bf16 Int16Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_uint (bf16 Int16Regs:$a))), + (CVT_u32_bf16 Int16Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_uint (bf16 Int16Regs:$a))), + (CVT_u64_bf16 Int16Regs:$a, CvtRZI)>; // f32 -> sint def : Pat<(i1 (fp_to_sint Float32Regs:$a)), (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; @@ -3009,6 +3219,9 @@ def : Pat<(select Int32Regs:$pred, Int64Regs:$a, Int64Regs:$b), def : Pat<(select Int32Regs:$pred, (f16 Int16Regs:$a), (f16 Int16Regs:$b)), (SELP_f16rr Int16Regs:$a, Int16Regs:$b, (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; +def : Pat<(select Int32Regs:$pred, (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)), + (SELP_bf16rr Int16Regs:$a, Int16Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b), (SELP_f32rr Float32Regs:$a, Float32Regs:$b, (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; @@ -3080,6 +3293,13 @@ def : Pat<(f16 (extractelt (v2f16 Int32Regs:$src), 1)), def : Pat<(v2f16 (build_vector (f16 Int16Regs:$a), (f16 Int16Regs:$b))), (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>; +def : Pat<(bf16 (extractelt (v2bf16 Int32Regs:$src), 0)), + (I32toI16L Int32Regs:$src)>; +def : Pat<(bf16 (extractelt (v2bf16 Int32Regs:$src), 1)), + (I32toI16H Int32Regs:$src)>; +def : Pat<(v2bf16 (build_vector (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))), + (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>; + // Count leading zeros let hasSideEffects = false in { def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), @@ -3147,10 +3367,17 @@ def : Pat<(i32 (zext (i16 (ctpop Int16Regs:$a)))), def : Pat<(f16 (fpround Float32Regs:$a)), (CVT_f16_f32 Float32Regs:$a, CvtRN)>; +// fpround f32 -> bf16 +def : Pat<(bf16 (fpround Float32Regs:$a)), + (CVT_bf16_f32 Float32Regs:$a, CvtRN)>; + // fpround f64 -> f16 def : Pat<(f16 (fpround Float64Regs:$a)), (CVT_f16_f64 Float64Regs:$a, CvtRN)>; +// fpround f64 -> bf16 +def : Pat<(bf16 (fpround Float64Regs:$a)), + (CVT_bf16_f64 Float64Regs:$a, CvtRN)>; // fpround f64 -> f32 def : Pat<(f32 (fpround Float64Regs:$a)), (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; @@ -3162,11 +3389,20 @@ def : Pat<(f32 (fpextend (f16 Int16Regs:$a))), (CVT_f32_f16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpextend (f16 Int16Regs:$a))), (CVT_f32_f16 Int16Regs:$a, CvtNONE)>; +// fpextend bf16 -> f32 +def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))), + (CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))), + (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>; // fpextend f16 -> f64 def : Pat<(f64 (fpextend (f16 Int16Regs:$a))), (CVT_f64_f16 Int16Regs:$a, CvtNONE)>; +// fpextend bf16 -> f64 +def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))), + (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>; + // fpextend f32 -> f64 def : Pat<(f64 (fpextend Float32Regs:$a)), (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; @@ -3181,6 +3417,8 @@ def retglue : SDNode<"NVPTXISD::RET_GLUE", SDTNone, multiclass CVT_ROUND { def : Pat<(OpNode (f16 Int16Regs:$a)), (CVT_f16_f16 Int16Regs:$a, Mode)>; + def : Pat<(OpNode (bf16 Int16Regs:$a)), + (CVT_bf16_bf16 Int16Regs:$a, Mode)>; def : Pat<(OpNode Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(OpNode Float32Regs:$a), diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index bfc79d383191bf..f0de0144d410e9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -998,6 +998,18 @@ multiclass FMA_INST { FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Int16Regs, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, Int16Regs, + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, Int16Regs, + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, Int16Regs, + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs, + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, Int16Regs, + [hasPTX<70>, hasSM<80>]>, + FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Int32Regs, [hasPTX<42>, hasSM<53>]>, FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Int32Regs, @@ -1010,11 +1022,6 @@ multiclass FMA_INST { [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2, Int32Regs, [hasPTX<70>, hasSM<80>]>, - - FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs, - [hasPTX<70>, hasSM<80>]>, - FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs, [hasPTX<70>, hasSM<80>]>, FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs, @@ -2207,10 +2214,6 @@ defm INT_PTX_LDU_G_v2i16_ELE : VLDU_G_ELE_V2<"v2.u16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>; defm INT_PTX_LDU_G_v2i32_ELE : VLDU_G_ELE_V2<"v2.u32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>; -defm INT_PTX_LDU_G_v2f16_ELE - : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>; -defm INT_PTX_LDU_G_v2f16x2_ELE - : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>; defm INT_PTX_LDU_G_v2f32_ELE : VLDU_G_ELE_V2<"v2.f32 \t{{$dst1, $dst2}}, [$src];", Float32Regs>; defm INT_PTX_LDU_G_v2i64_ELE diff --git a/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp b/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp index 5ec1b2425e68fc..95125eb41bc058 100644 --- a/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp @@ -34,6 +34,11 @@ void NVPTXFloatMCExpr::printImpl(raw_ostream &OS, const MCAsmInfo *MAI) const { NumHex = 4; APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); break; + case VK_NVPTX_BFLOAT_PREC_FLOAT: + OS << "0x"; + NumHex = 4; + APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored); + break; case VK_NVPTX_SINGLE_PREC_FLOAT: OS << "0f"; NumHex = 8; diff --git a/llvm/lib/Target/NVPTX/NVPTXMCExpr.h b/llvm/lib/Target/NVPTX/NVPTXMCExpr.h index 440fa1310003e0..ef99def06c4da3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXMCExpr.h +++ b/llvm/lib/Target/NVPTX/NVPTXMCExpr.h @@ -21,6 +21,7 @@ class NVPTXFloatMCExpr : public MCTargetExpr { public: enum VariantKind { VK_NVPTX_None, + VK_NVPTX_BFLOAT_PREC_FLOAT, // FP constant in bfloat-precision VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision @@ -40,6 +41,11 @@ class NVPTXFloatMCExpr : public MCTargetExpr { static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt, MCContext &Ctx); + static const NVPTXFloatMCExpr *createConstantBFPHalf(const APFloat &Flt, + MCContext &Ctx) { + return create(VK_NVPTX_BFLOAT_PREC_FLOAT, Flt, Ctx); + } + static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt, MCContext &Ctx) { return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx); diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp index 2347f46449d5f4..7fa64af196b936 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -26,7 +26,6 @@ static cl::opt NoF16Math("nvptx-no-f16-math", cl::Hidden, cl::desc("NVPTX Specific: Disable generation of f16 math ops."), cl::init(false)); - // Pin the vtable to this file. void NVPTXSubtarget::anchor() {} diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h index 920f5bb94689d9..93af11c258b480 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -76,6 +76,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo { inline bool hasHWROT32() const { return SmVersion >= 32; } bool hasImageHandles() const; bool hasFP16Math() const { return SmVersion >= 53; } + bool hasBF16Math() const { return SmVersion >= 80; } bool allowFP16Math() const; bool hasMaskOperator() const { return PTXVersion >= 71; } bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; } diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index f39934ae13e808..c73721da46e359 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -204,6 +204,14 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { return {Intrinsic::fma, FTZ_MustBeOff, true}; case Intrinsic::nvvm_fma_rn_ftz_f16x2: return {Intrinsic::fma, FTZ_MustBeOn, true}; + case Intrinsic::nvvm_fma_rn_bf16: + return {Intrinsic::fma, FTZ_MustBeOff, true}; + case Intrinsic::nvvm_fma_rn_ftz_bf16: + return {Intrinsic::fma, FTZ_MustBeOn, true}; + case Intrinsic::nvvm_fma_rn_bf16x2: + return {Intrinsic::fma, FTZ_MustBeOff, true}; + case Intrinsic::nvvm_fma_rn_ftz_bf16x2: + return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fmax_d: return {Intrinsic::maxnum, FTZ_Any}; case Intrinsic::nvvm_fmax_f: diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll new file mode 100644 index 00000000000000..3373cf1401aae4 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll @@ -0,0 +1,194 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK,SM80 %s +; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK,SM90 %s +; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 | %ptxas-verify -arch=sm_80 %} +; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %} + +; LDST: .b8 bfloat_array[8] = {1, 2, 3, 4, 5, 6, 7, 8}; +@"bfloat_array" = addrspace(1) constant [4 x bfloat] + [bfloat 0xR0201, bfloat 0xR0403, bfloat 0xR0605, bfloat 0xR0807] + +; CHECK-LABEL: test_fadd( +; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_fadd_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_fadd_param_1]; +; SM90: add.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]]; +; +; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]]; +; SM80-DAG: cvt.f32.bf16 [[FB:%f[0-9]+]], [[B]]; +; SM80: add.rn.f32 [[FR:%f[0-9]+]], [[FA]], [[FB]]; +; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; + +define bfloat @test_fadd(bfloat %0, bfloat %1) { + %3 = fadd bfloat %0, %1 + ret bfloat %3 +} + +; CHECK-LABEL: test_fsub( +; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_fsub_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_fsub_param_1]; +; SM90: sub.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]]; +; +; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]]; +; SM80-DAG: cvt.f32.bf16 [[FB:%f[0-9]+]], [[B]]; +; SM80: sub.rn.f32 [[FR:%f[0-9]+]], [[FA]], [[FB]]; +; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; + +define bfloat @test_fsub(bfloat %0, bfloat %1) { + %3 = fsub bfloat %0, %1 + ret bfloat %3 +} + +; CHECK-LABEL: test_faddx2( +; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_faddx2_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_faddx2_param_1]; +; SM90: add.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]]; + +; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; +; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]; +; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; +; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]]; +; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]]; +; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; +; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; +; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; + +define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fadd <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_fsubx2( +; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fsubx2_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fsubx2_param_1]; +; SM90: sub.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]]; + +; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; +; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]; +; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; +; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]]; +; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]]; +; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; +; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; +; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; + +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; + +define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fsub <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_fmulx2( +; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fmulx2_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fmulx2_param_1]; +; SM90: mul.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]]; + +; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]; +; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]; +; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; +; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]]; +; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]]; +; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; +; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; +; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}; + +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; + +define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fmul <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_fdiv( +; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fdiv_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fdiv_param_1]; +; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]] +; CHECK-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]] +; CHECK-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]]; +; CHECK-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; +; CHECK-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]]; +; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; +; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; +; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]; +; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]; +; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]} +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; + +define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fdiv <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_extract_0( +; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_extract_0_param_0]; +; CHECK: st.param.b16 [func_retval0+0], [[A]]; +; CHECK: ret; + +define bfloat @test_extract_0(<2 x bfloat> %a) #0 { + %e = extractelement <2 x bfloat> %a, i32 0 + ret bfloat %e +} + +; CHECK-LABEL: test_extract_1( +; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_extract_1_param_0+2]; +; CHECK: st.param.b16 [func_retval0+0], [[A]]; +; CHECK: ret; + +define bfloat @test_extract_1(<2 x bfloat> %a) #0 { + %e = extractelement <2 x bfloat> %a, i32 1 + ret bfloat %e +} + +; CHECK-LABEL: test_fpext_float( +; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fpext_float_param_0]; +; CHECK: cvt.f32.bf16 [[R:%f[0-9]+]], [[A]]; +; CHECK: st.param.f32 [func_retval0+0], [[R]]; +; CHECK: ret; +define float @test_fpext_float(bfloat %a) #0 { + %r = fpext bfloat %a to float + ret float %r +} + +; CHECK-LABEL: test_fptrunc_float( +; CHECK: ld.param.f32 [[A:%f[0-9]+]], [test_fptrunc_float_param_0]; +; CHECK: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[A]]; +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK: ret; +define bfloat @test_fptrunc_float(float %a) #0 { + %r = fptrunc float %a to bfloat + ret bfloat %r +} + +; CHECK-LABEL: test_fadd_imm_1( +; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fadd_imm_1_param_0]; +; SM90: mov.b16 [[B:%rs[0-9]+]], 0x3F80; +; SM90: add.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]]; + +; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]]; +; SM80: add.rn.f32 [[FR:%f[0-9]+]], [[FA]], 0f3F800000; +; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]]; + +; CHECK: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; +define bfloat @test_fadd_imm_1(bfloat %a) #0 { + %r = fadd bfloat %a, 1.0 + ret bfloat %r +} diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll index 6aac2dd18775eb..4e30cebfe90251 100644 --- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll +++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll @@ -3,45 +3,45 @@ ; CHECK-LABEL: cvt_rn_bf16x2_f32 -define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) { +define <2 x bfloat> @cvt_rn_bf16x2_f32(float %f1, float %f2) { ; CHECK: cvt.rn.bf16x2.f32 - %val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2); + %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2); -ret i32 %val +ret <2 x bfloat> %val } ; CHECK-LABEL: cvt_rn_relu_bf16x2_f32 -define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) { +define <2 x bfloat> @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) { ; CHECK: cvt.rn.relu.bf16x2.f32 -%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2); +%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2); -ret i32 %val +ret <2 x bfloat> %val } ; CHECK-LABEL: cvt_rz_bf16x2_f32 -define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) { +define <2 x bfloat> @cvt_rz_bf16x2_f32(float %f1, float %f2) { ; CHECK: cvt.rz.bf16x2.f32 - %val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2); + %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2); -ret i32 %val +ret <2 x bfloat> %val } ; CHECK-LABEL: cvt_rz_relu_bf16x2_f32 -define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) { +define <2 x bfloat> @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) { ; CHECK: cvt.rz.relu.bf16x2.f32 -%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2); +%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2); -ret i32 %val +ret <2 x bfloat> %val } -declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float) -declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float) -declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float) -declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float) +declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float, float) +declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float, float) +declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float, float) +declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float, float) ; CHECK-LABEL: cvt_rn_f16x2_f32 define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) { @@ -85,45 +85,45 @@ declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float) declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float) ; CHECK-LABEL: cvt_rn_bf16_f32 -define i16 @cvt_rn_bf16_f32(float %f1) { +define bfloat @cvt_rn_bf16_f32(float %f1) { ; CHECK: cvt.rn.bf16.f32 - %val = call i16 @llvm.nvvm.f2bf16.rn(float %f1); + %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1); -ret i16 %val +ret bfloat %val } ; CHECK-LABEL: cvt_rn_relu_bf16_f32 -define i16 @cvt_rn_relu_bf16_f32(float %f1) { +define bfloat @cvt_rn_relu_bf16_f32(float %f1) { ; CHECK: cvt.rn.relu.bf16.f32 -%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1); +%val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1); -ret i16 %val +ret bfloat %val } ; CHECK-LABEL: cvt_rz_bf16_f32 -define i16 @cvt_rz_bf16_f32(float %f1) { +define bfloat @cvt_rz_bf16_f32(float %f1) { ; CHECK: cvt.rz.bf16.f32 - %val = call i16 @llvm.nvvm.f2bf16.rz(float %f1); + %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1); -ret i16 %val +ret bfloat %val } ; CHECK-LABEL: cvt_rz_relu_bf16_f32 -define i16 @cvt_rz_relu_bf16_f32(float %f1) { +define bfloat @cvt_rz_relu_bf16_f32(float %f1) { ; CHECK: cvt.rz.relu.bf16.f32 -%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1); +%val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1); -ret i16 %val +ret bfloat %val } -declare i16 @llvm.nvvm.f2bf16.rn(float) -declare i16 @llvm.nvvm.f2bf16.rn.relu(float) -declare i16 @llvm.nvvm.f2bf16.rz(float) -declare i16 @llvm.nvvm.f2bf16.rz.relu(float) +declare bfloat @llvm.nvvm.f2bf16.rn(float) +declare bfloat @llvm.nvvm.f2bf16.rn.relu(float) +declare bfloat @llvm.nvvm.f2bf16.rz(float) +declare bfloat @llvm.nvvm.f2bf16.rz.relu(float) ; CHECK-LABEL: cvt_rna_tf32_f32 define i32 @cvt_rna_tf32_f32(float %f1) { diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll index 55fde7837487b5..deea2e3b557f16 100644 --- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll @@ -246,11 +246,11 @@ declare half @test_callee(half %a, half %b) #0 ; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_call_param_0]; ; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_call_param_1]; ; CHECK: { -; CHECK-DAG: .param .b32 param0; -; CHECK-DAG: .param .b32 param1; +; CHECK-DAG: .param .align 2 .b8 param0[2]; +; CHECK-DAG: .param .align 2 .b8 param1[2]; ; CHECK-DAG: st.param.b16 [param0+0], [[A]]; ; CHECK-DAG: st.param.b16 [param1+0], [[B]]; -; CHECK-DAG: .param .b32 retval0; +; CHECK-DAG: .param .align 2 .b8 retval0[2]; ; CHECK: call.uni (retval0), ; CHECK-NEXT: test_callee, ; CHECK: ); @@ -267,11 +267,11 @@ define half @test_call(half %a, half %b) #0 { ; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_call_flipped_param_0]; ; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_call_flipped_param_1]; ; CHECK: { -; CHECK-DAG: .param .b32 param0; -; CHECK-DAG: .param .b32 param1; +; CHECK-DAG: .param .align 2 .b8 param0[2]; +; CHECK-DAG: .param .align 2 .b8 param1[2]; ; CHECK-DAG: st.param.b16 [param0+0], [[B]]; ; CHECK-DAG: st.param.b16 [param1+0], [[A]]; -; CHECK-DAG: .param .b32 retval0; +; CHECK-DAG: .param .align 2 .b8 retval0[2]; ; CHECK: call.uni (retval0), ; CHECK-NEXT: test_callee, ; CHECK: ); @@ -288,11 +288,11 @@ define half @test_call_flipped(half %a, half %b) #0 { ; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_tailcall_flipped_param_0]; ; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_tailcall_flipped_param_1]; ; CHECK: { -; CHECK-DAG: .param .b32 param0; -; CHECK-DAG: .param .b32 param1; +; CHECK-DAG: .param .align 2 .b8 param0[2]; +; CHECK-DAG: .param .align 2 .b8 param1[2]; ; CHECK-DAG: st.param.b16 [param0+0], [[B]]; ; CHECK-DAG: st.param.b16 [param1+0], [[A]]; -; CHECK-DAG: .param .b32 retval0; +; CHECK-DAG: .param .align 2 .b8 retval0[2]; ; CHECK: call.uni (retval0), ; CHECK-NEXT: test_callee, ; CHECK: ); diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll new file mode 100644 index 00000000000000..34b9c085093269 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll @@ -0,0 +1,366 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s +; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %} + +declare i16 @llvm.nvvm.abs.bf16(i16) +declare i32 @llvm.nvvm.abs.bf16x2(i32) +declare i16 @llvm.nvvm.neg.bf16(i16) +declare i32 @llvm.nvvm.neg.bf16x2(i32) + +declare float @llvm.nvvm.fmin.nan.f(float, float) +declare float @llvm.nvvm.fmin.ftz.nan.f(float, float) +declare half @llvm.nvvm.fmin.f16(half, half) +declare half @llvm.nvvm.fmin.ftz.f16(half, half) +declare half @llvm.nvvm.fmin.nan.f16(half, half) +declare half @llvm.nvvm.fmin.ftz.nan.f16(half, half) +declare <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half>, <2 x half>) +declare i16 @llvm.nvvm.fmin.bf16(i16, i16) +declare i16 @llvm.nvvm.fmin.nan.bf16(i16, i16) +declare i32 @llvm.nvvm.fmin.bf16x2(i32, i32) +declare i32 @llvm.nvvm.fmin.nan.bf16x2(i32, i32) + +declare float @llvm.nvvm.fmax.nan.f(float, float) +declare float @llvm.nvvm.fmax.ftz.nan.f(float, float) +declare half @llvm.nvvm.fmax.f16(half, half) +declare half @llvm.nvvm.fmax.ftz.f16(half, half) +declare half @llvm.nvvm.fmax.nan.f16(half, half) +declare half @llvm.nvvm.fmax.ftz.nan.f16(half, half) +declare <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half>, <2 x half>) +declare i16 @llvm.nvvm.fmax.bf16(i16, i16) +declare i16 @llvm.nvvm.fmax.nan.bf16(i16, i16) +declare i32 @llvm.nvvm.fmax.bf16x2(i32, i32) +declare i32 @llvm.nvvm.fmax.nan.bf16x2(i32, i32) + +declare half @llvm.nvvm.fma.rn.relu.f16(half, half, half) +declare half @llvm.nvvm.fma.rn.ftz.relu.f16(half, half, half) +declare <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half>, <2 x half>, <2 x half>) +declare <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half>, <2 x half>, <2 x half>) +declare i16 @llvm.nvvm.fma.rn.bf16(i16, i16, i16) +declare i16 @llvm.nvvm.fma.rn.relu.bf16(i16, i16, i16) +declare i32 @llvm.nvvm.fma.rn.bf16x2(i32, i32, i32) +declare i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32, i32, i32) + +; CHECK-LABEL: abs_bf16 +define i16 @abs_bf16(i16 %0) { + ; CHECK-NOT: call + ; CHECK: abs.bf16 + %res = call i16 @llvm.nvvm.abs.bf16(i16 %0); + ret i16 %res +} + +; CHECK-LABEL: abs_bf16x2 +define i32 @abs_bf16x2(i32 %0) { + ; CHECK-NOT: call + ; CHECK: abs.bf16x2 + %res = call i32 @llvm.nvvm.abs.bf16x2(i32 %0); + ret i32 %res +} + +; CHECK-LABEL: neg_bf16 +define i16 @neg_bf16(i16 %0) { + ; CHECK-NOT: call + ; CHECK: neg.bf16 + %res = call i16 @llvm.nvvm.neg.bf16(i16 %0); + ret i16 %res +} + +; CHECK-LABEL: neg_bf16x2 +define i32 @neg_bf16x2(i32 %0) { + ; CHECK-NOT: call + ; CHECK: neg.bf16x2 + %res = call i32 @llvm.nvvm.neg.bf16x2(i32 %0); + ret i32 %res +} + +; CHECK-LABEL: fmin_nan_f +define float @fmin_nan_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.f32 + %res = call float @llvm.nvvm.fmin.nan.f(float %0, float %1); + ret float %res +} + +; CHECK-LABEL: fmin_ftz_nan_f +define float @fmin_ftz_nan_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.NaN.f32 + %res = call float @llvm.nvvm.fmin.ftz.nan.f(float %0, float %1); + ret float %res +} + +; CHECK-LABEL: fmin_f16 +define half @fmin_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.f16 + %res = call half @llvm.nvvm.fmin.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_ftz_f16 +define half @fmin_ftz_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.f16 + %res = call half @llvm.nvvm.fmin.ftz.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_nan_f16 +define half @fmin_nan_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.f16 + %res = call half @llvm.nvvm.fmin.nan.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_ftz_nan_f16 +define half @fmin_ftz_nan_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.NaN.f16 + %res = call half @llvm.nvvm.fmin.ftz.nan.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_f16x2 +define <2 x half> @fmin_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_ftz_f16x2 +define <2 x half> @fmin_ftz_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_nan_f16x2 +define <2 x half> @fmin_nan_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_ftz_nan_f16x2 +define <2 x half> @fmin_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.NaN.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_bf16 +define i16 @fmin_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: min.bf16 + %res = call i16 @llvm.nvvm.fmin.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmin_nan_bf16 +define i16 @fmin_nan_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.bf16 + %res = call i16 @llvm.nvvm.fmin.nan.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmin_bf16x2 +define i32 @fmin_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: min.bf16x2 + %res = call i32 @llvm.nvvm.fmin.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmin_nan_bf16x2 +define i32 @fmin_nan_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.bf16x2 + %res = call i32 @llvm.nvvm.fmin.nan.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmax_nan_f +define float @fmax_nan_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.f32 + %res = call float @llvm.nvvm.fmax.nan.f(float %0, float %1); + ret float %res +} + +; CHECK-LABEL: fmax_ftz_nan_f +define float @fmax_ftz_nan_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.NaN.f32 + %res = call float @llvm.nvvm.fmax.ftz.nan.f(float %0, float %1); + ret float %res +} + +; CHECK-LABEL: fmax_f16 +define half @fmax_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.f16 + %res = call half @llvm.nvvm.fmax.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_ftz_f16 +define half @fmax_ftz_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.f16 + %res = call half @llvm.nvvm.fmax.ftz.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_nan_f16 +define half @fmax_nan_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.f16 + %res = call half @llvm.nvvm.fmax.nan.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_ftz_nan_f16 +define half @fmax_ftz_nan_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.NaN.f16 + %res = call half @llvm.nvvm.fmax.ftz.nan.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_f16x2 +define <2 x half> @fmax_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_ftz_f16x2 +define <2 x half> @fmax_ftz_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_nan_f16x2 +define <2 x half> @fmax_nan_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_ftz_nan_f16x2 +define <2 x half> @fmax_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.NaN.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_bf16 +define i16 @fmax_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: max.bf16 + %res = call i16 @llvm.nvvm.fmax.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmax_nan_bf16 +define i16 @fmax_nan_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.bf16 + %res = call i16 @llvm.nvvm.fmax.nan.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmax_bf16x2 +define i32 @fmax_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: max.bf16x2 + %res = call i32 @llvm.nvvm.fmax.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmax_nan_bf16x2 +define i32 @fmax_nan_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.bf16x2 + %res = call i32 @llvm.nvvm.fmax.nan.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fma_rn_relu_f16 +define half @fma_rn_relu_f16(half %0, half %1, half %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.relu.f16 + %res = call half @llvm.nvvm.fma.rn.relu.f16(half %0, half %1, half %2) + ret half %res +} + +; CHECK-LABEL: fma_rn_ftz_relu_f16 +define half @fma_rn_ftz_relu_f16(half %0, half %1, half %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.ftz.relu.f16 + %res = call half @llvm.nvvm.fma.rn.ftz.relu.f16(half %0, half %1, half %2) + ret half %res +} + +; CHECK-LABEL: fma_rn_relu_f16x2 +define <2 x half> @fma_rn_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.relu.f16x2 + %res = call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) + ret <2 x half> %res +} + +; CHECK-LABEL: fma_rn_ftz_relu_f16x2 +define <2 x half> @fma_rn_ftz_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.ftz.relu.f16x2 + %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) + ret <2 x half> %res +} + +; CHECK-LABEL: fma_rn_bf16 +define i16 @fma_rn_bf16(i16 %0, i16 %1, i16 %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.bf16 + %res = call i16 @llvm.nvvm.fma.rn.bf16(i16 %0, i16 %1, i16 %2) + ret i16 %res +} + +; CHECK-LABEL: fma_rn_relu_bf16 +define i16 @fma_rn_relu_bf16(i16 %0, i16 %1, i16 %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.relu.bf16 + %res = call i16 @llvm.nvvm.fma.rn.relu.bf16(i16 %0, i16 %1, i16 %2) + ret i16 %res +} + +; CHECK-LABEL: fma_rn_bf16x2 +define i32 @fma_rn_bf16x2(i32 %0, i32 %1, i32 %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.bf16x2 + %res = call i32 @llvm.nvvm.fma.rn.bf16x2(i32 %0, i32 %1, i32 %2) + ret i32 %res +} + +; CHECK-LABEL: fma_rn_relu_bf16x2 +define i32 @fma_rn_relu_bf16x2(i32 %0, i32 %1, i32 %2) { + ; CHECK-NOT: call + ; CHECK: fma.rn.relu.bf16x2 + %res = call i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32 %0, i32 %1, i32 %2) + ret i32 %res +} diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll index 34b9c085093269..fe05c8e5ec734e 100644 --- a/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll +++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll @@ -1,10 +1,10 @@ ; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s ; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %} -declare i16 @llvm.nvvm.abs.bf16(i16) -declare i32 @llvm.nvvm.abs.bf16x2(i32) -declare i16 @llvm.nvvm.neg.bf16(i16) -declare i32 @llvm.nvvm.neg.bf16x2(i32) +declare bfloat @llvm.nvvm.abs.bf16(bfloat) +declare <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat>) +declare bfloat @llvm.nvvm.neg.bf16(bfloat) +declare <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat>) declare float @llvm.nvvm.fmin.nan.f(float, float) declare float @llvm.nvvm.fmin.ftz.nan.f(float, float) @@ -16,10 +16,10 @@ declare <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half>, <2 x half>) -declare i16 @llvm.nvvm.fmin.bf16(i16, i16) -declare i16 @llvm.nvvm.fmin.nan.bf16(i16, i16) -declare i32 @llvm.nvvm.fmin.bf16x2(i32, i32) -declare i32 @llvm.nvvm.fmin.nan.bf16x2(i32, i32) +declare bfloat @llvm.nvvm.fmin.bf16(bfloat, bfloat) +declare bfloat @llvm.nvvm.fmin.nan.bf16(bfloat, bfloat) +declare <2 x bfloat> @llvm.nvvm.fmin.bf16x2(<2 x bfloat>, <2 x bfloat>) +declare <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2(<2 x bfloat>, <2 x bfloat>) declare float @llvm.nvvm.fmax.nan.f(float, float) declare float @llvm.nvvm.fmax.ftz.nan.f(float, float) @@ -31,50 +31,50 @@ declare <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half>, <2 x half>) -declare i16 @llvm.nvvm.fmax.bf16(i16, i16) -declare i16 @llvm.nvvm.fmax.nan.bf16(i16, i16) -declare i32 @llvm.nvvm.fmax.bf16x2(i32, i32) -declare i32 @llvm.nvvm.fmax.nan.bf16x2(i32, i32) +declare bfloat @llvm.nvvm.fmax.bf16(bfloat, bfloat) +declare bfloat @llvm.nvvm.fmax.nan.bf16(bfloat, bfloat) +declare <2 x bfloat> @llvm.nvvm.fmax.bf16x2(<2 x bfloat>, <2 x bfloat>) +declare <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2(<2 x bfloat>, <2 x bfloat>) declare half @llvm.nvvm.fma.rn.relu.f16(half, half, half) declare half @llvm.nvvm.fma.rn.ftz.relu.f16(half, half, half) declare <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half>, <2 x half>, <2 x half>) declare <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half>, <2 x half>, <2 x half>) -declare i16 @llvm.nvvm.fma.rn.bf16(i16, i16, i16) -declare i16 @llvm.nvvm.fma.rn.relu.bf16(i16, i16, i16) -declare i32 @llvm.nvvm.fma.rn.bf16x2(i32, i32, i32) -declare i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32, i32, i32) +declare bfloat @llvm.nvvm.fma.rn.bf16(bfloat, bfloat, bfloat) +declare bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat, bfloat, bfloat) +declare <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>) +declare <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>) ; CHECK-LABEL: abs_bf16 -define i16 @abs_bf16(i16 %0) { +define bfloat @abs_bf16(bfloat %0) { ; CHECK-NOT: call ; CHECK: abs.bf16 - %res = call i16 @llvm.nvvm.abs.bf16(i16 %0); - ret i16 %res + %res = call bfloat @llvm.nvvm.abs.bf16(bfloat %0); + ret bfloat %res } ; CHECK-LABEL: abs_bf16x2 -define i32 @abs_bf16x2(i32 %0) { +define <2 x bfloat> @abs_bf16x2(<2 x bfloat> %0) { ; CHECK-NOT: call ; CHECK: abs.bf16x2 - %res = call i32 @llvm.nvvm.abs.bf16x2(i32 %0); - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> %0); + ret <2 x bfloat> %res } ; CHECK-LABEL: neg_bf16 -define i16 @neg_bf16(i16 %0) { +define bfloat @neg_bf16(bfloat %0) { ; CHECK-NOT: call ; CHECK: neg.bf16 - %res = call i16 @llvm.nvvm.neg.bf16(i16 %0); - ret i16 %res + %res = call bfloat @llvm.nvvm.neg.bf16(bfloat %0); + ret bfloat %res } ; CHECK-LABEL: neg_bf16x2 -define i32 @neg_bf16x2(i32 %0) { +define <2 x bfloat> @neg_bf16x2(<2 x bfloat> %0) { ; CHECK-NOT: call ; CHECK: neg.bf16x2 - %res = call i32 @llvm.nvvm.neg.bf16x2(i32 %0); - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> %0); + ret <2 x bfloat> %res } ; CHECK-LABEL: fmin_nan_f @@ -158,35 +158,35 @@ define <2 x half> @fmin_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) { } ; CHECK-LABEL: fmin_bf16 -define i16 @fmin_bf16(i16 %0, i16 %1) { +define bfloat @fmin_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: min.bf16 - %res = call i16 @llvm.nvvm.fmin.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmin.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmin_nan_bf16 -define i16 @fmin_nan_bf16(i16 %0, i16 %1) { +define bfloat @fmin_nan_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: min.NaN.bf16 - %res = call i16 @llvm.nvvm.fmin.nan.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmin.nan.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmin_bf16x2 -define i32 @fmin_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmin_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: min.bf16x2 - %res = call i32 @llvm.nvvm.fmin.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmin.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmin_nan_bf16x2 -define i32 @fmin_nan_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmin_nan_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: min.NaN.bf16x2 - %res = call i32 @llvm.nvvm.fmin.nan.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmax_nan_f @@ -270,35 +270,35 @@ define <2 x half> @fmax_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) { } ; CHECK-LABEL: fmax_bf16 -define i16 @fmax_bf16(i16 %0, i16 %1) { +define bfloat @fmax_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: max.bf16 - %res = call i16 @llvm.nvvm.fmax.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmax.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmax_nan_bf16 -define i16 @fmax_nan_bf16(i16 %0, i16 %1) { +define bfloat @fmax_nan_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: max.NaN.bf16 - %res = call i16 @llvm.nvvm.fmax.nan.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmax.nan.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmax_bf16x2 -define i32 @fmax_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmax_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: max.bf16x2 - %res = call i32 @llvm.nvvm.fmax.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmax.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmax_nan_bf16x2 -define i32 @fmax_nan_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmax_nan_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: max.NaN.bf16x2 - %res = call i32 @llvm.nvvm.fmax.nan.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fma_rn_relu_f16 @@ -334,33 +334,33 @@ define <2 x half> @fma_rn_ftz_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half } ; CHECK-LABEL: fma_rn_bf16 -define i16 @fma_rn_bf16(i16 %0, i16 %1, i16 %2) { +define bfloat @fma_rn_bf16(bfloat %0, bfloat %1, bfloat %2) { ; CHECK-NOT: call ; CHECK: fma.rn.bf16 - %res = call i16 @llvm.nvvm.fma.rn.bf16(i16 %0, i16 %1, i16 %2) - ret i16 %res + %res = call bfloat @llvm.nvvm.fma.rn.bf16(bfloat %0, bfloat %1, bfloat %2) + ret bfloat %res } ; CHECK-LABEL: fma_rn_relu_bf16 -define i16 @fma_rn_relu_bf16(i16 %0, i16 %1, i16 %2) { +define bfloat @fma_rn_relu_bf16(bfloat %0, bfloat %1, bfloat %2) { ; CHECK-NOT: call ; CHECK: fma.rn.relu.bf16 - %res = call i16 @llvm.nvvm.fma.rn.relu.bf16(i16 %0, i16 %1, i16 %2) - ret i16 %res + %res = call bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat %0, bfloat %1, bfloat %2) + ret bfloat %res } ; CHECK-LABEL: fma_rn_bf16x2 -define i32 @fma_rn_bf16x2(i32 %0, i32 %1, i32 %2) { +define <2 x bfloat> @fma_rn_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) { ; CHECK-NOT: call ; CHECK: fma.rn.bf16x2 - %res = call i32 @llvm.nvvm.fma.rn.bf16x2(i32 %0, i32 %1, i32 %2) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) + ret <2 x bfloat> %res } ; CHECK-LABEL: fma_rn_relu_bf16x2 -define i32 @fma_rn_relu_bf16x2(i32 %0, i32 %1, i32 %2) { +define <2 x bfloat> @fma_rn_relu_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) { ; CHECK-NOT: call ; CHECK: fma.rn.relu.bf16x2 - %res = call i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32 %0, i32 %1, i32 %2) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) + ret <2 x bfloat> %res } diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll new file mode 100644 index 00000000000000..b745df484bab22 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll @@ -0,0 +1,292 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_86 -mattr=+ptx72 | FileCheck %s +; RUN: %if ptxas-11.2 %{ llc < %s -march=nvptx64 -mcpu=sm_86 -mattr=+ptx72 | %ptxas-verify -arch=sm_86 %} + +declare half @llvm.nvvm.fmin.xorsign.abs.f16(half, half) +declare half @llvm.nvvm.fmin.ftz.xorsign.abs.f16(half, half) +declare half @llvm.nvvm.fmin.nan.xorsign.abs.f16(half, half) +declare half @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16(half, half) +declare <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16, i16) +declare i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16, i16) +declare i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32, i32) +declare i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32, i32) +declare float @llvm.nvvm.fmin.xorsign.abs.f(float, float) +declare float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float, float) +declare float @llvm.nvvm.fmin.nan.xorsign.abs.f(float, float) +declare float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f(float, float) + +declare half @llvm.nvvm.fmax.xorsign.abs.f16(half, half) +declare half @llvm.nvvm.fmax.ftz.xorsign.abs.f16(half, half) +declare half @llvm.nvvm.fmax.nan.xorsign.abs.f16(half, half) +declare half @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16(half, half) +declare <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) +declare i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16, i16) +declare i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16, i16) +declare i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32, i32) +declare i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32, i32) +declare float @llvm.nvvm.fmax.xorsign.abs.f(float, float) +declare float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float, float) +declare float @llvm.nvvm.fmax.nan.xorsign.abs.f(float, float) +declare float @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f(float, float) + +; CHECK-LABEL: fmin_xorsign_abs_f16 +define half @fmin_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmin.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_ftz_xorsign_abs_f16 +define half @fmin_ftz_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmin.ftz.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_nan_xorsign_abs_f16 +define half @fmin_nan_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmin.nan.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f16 +define half @fmin_ftz_nan_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.NaN.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmin_xorsign_abs_f16x2 +define <2 x half> @fmin_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_ftz_xorsign_abs_f16x2 +define <2 x half> @fmin_ftz_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_nan_xorsign_abs_f16x2 +define <2 x half> @fmin_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f16x2 +define <2 x half> @fmin_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.NaN.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmin_xorsign_abs_bf16 +define i16 @fmin_xorsign_abs_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: min.xorsign.abs.bf16 + %res = call i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmin_nan_xorsign_abs_bf16 +define i16 @fmin_nan_xorsign_abs_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.xorsign.abs.bf16 + %res = call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmin_xorsign_abs_bf16x2 +define i32 @fmin_xorsign_abs_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: min.xorsign.abs.bf16x2 + %res = call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmin_nan_xorsign_abs_bf16x2 +define i32 @fmin_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.xorsign.abs.bf16x2 + %res = call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmin_xorsign_abs_f +define float @fmin_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: min.xorsign.abs.f + %res = call float @llvm.nvvm.fmin.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmin_ftz_xorsign_abs_f +define float @fmin_ftz_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.xorsign.abs.f + %res = call float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmin_nan_xorsign_abs_f +define float @fmin_nan_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: min.NaN.xorsign.abs.f + %res = call float @llvm.nvvm.fmin.nan.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f +define float @fmin_ftz_nan_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: min.ftz.NaN.xorsign.abs.f + %res = call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmax_xorsign_abs_f16 +define half @fmax_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmax.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_ftz_xorsign_abs_f16 +define half @fmax_ftz_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmax.ftz.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_nan_xorsign_abs_f16 +define half @fmax_nan_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmax.nan.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f16 +define half @fmax_ftz_nan_xorsign_abs_f16(half %0, half %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.NaN.xorsign.abs.f16 + %res = call half @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16(half %0, half %1) + ret half %res +} + +; CHECK-LABEL: fmax_xorsign_abs_f16x2 +define <2 x half> @fmax_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_ftz_xorsign_abs_f16x2 +define <2 x half> @fmax_ftz_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_nan_xorsign_abs_f16x2 +define <2 x half> @fmax_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f16x2 +define <2 x half> @fmax_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.NaN.xorsign.abs.f16x2 + %res = call <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1) + ret <2 x half> %res +} + +; CHECK-LABEL: fmax_xorsign_abs_bf16 +define i16 @fmax_xorsign_abs_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: max.xorsign.abs.bf16 + %res = call i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmax_nan_xorsign_abs_bf16 +define i16 @fmax_nan_xorsign_abs_bf16(i16 %0, i16 %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.xorsign.abs.bf16 + %res = call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16 %0, i16 %1) + ret i16 %res +} + +; CHECK-LABEL: fmax_xorsign_abs_bf16x2 +define i32 @fmax_xorsign_abs_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: max.xorsign.abs.bf16x2 + %res = call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmax_nan_xorsign_abs_bf16x2 +define i32 @fmax_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.xorsign.abs.bf16x2 + %res = call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32 %0, i32 %1) + ret i32 %res +} + +; CHECK-LABEL: fmax_xorsign_abs_f +define float @fmax_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: max.xorsign.abs.f + %res = call float @llvm.nvvm.fmax.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmax_ftz_xorsign_abs_f +define float @fmax_ftz_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.xorsign.abs.f + %res = call float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmax_nan_xorsign_abs_f +define float @fmax_nan_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: max.NaN.xorsign.abs.f + %res = call float @llvm.nvvm.fmax.nan.xorsign.abs.f(float %0, float %1) + ret float %res +} + +; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f +define float @fmax_ftz_nan_xorsign_abs_f(float %0, float %1) { + ; CHECK-NOT: call + ; CHECK: max.ftz.NaN.xorsign.abs.f + %res = call float @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f(float %0, float %1) + ret float %res +} diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll index b745df484bab22..6d430b052d8fe0 100644 --- a/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll +++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll @@ -9,10 +9,10 @@ declare <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> , <2 x half>) declare <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>) declare <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) declare <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) -declare i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16, i16) -declare i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16, i16) -declare i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32, i32) -declare i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32, i32) +declare bfloat @llvm.nvvm.fmin.xorsign.abs.bf16(bfloat, bfloat) +declare bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16(bfloat, bfloat) +declare <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>) +declare <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>) declare float @llvm.nvvm.fmin.xorsign.abs.f(float, float) declare float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float, float) declare float @llvm.nvvm.fmin.nan.xorsign.abs.f(float, float) @@ -26,10 +26,10 @@ declare <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> , <2 x half>) declare <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>) declare <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) declare <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>) -declare i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16, i16) -declare i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16, i16) -declare i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32, i32) -declare i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32, i32) +declare bfloat @llvm.nvvm.fmax.xorsign.abs.bf16(bfloat, bfloat) +declare bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16(bfloat, bfloat) +declare <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>) +declare <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>) declare float @llvm.nvvm.fmax.xorsign.abs.f(float, float) declare float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float, float) declare float @llvm.nvvm.fmax.nan.xorsign.abs.f(float, float) @@ -100,35 +100,35 @@ define <2 x half> @fmin_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) } ; CHECK-LABEL: fmin_xorsign_abs_bf16 -define i16 @fmin_xorsign_abs_bf16(i16 %0, i16 %1) { +define bfloat @fmin_xorsign_abs_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: min.xorsign.abs.bf16 - %res = call i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmin_nan_xorsign_abs_bf16 -define i16 @fmin_nan_xorsign_abs_bf16(i16 %0, i16 %1) { +define bfloat @fmin_nan_xorsign_abs_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: min.NaN.xorsign.abs.bf16 - %res = call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmin_xorsign_abs_bf16x2 -define i32 @fmin_xorsign_abs_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmin_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: min.xorsign.abs.bf16x2 - %res = call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmin_nan_xorsign_abs_bf16x2 -define i32 @fmin_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmin_nan_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: min.NaN.xorsign.abs.bf16x2 - %res = call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmin_xorsign_abs_f @@ -228,35 +228,35 @@ define <2 x half> @fmax_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) } ; CHECK-LABEL: fmax_xorsign_abs_bf16 -define i16 @fmax_xorsign_abs_bf16(i16 %0, i16 %1) { +define bfloat @fmax_xorsign_abs_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: max.xorsign.abs.bf16 - %res = call i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmax_nan_xorsign_abs_bf16 -define i16 @fmax_nan_xorsign_abs_bf16(i16 %0, i16 %1) { +define bfloat @fmax_nan_xorsign_abs_bf16(bfloat %0, bfloat %1) { ; CHECK-NOT: call ; CHECK: max.NaN.xorsign.abs.bf16 - %res = call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16 %0, i16 %1) - ret i16 %res + %res = call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16(bfloat %0, bfloat %1) + ret bfloat %res } ; CHECK-LABEL: fmax_xorsign_abs_bf16x2 -define i32 @fmax_xorsign_abs_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmax_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: max.xorsign.abs.bf16x2 - %res = call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmax_nan_xorsign_abs_bf16x2 -define i32 @fmax_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) { +define <2 x bfloat> @fmax_nan_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) { ; CHECK-NOT: call ; CHECK: max.NaN.xorsign.abs.bf16x2 - %res = call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32 %0, i32 %1) - ret i32 %res + %res = call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) + ret <2 x bfloat> %res } ; CHECK-LABEL: fmax_xorsign_abs_f diff --git a/llvm/test/CodeGen/NVPTX/param-load-store.ll b/llvm/test/CodeGen/NVPTX/param-load-store.ll index b05fbaea17087c..313a0915d2030f 100644 --- a/llvm/test/CodeGen/NVPTX/param-load-store.ll +++ b/llvm/test/CodeGen/NVPTX/param-load-store.ll @@ -381,13 +381,13 @@ define <5 x i16> @test_v5i16(<5 x i16> %a) { ret <5 x i16> %r; } -; CHECK: .func (.param .b32 func_retval0) +; CHECK: .func (.param .align 2 .b8 func_retval0[2]) ; CHECK-LABEL: test_f16( -; CHECK-NEXT: .param .b32 test_f16_param_0 +; CHECK-NEXT: .param .align 2 .b8 test_f16_param_0[2] ; CHECK: ld.param.b16 [[E:%rs[0-9]+]], [test_f16_param_0]; -; CHECK: .param .b32 param0; +; CHECK: .param .align 2 .b8 param0[2]; ; CHECK: st.param.b16 [param0+0], [[E]]; -; CHECK: .param .b32 retval0; +; CHECK: .param .align 2 .b8 retval0[2]; ; CHECK: call.uni (retval0), ; CHECK-NEXT: test_f16, ; CHECK: ld.param.b16 [[R:%rs[0-9]+]], [retval0+0]; @@ -415,6 +415,41 @@ define <2 x half> @test_v2f16(<2 x half> %a) { ret <2 x half> %r; } +; CHECK: .func (.param .align 2 .b8 func_retval0[2]) +; CHECK-LABEL: test_bf16( +; CHECK-NEXT: .param .align 2 .b8 test_bf16_param_0[2] +; CHECK: ld.param.b16 [[E:%rs[0-9]+]], [test_bf16_param_0]; +; CHECK: .param .align 2 .b8 param0[2]; +; CHECK: st.param.b16 [param0+0], [[E]]; +; CHECK: .param .align 2 .b8 retval0[2]; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_bf16, +; CHECK: ld.param.b16 [[R:%rs[0-9]+]], [retval0+0]; +; CHECK: st.param.b16 [func_retval0+0], [[R]] +; CHECK-NEXT: ret; +define bfloat @test_bf16(bfloat %a) { + %r = tail call bfloat @test_bf16(bfloat %a); + ret bfloat %r; +} + +; CHECK: .func (.param .align 4 .b8 func_retval0[4]) +; CHECK-LABEL: test_v2bf16( +; CHECK-NEXT: .param .align 4 .b8 test_v2bf16_param_0[4] +; CHECK: ld.param.b32 [[E:%r[0-9]+]], [test_v2bf16_param_0]; +; CHECK: .param .align 4 .b8 param0[4]; +; CHECK: st.param.b32 [param0+0], [[E]]; +; CHECK: .param .align 4 .b8 retval0[4]; +; CHECK: call.uni (retval0), +; CHECK-NEXT: test_v2bf16, +; CHECK: ld.param.b32 [[R:%r[0-9]+]], [retval0+0]; +; CHECK: st.param.b32 [func_retval0+0], [[R]] +; CHECK-NEXT: ret; +define <2 x bfloat> @test_v2bf16(<2 x bfloat> %a) { + %r = tail call <2 x bfloat> @test_v2bf16(<2 x bfloat> %a); + ret <2 x bfloat> %r; +} + + ; CHECK:.func (.param .align 8 .b8 func_retval0[8]) ; CHECK-LABEL: test_v3f16( ; CHECK: .param .align 8 .b8 test_v3f16_param_0[8]