Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use uint dtypes consistently for loads and stores #1334

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
}));
#undef X


// launch 4-bit kernel
#define X(DeviceOnly, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \
nbit::INT4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L<index_t, output_t, OutputRowsPerThread, kWarpsPerBlock, InputRowsInFlight, MinNum128BRows, MaxNum128BRows, DeviceOnly><<< \
Expand Down
112 changes: 56 additions & 56 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1970,14 +1970,14 @@ struct VecNT<4, PrimitiveType::FP> {
// Since byte granule is guaranteed, num_valid_outputs can be any integer
// for int8.
if (aligned_16b && num_valid_outputs == 4) {
*reinterpret_cast<float4*>(output_ptr) =
*reinterpret_cast<const float4*>(&acc);
*reinterpret_cast<uint4*>(output_ptr) =
*reinterpret_cast<const uint4*>(&acc);
} else if (aligned_8b && num_valid_outputs >= 2) {
*reinterpret_cast<float2*>(output_ptr) =
*reinterpret_cast<const float2*>(&(acc.x));
*reinterpret_cast<uint2*>(output_ptr) =
*reinterpret_cast<const uint2*>(&(acc.x));
if (num_valid_outputs == 4) {
*reinterpret_cast<float2*>(output_ptr + 2) =
*reinterpret_cast<const float2*>(&(acc.x) + 2);
*reinterpret_cast<uint2*>(output_ptr + 2) =
*reinterpret_cast<const uint2*>(&(acc.x) + 2);
} else if (num_valid_outputs == 3) {
*(output_ptr + 2) = *(&(acc.x) + 2);
}
Expand All @@ -1998,14 +1998,14 @@ struct VecNT<4, PrimitiveType::FP> {
// Since byte granule is guaranteed, num_valid_outputs can be any integer
// for int8.
if (aligned_8b && num_valid_outputs == 4) {
*reinterpret_cast<float2*>(output_ptr) =
*reinterpret_cast<const float2*>(&val);
*reinterpret_cast<uint2*>(output_ptr) =
*reinterpret_cast<const uint2*>(&val);
} else if (aligned_4b && num_valid_outputs >= 2) {
*reinterpret_cast<float*>(output_ptr) =
*reinterpret_cast<const float*>(&(val.vals[0].x));
*reinterpret_cast<uint*>(output_ptr) =
*reinterpret_cast<const uint*>(&(val.vals[0].x));
if (num_valid_outputs == 4) {
*reinterpret_cast<float*>(output_ptr + 2) =
*reinterpret_cast<const float*>(&(val.vals[0].x) + 2);
*reinterpret_cast<uint*>(output_ptr + 2) =
*reinterpret_cast<const uint*>(&(val.vals[0].x) + 2);
} else if (num_valid_outputs == 3) {
*(output_ptr + 2) =
*reinterpret_cast<const at::Half*>(&(val.vals[0].x) + 2);
Expand Down Expand Up @@ -2084,14 +2084,14 @@ struct VecNT<4, PrimitiveType::INT> {
// Since byte granule is guaranteed, num_valid_outputs can be any integer
// for int8.
if (aligned_16b && num_valid_outputs == 4) {
*reinterpret_cast<float4*>(output_ptr) =
*reinterpret_cast<const float4*>(&acc);
*reinterpret_cast<uint4*>(output_ptr) =
*reinterpret_cast<const uint4*>(&acc);
} else if (aligned_8b && num_valid_outputs >= 2) {
*reinterpret_cast<float2*>(output_ptr) =
*reinterpret_cast<const float2*>(&(acc.x));
*reinterpret_cast<uint2*>(output_ptr) =
*reinterpret_cast<const uint2*>(&(acc.x));
if (num_valid_outputs == 4) {
*reinterpret_cast<float2*>(output_ptr + 2) =
*reinterpret_cast<const float2*>(&(acc.x) + 2);
*reinterpret_cast<uint2*>(output_ptr + 2) =
*reinterpret_cast<const uint2*>(&(acc.x) + 2);
} else if (num_valid_outputs == 3) {
*(output_ptr + 2) = *(&(acc.x) + 2);
}
Expand All @@ -2112,14 +2112,14 @@ struct VecNT<4, PrimitiveType::INT> {
// Since byte granule is guaranteed, num_valid_outputs can be any integer
// for int8.
if (aligned_8b && num_valid_outputs == 4) {
*reinterpret_cast<float2*>(output_ptr) =
*reinterpret_cast<const float2*>(&val);
*reinterpret_cast<uint2*>(output_ptr) =
*reinterpret_cast<const uint2*>(&val);
} else if (aligned_4b && num_valid_outputs >= 2) {
*reinterpret_cast<float*>(output_ptr) =
*reinterpret_cast<const float*>(&(val.vals[0].x));
*reinterpret_cast<uint*>(output_ptr) =
*reinterpret_cast<const uint*>(&(val.vals[0].x));
if (num_valid_outputs == 4) {
*reinterpret_cast<float*>(output_ptr + 2) =
*reinterpret_cast<const float*>(&(val.vals[0].x) + 2);
*reinterpret_cast<uint*>(output_ptr + 2) =
*reinterpret_cast<const uint*>(&(val.vals[0].x) + 2);
} else if (num_valid_outputs == 3) {
*(output_ptr + 2) =
*reinterpret_cast<const at::Half*>(&(val.vals[0].x) + 2);
Expand Down Expand Up @@ -2198,21 +2198,21 @@ struct VecNT<8, PrimitiveType::INT> {
// Since byte granule is guaranteed, num_valid_outputs is multiple of 2 for
// int4.
if (aligned_16b && num_valid_outputs >= 4) { // 128 bit cache line
*reinterpret_cast<float4*>(output_ptr) =
*reinterpret_cast<const float4*>(&(acc.vals[0]));
*reinterpret_cast<uint4*>(output_ptr) =
*reinterpret_cast<const uint4*>(&(acc.vals[0]));
if (num_valid_outputs == 8) {
*reinterpret_cast<float4*>(output_ptr + 4) =
*reinterpret_cast<const float4*>(&(acc.vals[1]));
*reinterpret_cast<uint4*>(output_ptr + 4) =
*reinterpret_cast<const uint4*>(&(acc.vals[1]));
} else if (num_valid_outputs == 6) {
*reinterpret_cast<float2*>(output_ptr + 4) =
*reinterpret_cast<const float2*>(&(acc.vals[1]));
*reinterpret_cast<uint2*>(output_ptr + 4) =
*reinterpret_cast<const uint2*>(&(acc.vals[1]));
}
} else if (aligned_8b) {
#pragma unroll
for (int i = 0; i < 8; i += 2) {
if (i < num_valid_outputs) {
*reinterpret_cast<float2*>(output_ptr + i) =
*reinterpret_cast<const float2*>(&(acc.vals[0].x) + i);
*reinterpret_cast<uint2*>(output_ptr + i) =
*reinterpret_cast<const uint2*>(&(acc.vals[0].x) + i);
}
}
} else {
Expand All @@ -2233,24 +2233,24 @@ struct VecNT<8, PrimitiveType::INT> {
// Since byte granule is guaranteed, num_valid_outputs is multiple of 2 for
// int4.
if (aligned_16b && num_valid_outputs == 8) {
*reinterpret_cast<half8*>(output_ptr) =
*reinterpret_cast<const half8*>(&val);
*reinterpret_cast<uint4*>(output_ptr) =
*reinterpret_cast<const uint4*>(&val);
} else if (aligned_8b && num_valid_outputs >= 4) {
*reinterpret_cast<half4*>(output_ptr) =
*reinterpret_cast<const half4*>(&(val.vals[0].x));
*reinterpret_cast<uint2*>(output_ptr) =
*reinterpret_cast<const uint2*>(&(val.vals[0].x));
if (num_valid_outputs == 8) {
*reinterpret_cast<half4*>(output_ptr + 4) =
*reinterpret_cast<const half4*>(&(val.vals[0].x) + 4);
*reinterpret_cast<uint2*>(output_ptr + 4) =
*reinterpret_cast<const uint2*>(&(val.vals[0].x) + 4);
} else if (num_valid_outputs == 6) {
*reinterpret_cast<half2*>(output_ptr + 4) =
*reinterpret_cast<const half2*>(&(val.vals[0].x) + 4);
*reinterpret_cast<uint*>(output_ptr + 4) =
*reinterpret_cast<const uint*>(&(val.vals[0].x) + 4);
}
} else if (aligned_4b) {
#pragma unroll
for (int i = 0; i < 8; i += 2) {
if (i < num_valid_outputs) {
*reinterpret_cast<half2*>(output_ptr + i) =
*reinterpret_cast<const half2*>(&(val.vals[0].x) + i);
*reinterpret_cast<uint*>(output_ptr + i) =
*reinterpret_cast<const uint*>(&(val.vals[0].x) + i);
}
}
} else {
Expand Down Expand Up @@ -2334,16 +2334,16 @@ struct VecNT<16, PrimitiveType::INT> {
#pragma unroll
for (int i = 0; i < 16; i += 4) {
if (i < num_valid_outputs) {
*reinterpret_cast<float4*>(output_ptr + i) =
*reinterpret_cast<const float4*>(&(acc.vals[0].vals[0]) + i);
*reinterpret_cast<uint4*>(output_ptr + i) =
*reinterpret_cast<const uint4*>(&(acc.vals[0].vals[0]) + i);
}
}
} else if (aligned_8b) {
#pragma unroll
for (int i = 0; i < 16; i += 2) {
if (i < num_valid_outputs) {
*reinterpret_cast<float2*>(output_ptr + i) =
*reinterpret_cast<const float2*>(&(acc.vals[0].vals[0]) + i);
*reinterpret_cast<uint2*>(output_ptr + i) =
*reinterpret_cast<const uint2*>(&(acc.vals[0].vals[0]) + i);
}
}
} else {
Expand All @@ -2365,29 +2365,29 @@ struct VecNT<16, PrimitiveType::INT> {
// Since byte granule is guaranteed, num_valid_outputs is multiple of 4 for
// int2.
if (aligned_16b && num_valid_outputs >= 8) {
*reinterpret_cast<half8*>(output_ptr) =
*reinterpret_cast<const half8*>(&(val.vals[0].x));
*reinterpret_cast<uint4*>(output_ptr) =
*reinterpret_cast<const uint4*>(&(val.vals[0].x));
if (num_valid_outputs == 16) {
*reinterpret_cast<half8*>(output_ptr + 8) =
*reinterpret_cast<const half8*>(&(val.vals[0].x) + 8);
*reinterpret_cast<uint4*>(output_ptr + 8) =
*reinterpret_cast<const uint4*>(&(val.vals[0].x) + 8);
} else if (num_valid_outputs == 12) {
*reinterpret_cast<half4*>(output_ptr + 8) =
*reinterpret_cast<const half4*>(&(val.vals[0].x) + 8);
*reinterpret_cast<uint2*>(output_ptr + 8) =
*reinterpret_cast<const uint2*>(&(val.vals[0].x) + 8);
}
} else if (aligned_8b) {
#pragma unroll
for (int i = 0; i < 16; i += 4) {
if (i < num_valid_outputs) {
*reinterpret_cast<half4*>(output_ptr + i) =
*reinterpret_cast<const half4*>(&(val.vals[0].x) + i);
*reinterpret_cast<uint2*>(output_ptr + i) =
*reinterpret_cast<const uint2*>(&(val.vals[0].x) + i);
}
}
} else if (aligned_4b) {
#pragma unroll
for (int i = 0; i < 16; i += 2) {
if (i < num_valid_outputs) {
*reinterpret_cast<half2*>(output_ptr + i) =
*reinterpret_cast<const half2*>(&(val.vals[0].x) + i);
*reinterpret_cast<uint*>(output_ptr + i) =
*reinterpret_cast<const uint*>(&(val.vals[0].x) + i);
}
}
} else {
Expand Down