diff --git a/compiler/rustc_middle/src/ty/layout.rs b/compiler/rustc_middle/src/ty/layout.rs index 7cf2984a63f90..cd4b23fca3932 100644 --- a/compiler/rustc_middle/src/ty/layout.rs +++ b/compiler/rustc_middle/src/ty/layout.rs @@ -2592,6 +2592,22 @@ where pointee_info } + + fn is_adt(this: TyAndLayout<'tcx>) -> bool { + matches!(this.ty.kind(), ty::Adt(..)) + } + + fn is_never(this: TyAndLayout<'tcx>) -> bool { + this.ty.kind() == &ty::Never + } + + fn is_tuple(this: TyAndLayout<'tcx>) -> bool { + matches!(this.ty.kind(), ty::Tuple(..)) + } + + fn is_unit(this: TyAndLayout<'tcx>) -> bool { + matches!(this.ty.kind(), ty::Tuple(list) if list.len() == 0) + } } impl<'tcx> ty::Instance<'tcx> { diff --git a/compiler/rustc_middle/src/ty/list.rs b/compiler/rustc_middle/src/ty/list.rs index adba7d131592e..197dc9205b480 100644 --- a/compiler/rustc_middle/src/ty/list.rs +++ b/compiler/rustc_middle/src/ty/list.rs @@ -61,6 +61,10 @@ impl List { static EMPTY_SLICE: InOrder = InOrder(0, MaxAlign); unsafe { &*(&EMPTY_SLICE as *const _ as *const List) } } + + pub fn len(&self) -> usize { + self.len + } } impl List { diff --git a/compiler/rustc_target/src/abi/call/mod.rs b/compiler/rustc_target/src/abi/call/mod.rs index ce564d1455bfc..afce10ff1cbe8 100644 --- a/compiler/rustc_target/src/abi/call/mod.rs +++ b/compiler/rustc_target/src/abi/call/mod.rs @@ -696,7 +696,13 @@ impl<'a, Ty> FnAbi<'a, Ty> { "sparc" => sparc::compute_abi_info(cx, self), "sparc64" => sparc64::compute_abi_info(cx, self), "nvptx" => nvptx::compute_abi_info(self), - "nvptx64" => nvptx64::compute_abi_info(self), + "nvptx64" => { + if cx.target_spec().adjust_abi(abi) == spec::abi::Abi::PtxKernel { + nvptx64::compute_ptx_kernel_abi_info(cx, self) + } else { + nvptx64::compute_abi_info(self) + } + } "hexagon" => hexagon::compute_abi_info(self), "riscv32" | "riscv64" => riscv::compute_abi_info(cx, self), "wasm32" | "wasm64" => { diff --git a/compiler/rustc_target/src/abi/call/nvptx64.rs b/compiler/rustc_target/src/abi/call/nvptx64.rs index 16f331b16d561..fc16f1c97a452 100644 --- a/compiler/rustc_target/src/abi/call/nvptx64.rs +++ b/compiler/rustc_target/src/abi/call/nvptx64.rs @@ -1,21 +1,35 @@ -// Reference: PTX Writer's Guide to Interoperability -// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability - -use crate::abi::call::{ArgAbi, FnAbi}; +use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform}; +use crate::abi::{HasDataLayout, TyAbiInterface}; fn classify_ret(ret: &mut ArgAbi<'_, Ty>) { if ret.layout.is_aggregate() && ret.layout.size.bits() > 64 { ret.make_indirect(); - } else { - ret.extend_integer_width_to(64); } } fn classify_arg(arg: &mut ArgAbi<'_, Ty>) { if arg.layout.is_aggregate() && arg.layout.size.bits() > 64 { arg.make_indirect(); - } else { - arg.extend_integer_width_to(64); + } +} + +fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>) +where + Ty: TyAbiInterface<'a, C> + Copy, + C: HasDataLayout, +{ + if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) { + let align_bytes = arg.layout.align.abi.bytes(); + + let unit = match align_bytes { + 1 => Reg::i8(), + 2 => Reg::i16(), + 4 => Reg::i32(), + 8 => Reg::i64(), + 16 => Reg::i128(), + _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"), + }; + arg.cast_to(Uniform { unit, total: Size::from_bytes(2 * align_bytes) }); } } @@ -31,3 +45,20 @@ pub fn compute_abi_info(fn_abi: &mut FnAbi<'_, Ty>) { classify_arg(arg); } } + +pub fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>) +where + Ty: TyAbiInterface<'a, C> + Copy, + C: HasDataLayout, +{ + if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() { + panic!("Kernels should not return anything other than () or !"); + } + + for arg in &mut fn_abi.args { + if arg.is_ignore() { + continue; + } + classify_arg_kernel(cx, arg); + } +} diff --git a/compiler/rustc_target/src/abi/mod.rs b/compiler/rustc_target/src/abi/mod.rs index 169167f69bf8c..0e8fd9cc93fd1 100644 --- a/compiler/rustc_target/src/abi/mod.rs +++ b/compiler/rustc_target/src/abi/mod.rs @@ -1355,6 +1355,10 @@ pub trait TyAbiInterface<'a, C>: Sized { cx: &C, offset: Size, ) -> Option; + fn is_adt(this: TyAndLayout<'a, Self>) -> bool; + fn is_never(this: TyAndLayout<'a, Self>) -> bool; + fn is_tuple(this: TyAndLayout<'a, Self>) -> bool; + fn is_unit(this: TyAndLayout<'a, Self>) -> bool; } impl<'a, Ty> TyAndLayout<'a, Ty> { @@ -1396,6 +1400,34 @@ impl<'a, Ty> TyAndLayout<'a, Ty> { _ => false, } } + + pub fn is_adt(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_adt(self) + } + + pub fn is_never(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_never(self) + } + + pub fn is_tuple(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_tuple(self) + } + + pub fn is_unit(self) -> bool + where + Ty: TyAbiInterface<'a, C>, + { + Ty::is_unit(self) + } } impl<'a, Ty> TyAndLayout<'a, Ty> { diff --git a/src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs b/src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs new file mode 100644 index 0000000000000..5bf44f949fdf6 --- /dev/null +++ b/src/test/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs @@ -0,0 +1,254 @@ +// assembly-output: ptx-linker +// compile-flags: --crate-type cdylib -C target-cpu=sm_86 +// only-nvptx64 +// ignore-nvptx64 + +// The following ABI tests are made with nvcc 11.6 does. +// +// The PTX ABI stability is tied to major versions of the PTX ISA +// These tests assume major version 7 +// +// +// The following correspondence between types are assumed: +// u - uint_t +// i - int_t +// [T, N] - std::array +// &T - T const* +// &mut T - T* + +// CHECK: .version 7 + +#![feature(abi_ptx, lang_items, no_core)] +#![no_core] + +#[lang = "sized"] +trait Sized {} +#[lang = "copy"] +trait Copy {} + +#[repr(C)] +pub struct SingleU8 { + f: u8, +} + +#[repr(C)] +pub struct DoubleU8 { + f: u8, + g: u8, +} + +#[repr(C)] +pub struct TripleU8 { + f: u8, + g: u8, + h: u8, +} + +#[repr(C)] +pub struct TripleU16 { + f: u16, + g: u16, + h: u16, +} +#[repr(C)] +pub struct TripleU32 { + f: u32, + g: u32, + h: u32, +} +#[repr(C)] +pub struct TripleU64 { + f: u64, + g: u64, + h: u64, +} + +#[repr(C)] +pub struct DoubleFloat { + f: f32, + g: f32, +} + +#[repr(C)] +pub struct TripleFloat { + f: f32, + g: f32, + h: f32, +} + +#[repr(C)] +pub struct TripleDouble { + f: f64, + g: f64, + h: f64, +} + +#[repr(C)] +pub struct ManyIntegers { + f: u8, + g: u16, + h: u32, + i: u64, +} + +#[repr(C)] +pub struct ManyNumerics { + f: u8, + g: u16, + h: u32, + i: u64, + j: f32, + k: f64, +} + +// CHECK: .visible .entry f_u8_arg( +// CHECK: .param .u8 f_u8_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u8_arg(_a: u8) {} + +// CHECK: .visible .entry f_u16_arg( +// CHECK: .param .u16 f_u16_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u16_arg(_a: u16) {} + +// CHECK: .visible .entry f_u32_arg( +// CHECK: .param .u32 f_u32_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u32_arg(_a: u32) {} + +// CHECK: .visible .entry f_u64_arg( +// CHECK: .param .u64 f_u64_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u64_arg(_a: u64) {} + +// CHECK: .visible .entry f_u128_arg( +// CHECK: .param .align 16 .b8 f_u128_arg_param_0[16] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u128_arg(_a: u128) {} + +// CHECK: .visible .entry f_i8_arg( +// CHECK: .param .u8 f_i8_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_i8_arg(_a: i8) {} + +// CHECK: .visible .entry f_i16_arg( +// CHECK: .param .u16 f_i16_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_i16_arg(_a: i16) {} + +// CHECK: .visible .entry f_i32_arg( +// CHECK: .param .u32 f_i32_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_i32_arg(_a: i32) {} + +// CHECK: .visible .entry f_i64_arg( +// CHECK: .param .u64 f_i64_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_i64_arg(_a: i64) {} + +// CHECK: .visible .entry f_i128_arg( +// CHECK: .param .align 16 .b8 f_i128_arg_param_0[16] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_i128_arg(_a: i128) {} + +// CHECK: .visible .entry f_f32_arg( +// CHECK: .param .f32 f_f32_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_f32_arg(_a: f32) {} + +// CHECK: .visible .entry f_f64_arg( +// CHECK: .param .f64 f_f64_arg_param_0 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_f64_arg(_a: f64) {} + +// CHECK: .visible .entry f_single_u8_arg( +// CHECK: .param .align 1 .b8 f_single_u8_arg_param_0[1] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_single_u8_arg(_a: SingleU8) {} + +// CHECK: .visible .entry f_double_u8_arg( +// CHECK: .param .align 1 .b8 f_double_u8_arg_param_0[2] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_double_u8_arg(_a: DoubleU8) {} + +// CHECK: .visible .entry f_triple_u8_arg( +// CHECK: .param .align 1 .b8 f_triple_u8_arg_param_0[3] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_triple_u8_arg(_a: TripleU8) {} + +// CHECK: .visible .entry f_triple_u16_arg( +// CHECK: .param .align 2 .b8 f_triple_u16_arg_param_0[6] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_triple_u16_arg(_a: TripleU16) {} + +// CHECK: .visible .entry f_triple_u32_arg( +// CHECK: .param .align 4 .b8 f_triple_u32_arg_param_0[12] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_triple_u32_arg(_a: TripleU32) {} + +// CHECK: .visible .entry f_triple_u64_arg( +// CHECK: .param .align 8 .b8 f_triple_u64_arg_param_0[24] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_triple_u64_arg(_a: TripleU64) {} + +// CHECK: .visible .entry f_many_integers_arg( +// CHECK: .param .align 8 .b8 f_many_integers_arg_param_0[16] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_many_integers_arg(_a: ManyIntegers) {} + +// CHECK: .visible .entry f_double_float_arg( +// CHECK: .param .align 4 .b8 f_double_float_arg_param_0[8] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_double_float_arg(_a: DoubleFloat) {} + +// CHECK: .visible .entry f_triple_float_arg( +// CHECK: .param .align 4 .b8 f_triple_float_arg_param_0[12] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_triple_float_arg(_a: TripleFloat) {} + +// CHECK: .visible .entry f_triple_double_arg( +// CHECK: .param .align 8 .b8 f_triple_double_arg_param_0[24] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_triple_double_arg(_a: TripleDouble) {} + +// CHECK: .visible .entry f_many_numerics_arg( +// CHECK: .param .align 8 .b8 f_many_numerics_arg_param_0[32] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_many_numerics_arg(_a: ManyNumerics) {} + +// CHECK: .visible .entry f_byte_array_arg( +// CHECK: .param .align 1 .b8 f_byte_array_arg_param_0[5] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_byte_array_arg(_a: [u8; 5]) {} + +// CHECK: .visible .entry f_float_array_arg( +// CHECK: .param .align 4 .b8 f_float_array_arg_param_0[20] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_float_array_arg(_a: [f32; 5]) {} + +// CHECK: .visible .entry f_u128_array_arg( +// CHECK: .param .align 16 .b8 f_u128_array_arg_param_0[80] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u128_array_arg(_a: [u128; 5]) {} + +// CHECK: .visible .entry f_u32_slice_arg( +// CHECK: .param .u64 f_u32_slice_arg_param_0 +// CHECK: .param .u64 f_u32_slice_arg_param_1 +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_u32_slice_arg(_a: &[u32]) {} + +// CHECK: .visible .entry f_tuple_u8_u8_arg( +// CHECK: .param .align 1 .b8 f_tuple_u8_u8_arg_param_0[2] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_arg(_a: (u8, u8)) {} + +// CHECK: .visible .entry f_tuple_u32_u32_arg( +// CHECK: .param .align 4 .b8 f_tuple_u32_u32_arg_param_0[8] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_tuple_u32_u32_arg(_a: (u32, u32)) {} + + +// CHECK: .visible .entry f_tuple_u8_u8_u32_arg( +// CHECK: .param .align 4 .b8 f_tuple_u8_u8_u32_arg_param_0[8] +#[no_mangle] +pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_u32_arg(_a: (u8, u8, u32)) {}