Skip to content

Commit

Permalink
Add int4b_t/uint4b_t support for mixed dtypes GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Oct 24, 2023
1 parent 5f13dca commit f0ef31d
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 16 deletions.
85 changes: 84 additions & 1 deletion include/cutlass/numeric_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -2454,7 +2454,7 @@ struct FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> {
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;

#if 0 // Scalar conversion (Please keep this code for reference for vectorized version below)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
Expand Down Expand Up @@ -2541,6 +2541,54 @@ struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> {
}
};

template<FloatRoundStyle Round, bool convert_lsb_quad>
CUTLASS_DEVICE
Array<cutlass::half_t, 4>
convert_int4b_to_half_quad(uint32_t const& source) {
uint64_t result;

uint32_t tmp1, tmp2;

asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp1) : "r"(source), "n"(convert_lsb_quad ? 0x1100 : 0x3322));
asm volatile("shr.s16 %0,%1,%2;\n" : "=r"(tmp2) : "r"(tmp1), "n"(4));
asm volatile("and.b32 %0,%1,%2;\n" : "=r"(tmp2) : "r"(tmp2), "n"(0xFF00FF00));
asm volatile("shl.s16 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "n"(4));
asm volatile("sh4.s16 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "n"(12));
asm volatile("and.b32 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "n"(0x00FF00FF));
asm volatile("or.b32 %0,%1,%2;\n" : "=r"(tmp1) : "r"(tmp1), "r"(tmp2));

This comment has been minimized.

Copy link
@manishucsd

manishucsd Oct 25, 2023

Collaborator

I see that you are taking s4 to s8 and then. s8 to f16. Did you try and see if there is a way to take s4 to f16? I will try and think it through, I think it should be possible.

This comment has been minimized.

Copy link
@alexsamardzic

alexsamardzic Oct 25, 2023

Author Owner

It's doable directly too: if given 4-bit number put after 10 in the mantissa of f16 number and the remaining 4 bits of mantissa set to 0, then it's the same code as in s8 to f16 conversion, with 0x6600 replaced by 0x5600, and 1536.0_hf by 96.0_hf. The 4-bit numbers still have to be extracted, as they come packed, two of them within a byte; but I realized later this could be probably done by distinguishing between odd/even threads. So this code above is going to be improved, but it works at the moment, and as mentioned in my latest comment on the issue, I'm trying first to make an unit test working (please let me know if you have anything to suggest in that regard).


FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> converter;
return converter.convert(tmp1);

return result;
}

/// Partial specialization for Array<cutlass::half_t, 8> <= Array<int4b_t, 8>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, cutlass::int4b_t, 8, Round> {
using result_type = Array<cutlass::half_t, 8>;
using source_type = Array<cutlass::int4b_t, 8>;
static FloatRoundStyle const round_style = Round;

CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;

uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
Array<cutlass::half_t, 4>* result_ptr = reinterpret_cast<Array<cutlass::half_t, 4>*>(&result);

result_ptr[0] = convert_int4b_to_half_quad<Round, true>(source_ptr[0]);
result_ptr[1] = convert_int4b_to_half_quad<Round, false>(source_ptr[0]);

return result;
}

CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<uint8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::bfloat16_t, uint8_t, 4, Round> {
Expand Down Expand Up @@ -2661,6 +2709,41 @@ struct FastNumericArrayConverter<T, S, N, Round,

};

/// Partial specialization for Array<half, 8> <= Array<int4b, 8>
template <
int N,
FloatRoundStyle Round
>
struct FastNumericArrayConverter<half_t, int4b_t, N, Round> {
static_assert(!(N % 8), "N must be multiple of 8.");

using result_type = Array<half_t, N>;
using source_type = Array<int4b_t, N>;
static FloatRoundStyle const round_style = Round;

CUTLASS_HOST_DEVICE
static result_type convert(source_type const & source) {
NumericArrayConverter<half_t, int4b_t, 8, Round> convert_vector_;

result_type result;

Array<half_t, 8> *result_ptr = reinterpret_cast<Array<half_t, 8> *>(&result);
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 8; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}

return result;
}

CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Defines preferred rounding mode for a pair of types
Expand Down
54 changes: 39 additions & 15 deletions test/unit/core/fast_numeric_conversion.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ void run_test_integer_range_limited() {
}
}


template <typename Destination, typename Source, int Count>
void run_test_integer_range_all() {
const int kN = Count;
Expand All @@ -97,13 +96,31 @@ void run_test_integer_range_all() {
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});

int const kIntSourceMin = std::numeric_limits<Source>::min();
int const kIntSourceMax = std::numeric_limits<Source>::max();
int const kIntRange = kIntSourceMax - kIntSourceMin + 1;
constexpr bool int4b_source = std::is_same<Source, cutlass::int4b_t>::value;

for (int i = 0; i < kN; ++i) {
source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange));
int kIntSourceMin;
int kIntSourceMax;
if constexpr (int4b_source) {

This comment has been minimized.

Copy link
@manishucsd

manishucsd Oct 25, 2023

Collaborator

let us make std::numeric_limits work on s4 and u4, if doesn't already.

This comment has been minimized.

Copy link
@alexsamardzic

alexsamardzic Oct 26, 2023

Author Owner

Fixed.

kIntSourceMin = -(1 << (Source::kBits - 1));
kIntSourceMax = -kIntSourceMin - 1;
} else {
kIntSourceMin = std::numeric_limits<Source>::min();
kIntSourceMax = std::numeric_limits<Source>::max();
}
int kIntRange = kIntSourceMax - kIntSourceMin + 1;

using SourceDataPtr = std::conditional_t<int4b_source, cutlass::Array<Source, kN>*, Source**>;
Source* source_data = nullptr;
SourceDataPtr source_data_ptr = nullptr;
if constexpr (int4b_source) {
source_data_ptr = reinterpret_cast<SourceDataPtr>(source.host_data());
} else {
source_data = source.host_data();
source_data_ptr = &source_data;
}

for (int i = 0; i < kN; ++i) {
(*source_data_ptr)[i] = Source(kIntSourceMin + (i % kIntRange));
}

source.sync_device();
Expand All @@ -114,25 +131,25 @@ void run_test_integer_range_all() {
);

destination.sync_host();

// Verify conversion
bool passed = true;
for (int i = 0; i < kN; ++i) {
if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) {
if(!(float(destination.host_data()[i]) == float((*source_data_ptr)[i]))) {
passed = false;
break;
}
}
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
// Print out results for the failed conversion.
if (!passed) {

// Print out results for the failed conversion.
if (!passed) {
for (int i = 0; i < kN; ++i) {
std::cout << "source(" << float(source.host_data()[i]) << ") -> "
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
std::cout << "source(" << float((*source_data_ptr)[i]) << ") -> "
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
}
}
std::flush(std::cout);
}
std::flush(std::cout);
}

} // namespace kernel
Expand Down Expand Up @@ -174,3 +191,10 @@ TEST(FastNumericConversion, s8_to_bf16_array) {
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}

TEST(FastNumericConversion, s4_to_f16_array) {
int const kN = 256;

This comment has been minimized.

Copy link
@manishucsd

manishucsd Oct 25, 2023

Collaborator

you don't have 256 bit patterns for 4b number. You can reduce this or you will be repeating the bit patterns that you are testing.

This comment has been minimized.

Copy link
@alexsamardzic

alexsamardzic Oct 26, 2023

Author Owner

Fixed.

using Source = cutlass::int4b_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}

0 comments on commit f0ef31d

Please sign in to comment.