-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2454,7 +2454,7 @@ struct FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> { | |
CUTLASS_DEVICE | ||
static result_type convert(source_type const &source) { | ||
result_type result; | ||
|
||
#if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) | ||
CUTLASS_PRAGMA_UNROLL | ||
for (int i = 0; i < 4; ++i) { | ||
|
@@ -2541,6 +2541,54 @@ struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> { | |
} | ||
}; | ||
|
||
template<FloatRoundStyle Round, bool convert_lsb_quad> | ||
CUTLASS_DEVICE | ||
Array<cutlass::half_t, 4> | ||
convert_int4b_to_half_quad(uint32_t const& source) { | ||
uint64_t result; | ||
|
||
uint32_t tmp1, tmp2; | ||
|
||
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp1) : "r"(source), "n"(convert_lsb_quad ? 0x1100 : 0x3322)); | ||
asm volatile("shr.s16 %0,%1,%2;\n" : "=r"(tmp2) : "r"(tmp1), "n"(4)); | ||
asm volatile("and.b32 %0,%1,%2;\n" : "=r"(tmp2) : "r"(tmp2), "n"(0xFF00FF00)); | ||
asm volatile("shl.s16 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "n"(4)); | ||
asm volatile("sh4.s16 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "n"(12)); | ||
asm volatile("and.b32 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "n"(0x00FF00FF)); | ||
asm volatile("or.b32 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "r"(tmp2)); | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
alexsamardzic
Author
Owner
|
||
|
||
FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> converter; | ||
return converter.convert(tmp1); | ||
|
||
return result; | ||
} | ||
|
||
/// Partial specialization for Array<cutlass::half_t, 8> <= Array<int4b_t, 8> | ||
template <FloatRoundStyle Round> | ||
struct FastNumericArrayConverter<cutlass::half_t, cutlass::int4b_t, 8, Round> { | ||
using result_type = Array<cutlass::half_t, 8>; | ||
using source_type = Array<cutlass::int4b_t, 8>; | ||
static FloatRoundStyle const round_style = Round; | ||
|
||
CUTLASS_DEVICE | ||
static result_type convert(source_type const &source) { | ||
result_type result; | ||
|
||
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source); | ||
Array<cutlass::half_t, 4>* result_ptr = reinterpret_cast<Array<cutlass::half_t, 4>*>(&result); | ||
|
||
result_ptr[0] = convert_int4b_to_half_quad<Round, true>(source_ptr[0]); | ||
result_ptr[1] = convert_int4b_to_half_quad<Round, false>(source_ptr[0]); | ||
|
||
return result; | ||
} | ||
|
||
CUTLASS_DEVICE | ||
result_type operator()(source_type const &s) const { | ||
return convert(s); | ||
} | ||
}; | ||
|
||
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<uint8_t, 4> | ||
template <FloatRoundStyle Round> | ||
struct FastNumericArrayConverter<cutlass::bfloat16_t, uint8_t, 4, Round> { | ||
|
@@ -2661,6 +2709,41 @@ struct FastNumericArrayConverter<T, S, N, Round, | |
|
||
}; | ||
|
||
/// Partial specialization for Array<half, 8> <= Array<int4b, 8> | ||
template < | ||
int N, | ||
FloatRoundStyle Round | ||
> | ||
struct FastNumericArrayConverter<half_t, int4b_t, N, Round> { | ||
static_assert(!(N % 8), "N must be multiple of 8."); | ||
|
||
using result_type = Array<half_t, N>; | ||
using source_type = Array<int4b_t, N>; | ||
static FloatRoundStyle const round_style = Round; | ||
|
||
CUTLASS_HOST_DEVICE | ||
static result_type convert(source_type const & source) { | ||
NumericArrayConverter<half_t, int4b_t, 8, Round> convert_vector_; | ||
|
||
result_type result; | ||
|
||
Array<half_t, 8> *result_ptr = reinterpret_cast<Array<half_t, 8> *>(&result); | ||
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source); | ||
|
||
CUTLASS_PRAGMA_UNROLL | ||
for (int i = 0; i < N / 8; ++i) { | ||
result_ptr[i] = convert_vector_(source_ptr[i]); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
CUTLASS_HOST_DEVICE | ||
result_type operator()(source_type const &s) const { | ||
return convert(s); | ||
} | ||
}; | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
/// Defines preferred rounding mode for a pair of types | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,7 +86,6 @@ void run_test_integer_range_limited() { | |
} | ||
} | ||
|
||
|
||
template <typename Destination, typename Source, int Count> | ||
void run_test_integer_range_all() { | ||
const int kN = Count; | ||
|
@@ -97,13 +96,31 @@ void run_test_integer_range_all() { | |
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN}); | ||
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN}); | ||
|
||
int const kIntSourceMin = std::numeric_limits<Source>::min(); | ||
int const kIntSourceMax = std::numeric_limits<Source>::max(); | ||
int const kIntRange = kIntSourceMax - kIntSourceMin + 1; | ||
constexpr bool int4b_source = std::is_same<Source, cutlass::int4b_t>::value; | ||
|
||
for (int i = 0; i < kN; ++i) { | ||
source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange)); | ||
int kIntSourceMin; | ||
int kIntSourceMax; | ||
if constexpr (int4b_source) { | ||
This comment has been minimized.
Sorry, something went wrong.
manishucsd
Collaborator
|
||
kIntSourceMin = -(1 << (Source::kBits - 1)); | ||
kIntSourceMax = -kIntSourceMin - 1; | ||
} else { | ||
kIntSourceMin = std::numeric_limits<Source>::min(); | ||
kIntSourceMax = std::numeric_limits<Source>::max(); | ||
} | ||
int kIntRange = kIntSourceMax - kIntSourceMin + 1; | ||
|
||
using SourceDataPtr = std::conditional_t<int4b_source, cutlass::Array<Source, kN>*, Source**>; | ||
Source* source_data = nullptr; | ||
SourceDataPtr source_data_ptr = nullptr; | ||
if constexpr (int4b_source) { | ||
source_data_ptr = reinterpret_cast<SourceDataPtr>(source.host_data()); | ||
} else { | ||
source_data = source.host_data(); | ||
source_data_ptr = &source_data; | ||
} | ||
|
||
for (int i = 0; i < kN; ++i) { | ||
(*source_data_ptr)[i] = Source(kIntSourceMin + (i % kIntRange)); | ||
} | ||
|
||
source.sync_device(); | ||
|
@@ -114,25 +131,25 @@ void run_test_integer_range_all() { | |
); | ||
|
||
destination.sync_host(); | ||
|
||
// Verify conversion | ||
bool passed = true; | ||
for (int i = 0; i < kN; ++i) { | ||
if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) { | ||
if(!(float(destination.host_data()[i]) == float((*source_data_ptr)[i]))) { | ||
passed = false; | ||
break; | ||
} | ||
} | ||
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed"; | ||
// Print out results for the failed conversion. | ||
if (!passed) { | ||
|
||
// Print out results for the failed conversion. | ||
if (!passed) { | ||
for (int i = 0; i < kN; ++i) { | ||
std::cout << "source(" << float(source.host_data()[i]) << ") -> " | ||
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl; | ||
std::cout << "source(" << float((*source_data_ptr)[i]) << ") -> " | ||
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl; | ||
} | ||
} | ||
std::flush(std::cout); | ||
} | ||
std::flush(std::cout); | ||
} | ||
|
||
} // namespace kernel | ||
|
@@ -174,3 +191,10 @@ TEST(FastNumericConversion, s8_to_bf16_array) { | |
using Destination = cutlass::bfloat16_t; | ||
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>(); | ||
} | ||
|
||
TEST(FastNumericConversion, s4_to_f16_array) { | ||
int const kN = 256; | ||
This comment has been minimized.
Sorry, something went wrong.
manishucsd
Collaborator
|
||
using Source = cutlass::int4b_t; | ||
using Destination = cutlass::half_t; | ||
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>(); | ||
} |
I see that you are taking
s4
tos8
and then.s8
tof16
. Did you try and see if there is a way to takes4
tof16
? I will try and think it through, I think it should be possible.