Skip to content

Commit

Permalink
mxp: Cleanup truncated test
Browse files Browse the repository at this point in the history
  • Loading branch information
cmpfeil committed Jan 30, 2025
1 parent edbab32 commit a9a6370
Showing 1 changed file with 23 additions and 43 deletions.
66 changes: 23 additions & 43 deletions tests/test_mxp_truncated_mantissa_t.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@

#include "test_mxp_helper.h"

template <std::uint8_t bits, typename S, typename T>
void generic_truncated_add(const gt::gtensor<T, 1, S>& x,
gt::gtensor<T, 1, S>& y)
{
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
auto mxp_y = mxp::adapt<1, S, mxp_type>(y.data(), y.size());

mxp_y = mxp_y + mxp_x;
}

template <std::uint8_t bits>
float ref_truncated_add_float()
{
Expand Down Expand Up @@ -81,7 +70,12 @@ struct run_test_add_host
auto gt_y = gt::adapt<1, S>(y.data(), y.size());
y.view() = y_init;

generic_truncated_add<bits, S>(x, y);
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
auto mxp_y = mxp::adapt<1, S, mxp_type>(y.data(), y.size());

mxp_y = mxp_y + mxp_x;

EXPECT_EQ(
y, (gt::gtensor<T, 1, S>(y.size(), ref_truncated_add_gen<bits, T>())));
}
Expand Down Expand Up @@ -151,20 +145,6 @@ TEST(mxp_truncated_mantissa, add_complex_double)
Loop<0, 52, run_test_add_host>::Run(x, y, y_init);
}

template <std::uint8_t bits, typename S, typename T>
void generic_view_truncated_add(const gt::gtensor<T, 1, S>& x,
gt::gtensor<T, 1, S>& y)
{
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
auto mxp_y = mxp::adapt<1, S, mxp_type>(y.data(), y.size());

using gt::placeholders::_all;
using gt::placeholders::_s;

mxp_y.view(_s(1, -1)) = mxp_y.view(_s(1, -1)) + mxp_x.view(_all);
}

struct run_test_view_add_host
{
using S = gt::space::host;
Expand All @@ -176,7 +156,14 @@ struct run_test_view_add_host
auto gt_y = gt::adapt<1, S>(y.data(), y.size());
y.view() = y_init;

generic_view_truncated_add<bits, S>(x, y);
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
auto mxp_y = mxp::adapt<1, S, mxp_type>(y.data(), y.size());

using gt::placeholders::_all;
using gt::placeholders::_s;

mxp_y.view(_s(1, -1)) = mxp_y.view(_s(1, -1)) + mxp_x.view(_all);
EXPECT_EQ(y,
(gt::gtensor<T, 1, S>{y_init, ref_truncated_add_gen<bits, T>(),
ref_truncated_add_gen<bits, T>(),
Expand Down Expand Up @@ -255,20 +242,6 @@ TEST(mxp_truncated_mantissa, view_add_complex_double)
Loop<0, 52, run_test_view_add_host>::Run(x, y, y_init);
}

template <std::uint8_t bits, typename S, typename T>
void generic_view_2D_truncated_add(const gt::gtensor<T, 1, S>& x,
gt::gtensor<T, 2, S>& y)
{
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
auto mxp_y = mxp::adapt<2, S, mxp_type>(y.data(), y.shape());

using gt::placeholders::_all;
using gt::placeholders::_s;

mxp_y.view(_s(1, -1), 1) = mxp_y.view(_s(1, -1), 1) + mxp_x.view(_all);
}

struct run_test_view_2D_add_host
{
using S = gt::space::host;
Expand All @@ -280,7 +253,14 @@ struct run_test_view_2D_add_host
auto gt_y = gt::adapt<2, S>(y.data(), y.shape());
y.view() = y_init;

generic_view_2D_truncated_add<bits, S>(x, y);
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
auto mxp_y = mxp::adapt<2, S, mxp_type>(y.data(), y.shape());

using gt::placeholders::_all;
using gt::placeholders::_s;

mxp_y.view(_s(1, -1), 1) = mxp_y.view(_s(1, -1), 1) + mxp_x.view(_all);
EXPECT_EQ(
y, (gt::gtensor<T, 2, S>{{y_init, y_init, y_init, y_init, y_init},
{y_init, ref_truncated_add_gen<bits, T>(),
Expand Down Expand Up @@ -370,7 +350,7 @@ struct run_test_error_bounds_host
using mxp_type = mxp::truncated_mantissa_t<T, bits>;
const auto mxp_x = mxp::adapt<1, S, mxp_type>(x.data(), x.size());
const auto gt_x = gt::adapt<1, S>(x.data(), x.size());
/* */ auto gt_y = gt::adapt<1, S>(y.data(), y.size());
auto gt_y = gt::adapt<1, S>(y.data(), y.size());

gt_y = gt_x - mxp_x;

Expand Down

0 comments on commit a9a6370

Please sign in to comment.