Skip to content

Commit

Permalink
check only gko::complex<half> WIP: remove the compile option
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Nov 28, 2024
1 parent a2ec00c commit 05471a3
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 18 deletions.
10 changes: 10 additions & 0 deletions accessor/sycl_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ namespace gko {
class half;


template <typename V>
class complex;


namespace acc {
namespace detail {

Expand Down Expand Up @@ -81,6 +85,12 @@ struct sycl_type<std::complex<T>> {
};


template <>
struct sycl_type<std::complex<gko::half>> {
using type = gko::complex<typename sycl_type<gko::half>::type>;
};


} // namespace detail


Expand Down
76 changes: 64 additions & 12 deletions dpcpp/base/complex.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
// path in sycl module.
// They start to do this from LIBSYCL 7.1.0.

namespace std {
namespace gko {

template <typename>
class complex;
Expand Down Expand Up @@ -53,7 +53,7 @@ class complex<sycl::half> {
{}

template <typename T, typename = std::enable_if_t<std::is_scalar<T>::value>>
complex(const complex<T>& other)
complex(const std::complex<T>& other)
: real_(static_cast<value_type>(other.real())),
imag_(static_cast<value_type>(other.imag()))
{}
Expand All @@ -64,6 +64,13 @@ class complex<sycl::half> {

inline operator std::complex<float>() const noexcept;

bool operator!=(const complex& r) const { return !this->operator==(r); }

bool operator==(const complex& r) const
{
return real_ == r.real() && imag_ == r.imag();
}

template <typename V>
complex& operator=(const V& val)
{
Expand Down Expand Up @@ -107,26 +114,44 @@ class complex<sycl::half> {
}

template <typename T>
complex& operator+=(const complex<T>& val)
complex& operator+=(const std::complex<T>& val)
{
real_ += val.real();
imag_ += val.imag();
return *this;
}

template <typename T>
complex& operator-=(const complex<T>& val)
complex& operator-=(const std::complex<T>& val)
{
real_ -= val.real();
imag_ -= val.imag();
return *this;
}

template <typename T>
inline complex& operator*=(const complex<T>& val);
inline complex& operator*=(const std::complex<T>& val);

template <typename T>
inline complex& operator/=(const complex<T>& val);
inline complex& operator/=(const std::complex<T>& val);

complex& operator+=(const gko::complex<value_type>& val)
{
real_ += val.real();
imag_ += val.imag();
return *this;
}

complex& operator-=(const gko::complex<value_type>& val)
{
real_ -= val.real();
imag_ -= val.imag();
return *this;
}

inline complex& operator*=(const gko::complex<value_type>& val);

inline complex& operator/=(const gko::complex<value_type>& val);

// It's for MacOS.
// TODO: check whether mac compiler always use complex version even when real
Expand All @@ -147,12 +172,15 @@ class complex<sycl::half> {

#undef COMPLEX_HALF_OPERATOR


complex operator-() const { return complex(-real_, -imag_); }

private:
value_type real_;
value_type imag_;
};

} // namespace std
} // namespace gko


// after providing std::complex<sycl::half>, we can load their <complex> to
Expand Down Expand Up @@ -181,10 +209,10 @@ class complex<sycl::half> {

// we know the complex<float> now, so we implement those functions requiring
// complex<float>
namespace std {
namespace gko {


inline complex<sycl::half>::operator complex<float>() const noexcept
inline complex<sycl::half>::operator std::complex<float>() const noexcept
{
return std::complex<float>(static_cast<float>(real_),
static_cast<float>(imag_));
Expand All @@ -193,7 +221,7 @@ inline complex<sycl::half>::operator complex<float>() const noexcept

template <typename T>
inline complex<sycl::half>& complex<sycl::half>::operator*=(
const complex<T>& val)
const std::complex<T>& val)
{
auto val_f = static_cast<std::complex<float>>(val);
auto result_f = static_cast<std::complex<float>>(*this);
Expand All @@ -206,7 +234,31 @@ inline complex<sycl::half>& complex<sycl::half>::operator*=(

template <typename T>
inline complex<sycl::half>& complex<sycl::half>::operator/=(
const complex<T>& val)
const std::complex<T>& val)
{
auto val_f = static_cast<std::complex<float>>(val);
auto result_f = static_cast<std::complex<float>>(*this);
result_f /= val_f;
real_ = result_f.real();
imag_ = result_f.imag();
return *this;
}


inline complex<sycl::half>& complex<sycl::half>::operator*=(
const gko::complex<value_type>& val)
{
auto val_f = static_cast<std::complex<float>>(val);
auto result_f = static_cast<std::complex<float>>(*this);
result_f *= val_f;
real_ = result_f.real();
imag_ = result_f.imag();
return *this;
}


inline complex<sycl::half>& complex<sycl::half>::operator/=(
const gko::complex<value_type>& val)
{
auto val_f = static_cast<std::complex<float>>(val);
auto result_f = static_cast<std::complex<float>>(*this);
Expand All @@ -217,7 +269,7 @@ inline complex<sycl::half>& complex<sycl::half>::operator/=(
}


} // namespace std
} // namespace gko


#endif // GKO_DPCPP_BASE_COMPLEX_HPP_
45 changes: 40 additions & 5 deletions dpcpp/base/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,41 @@ struct basic_float_traits<sycl::half> {
template <>
struct is_complex_or_scalar_impl<sycl::half> : public std::true_type {};

template <typename ValueType>
struct complex_helper {
using type = std::complex<ValueType>;
};

template <>
struct complex_helper<sycl::half> {
using type = gko::complex<sycl::half>;
};


template <typename T>
struct type_size_impl<gko::complex<T>> {
static constexpr auto value = sizeof(T) * byte_size;
};


template <typename T>
struct remove_complex_impl<gko::complex<T>> {
using type = T;
};


template <typename T>
struct truncate_type_impl<gko::complex<T>> {
using type =
typename complex_helper<typename truncate_type_impl<T>::type>::type;
};

template <typename T>
struct is_complex_impl<gko::complex<T>> : public std::true_type {};

template <typename T>
struct is_complex_or_scalar_impl<gko::complex<T>>
: public is_complex_or_scalar_impl<T> {};

} // namespace detail

Expand All @@ -41,7 +76,7 @@ bool __dpct_inline__ is_nan(const sycl::half& val)
return std::isnan(static_cast<float>(val));
}

bool __dpct_inline__ is_nan(const std::complex<sycl::half>& val)
bool __dpct_inline__ is_nan(const gko::complex<sycl::half>& val)
{
return is_nan(val.real()) || is_nan(val.imag());
}
Expand All @@ -52,7 +87,7 @@ sycl::half __dpct_inline__ abs(const sycl::half& val)
return abs(static_cast<float>(val));
}

sycl::half __dpct_inline__ abs(const std::complex<sycl::half>& val)
sycl::half __dpct_inline__ abs(const gko::complex<sycl::half>& val)
{
return abs(static_cast<std::complex<float>>(val));
}
Expand All @@ -62,8 +97,8 @@ sycl::half __dpct_inline__ sqrt(const sycl::half& val)
return sqrt(static_cast<float>(val));
}

std::complex<sycl::half> __dpct_inline__
sqrt(const std::complex<sycl::half>& val)
gko::complex<sycl::half> __dpct_inline__
sqrt(const gko::complex<sycl::half>& val)
{
return sqrt(static_cast<std::complex<float>>(val));
}
Expand All @@ -74,7 +109,7 @@ bool __dpct_inline__ is_finite(const sycl::half& value)
return abs(value) < std::numeric_limits<sycl::half>::infinity();
}

bool __dpct_inline__ is_finite(const std::complex<sycl::half>& value)
bool __dpct_inline__ is_finite(const gko::complex<sycl::half>& value)
{
return is_finite(value.real()) && is_finite(value.imag());
}
Expand Down
5 changes: 5 additions & 0 deletions dpcpp/base/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ struct sycl_type_impl<std::complex<T>> {
using type = std::complex<typename sycl_type_impl<T>::type>;
};

template <>
struct sycl_type_impl<std::complex<gko::half>> {
using type = gko::complex<typename sycl_type_impl<gko::half>::type>;
};

template <typename ValueType, typename IndexType>
struct sycl_type_impl<matrix_data_entry<ValueType, IndexType>> {
using type =
Expand Down
2 changes: 1 addition & 1 deletion dpcpp/preconditioner/batch_block_jacobi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class BlockJacobi final {

// reduction (it does not support complex<half>)
if constexpr (std::is_same_v<value_type,
std::complex<sycl::half>>) {
gko::complex<sycl::half>>) {
for (int i = sg_size / 2; i > 0; i /= 2) {
sum += sycl::shift_group_left(sg, sum, i);
}
Expand Down

0 comments on commit 05471a3

Please sign in to comment.