Skip to content

Commit

Permalink
Merge branch 'main' into invoke_lambda_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Revaj authored Mar 22, 2024
2 parents 7f71edf + 7c53bbd commit 64c8d37
Show file tree
Hide file tree
Showing 91 changed files with 603 additions and 781 deletions.
2 changes: 1 addition & 1 deletion cub/cub/warp/warp_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ CUB_NAMESPACE_BEGIN
//! struct CustomLess
//! {
//! template <typename DataType>
//! __host__ bool operator()(const DataType &lhs, const DataType &rhs)
//! __device__ bool operator()(const DataType &lhs, const DataType &rhs)
//! {
//! return lhs < rhs;
//! }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,7 @@ __completion_mechanism __dispatch_memcpy_async_global_to_shared(_Group const & _
#if __cccl_ptx_isa >= 800
NV_IF_TARGET(NV_PROVIDES_SM_90, (
const bool __can_use_complete_tx = __allowed_completions & uint32_t(__completion_mechanism::__mbarrier_complete_tx);
_LIBCUDACXX_UNUSED_VAR(__can_use_complete_tx);
_LIBCUDACXX_DEBUG_ASSERT(__can_use_complete_tx == (nullptr != __bar_handle), "Pass non-null bar_handle if and only if can_use_complete_tx.");
if _LIBCUDACXX_CONSTEXPR_AFTER_CXX14 (_Align >= 16) {
if (__can_use_complete_tx && __isShared(__bar_handle)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return hcos(__v);), (return __nv_bfloat16(::cos(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __nv_bfloat16(::cos(float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,33 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
// trigonometric functions
inline _LIBCUDACXX_INLINE_VISIBILITY __half sin(__half __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return hsin(__v);), ({
float __vf = __v;
__vf = ::sin(__vf);
__half_raw __ret_repr = ::__float2half_rn(__vf);
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_53, (
return ::hsin(__v);
), (
{
float __vf = __v;
__vf = ::sin(__vf);
__half_raw __ret_repr = ::__float2half_rn(__vf);

uint16_t __repr = __half_raw(__v).x;
switch (__repr)
{
case 12979:
case 45747:
__ret_repr.x += 1;
break;
uint16_t __repr = __half_raw(__v).x;
switch (__repr)
{
case 12979:
case 45747:
__ret_repr.x += 1;
break;

case 23728:
case 56496:
__ret_repr.x -= 1;
break;
case 23728:
case 56496:
__ret_repr.x -= 1;
break;

default:;
}
default:;
}

return __ret_repr;
}))
return __ret_repr;
}
))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half sinh(__half __v)
Expand All @@ -69,10 +73,9 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half sinh(__half __v)
// clang-format off
inline _LIBCUDACXX_INLINE_VISIBILITY __half cos(__half __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE,
(
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_53, (
return ::hcos(__v);
),(
), (
{
float __vf = __v;
__vf = ::cos(__vf);
Expand Down Expand Up @@ -103,10 +106,9 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half cosh(__half __v)
// clang-format off
inline _LIBCUDACXX_INLINE_VISIBILITY __half exp(__half __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE,
(
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_53, (
return ::hexp(__v);
),(
), (
{
float __vf = __v;
__vf = ::exp(__vf);
Expand Down Expand Up @@ -142,10 +144,9 @@ inline _LIBCUDACXX_INLINE_VISIBILITY __half atan2(__half __x, __half __y)
// clang-format off
inline _LIBCUDACXX_INLINE_VISIBILITY __half log(__half __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE,
(
NV_IF_ELSE_TARGET(NV_PROVIDES_SM_53, (
return ::hlog(__x);
),(
), (
{
float __vf = __x;
__vf = ::log(__vf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ _CCCL_DIAG_POP
# include "../__type_traits/is_same.h"
# include "../cmath"

# if !defined(_CCCL_COMPILER_NVRTC)
# include <sstream> // for std::basic_ostringstream
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <>
Expand Down Expand Up @@ -72,8 +76,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__nv_bfloat16
: __repr(__re, __im)
{}

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions

_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<float>& __c)
: __repr(__c.real(), __c.imag())
Expand All @@ -82,7 +86,7 @@ _CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
: __repr(__c.real(), __c.imag())
{}

_CCCL_DIAG_POP
_CCCL_DIAG_POP

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Up>
Expand Down Expand Up @@ -228,6 +232,25 @@ inline _LIBCUDACXX_INLINE_VISIBILITY complex<__nv_bfloat16> acos(const complex<_
return complex<__nv_bfloat16>{_CUDA_VSTD::acos(complex<float>{__x.real(), __x.imag()})};
}

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>&
operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<__nv_bfloat16>& __x)
{
::std::complex<float> __temp;
__is >> __temp;
__x = __temp;
return __is;
}

template <class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<__nv_bfloat16>& __x)
{
return __os << complex<float>{__x.real(), __x.imag()};
}
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_END_NAMESPACE_STD

#endif /// _LIBCUDACXX_HAS_NVBF16
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
# include "../__type_traits/is_same.h"
# include "../cmath"

# if !defined(_CCCL_COMPILER_NVRTC)
# include <sstream> // for std::basic_ostringstream
# endif // !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <>
Expand Down Expand Up @@ -69,8 +73,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__half2)) com
: __repr(__re, __im)
{}

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions

_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<float>& __c)
: __repr(__c.real(), __c.imag())
Expand All @@ -79,7 +83,7 @@ _CCCL_DIAG_SUPPRESS_MSVC(4244) // narrowing conversions
: __repr(__c.real(), __c.imag())
{}

_CCCL_DIAG_POP
_CCCL_DIAG_POP

# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Up>
Expand Down Expand Up @@ -225,6 +229,24 @@ inline _LIBCUDACXX_INLINE_VISIBILITY complex<__half> acos(const complex<__half>&
return complex<__half>{_CUDA_VSTD::acos(complex<float>{__x.real(), __x.imag()})};
}

# if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION) && !defined(_CCCL_COMPILER_NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>& operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<__half>& __x)
{
::std::complex<float> __temp;
__is >> __temp;
__x = __temp;
return __is;
}

template <class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<__half>& __x)
{
return __os << complex<float>{__x.real(), __x.imag()};
}
# endif // !_LIBCUDACXX_HAS_NO_LOCALIZATION && !_CCCL_COMPILER_NVRTC

_LIBCUDACXX_END_NAMESPACE_STD

#endif /// _LIBCUDACXX_HAS_NVFP16
Expand Down
21 changes: 20 additions & 1 deletion libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -1829,7 +1829,26 @@ tan(const complex<_Tp>& __x)
return complex<_Tp>(__z.imag(), -__z.real());
}

#ifndef __cuda_std__
#ifdef __cuda_std__
# if !defined(_CCCL_COMPILER_NVRTC)
template <class _Tp, class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>&
operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<_Tp>& __x)
{
::std::complex<_Tp> __temp;
__is >> __temp;
__x = __temp;
return __is;
}

template <class _Tp, class _CharT, class _Traits>
::std::basic_ostream<_CharT, _Traits>&
operator<<(::std::basic_ostream<_CharT, _Traits>& __os, const complex<_Tp>& __x)
{
return __os << static_cast<::std::complex<_Tp>>(__x);
}
# endif // !_CCCL_COMPILER_NVRTC
#else // ^^^ __cuda_std__ ^^^ / vvv !__cuda_std__ vvv
#if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION)
template<class _Tp, class _CharT, class _Traits>
basic_istream<_CharT, _Traits>&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,18 @@ CUtensorMap map_encode(T *tensor_ptr, const cuda::std::array<uint64_t, num_dims>

// The stride is the number of bytes to traverse from the first element of one row to the next.
// It must be a multiple of 16.
uint64_t stride[num_dims - 1];
constexpr int num_strides = num_dims - 1;
cuda::std::array<uint64_t, num_strides> stride;
uint64_t base_stride = sizeof(T);
for (size_t i = 0; i < num_dims - 1; ++i) {
for (size_t i = 0; i < stride.size(); ++i) {
base_stride *= gmem_dims[i];
stride[i] = base_stride;
}

// The distance between elements in units of sizeof(element). A stride of 2
// can be used to load only the real component of a complex-valued tensor, for instance.
uint32_t elem_stride[num_dims]; // = {1, .., 1};
for (size_t i = 0; i < num_dims; ++i) {
cuda::std::array<uint32_t, num_dims> elem_stride; // = {1, .., 1};
for (size_t i = 0; i < elem_stride.size(); ++i) {
elem_stride[i] = 1;
}

Expand All @@ -240,9 +241,9 @@ CUtensorMap map_encode(T *tensor_ptr, const cuda::std::array<uint64_t, num_dims>
num_dims, // cuuint32_t tensorRank,
tensor_ptr, // void *globalAddress,
gmem_dims.data(), // const cuuint64_t *globalDim,
stride, // const cuuint64_t *globalStrides,
stride.data(), // const cuuint64_t *globalStrides,
smem_dims.data(), // const cuuint32_t *boxDim,
elem_stride, // const cuuint32_t *elementStrides,
elem_stride.data(), // const cuuint32_t *elementStrides,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===----------------------------------------------------------------------===//
//
// Part of the libcu++ Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#include <cuda/std/complex>
#include <cuda/std/cassert>

#include "test_macros.h"

template <class T, class U>
__host__ __device__ void test_assignment() {
cuda::std::complex<T> from_only_real{static_cast<T>(-1.0),
static_cast<T>(1.0)};
cuda::std::complex<T> from_only_imag{static_cast<T>(-1.0),
static_cast<T>(1.0)};
cuda::std::complex<T> from_real_imag{static_cast<T>(-1.0),
static_cast<T>(1.0)};

const cuda::std::complex<U> only_real{static_cast<U>(42.0), static_cast<U>(0.0)};
const cuda::std::complex<U> only_imag{static_cast<U>(0.0), static_cast<U>(42.0)};
const cuda::std::complex<U> real_imag{static_cast<U>(42.0),
static_cast<U>(112.0)};

from_only_real = only_real;
from_only_imag = only_imag;
from_real_imag = real_imag;

assert(from_only_real.real() == static_cast<T>(42.0));
assert(from_only_real.imag() == static_cast<T>(0.0));
assert(from_only_imag.real() == static_cast<T>(0.0));
assert(from_only_imag.imag() == static_cast<T>(42.0));
assert(from_real_imag.real() == static_cast<T>(42.0));
assert(from_real_imag.imag() == static_cast<T>(112.0));
}

__host__ __device__ void test() {
#ifdef _LIBCUDACXX_HAS_NVFP16
test_assignment<__half, float>();
test_assignment<__half, double>();
test_assignment<float, __half>();
test_assignment<double, __half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_assignment<__nv_bfloat16, float>();
test_assignment<__nv_bfloat16, double>();
test_assignment<float, __nv_bfloat16>();
test_assignment<double, __nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
}

int main(int arg, char** argv) {
test();
return 0;
}
Loading

0 comments on commit 64c8d37

Please sign in to comment.