Skip to content

Commit

Permalink
Add benchmark EmbeddingSpMDMNBitBenchmarkOutTypeFloat16 (#2901)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2901

This diff adds the benchmark EmbeddingSpMDMNBitBenchmarkOutTypeFloat16, to test TBE with output type float16. This diff doesn't change the EmbeddingSpMDMNBitBenchmark.

Differential Revision: D60254038
  • Loading branch information
helloguo authored and facebook-github-bot committed Jul 25, 2024
1 parent 10d2f58 commit 8050e77
Showing 1 changed file with 209 additions and 23 deletions.
232 changes: 209 additions & 23 deletions bench/EmbeddingSpMDMNBitBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ static vector<vector<int>> GetInputs_() {
return input_dims;
}

template <typename OutType>
int run_benchmark(
int bit_rate,
int batch_size,
Expand All @@ -69,7 +70,8 @@ int run_benchmark(
int average_len,
bool normalize_by_lengths,
bool use_32_bit_indices = false,
bool prefetch = false) {
bool prefetch = false,
bool is_bf16_out = false) {
// Create embedding table
int num_elem_per_byte = 8 / bit_rate;
int fused_embedding_dim =
Expand Down Expand Up @@ -133,8 +135,8 @@ int run_benchmark(
weights[i] = embedding_distribution(generator);
}

vector<float> output_sls_ref(batch_size * embedding_dim);
vector<float> output_slws_ref(output_sls_ref.size()),
vector<OutType> output_sls_ref(batch_size * embedding_dim);
vector<OutType> output_slws_ref(output_sls_ref.size()),
output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size());

constexpr int NUM_WARMUP = 10;
Expand All @@ -148,11 +150,12 @@ int run_benchmark(
CACHE_LINE_LEN);

for (bool has_weight : {false, true}) {
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
vector<float> output_autovec(output_sls_ref.size());
vector<OutType>& output_ref = has_weight ? output_slws_ref : output_sls_ref;
vector<OutType> output_autovec(output_sls_ref.size());

bool success = false, success_ref = false, success_autovec = false;

#ifndef OUT_TYPE_FLOAT16
auto kernel_32 = GenerateEmbeddingSpMDMNBit<int32_t>(
bit_rate,
embedding_dim,
Expand All @@ -165,8 +168,9 @@ int run_benchmark(
has_weight,
normalize_by_lengths,
prefetch ? 16 : 0);
#endif // OUT_TYPE_FLOAT16

vector<float>& output = has_weight ? output_slws : output_sls;
vector<OutType>& output = has_weight ? output_slws : output_sls;
for (bool flush_cache : {false, true}) {
// Reference implementation
double t_ref = measureWithWarmup(
Expand All @@ -183,7 +187,13 @@ int run_benchmark(
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
output_ref.data(),
false, // is_weight_positional
true, // use_offsets
-1, // output_stride
-1, // input_stride
true, // scale_bias_last
is_bf16_out);
} else {
success_ref = EmbeddingSpMDMNBit_ref(
bit_rate,
Expand All @@ -196,7 +206,13 @@ int run_benchmark(
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
output_ref.data(),
false, // is_weight_positional
true, // use_offsets
-1, // output_stride
-1, // input_stride
true, // scale_bias_last
is_bf16_out);
}
},
NUM_WARMUP,
Expand Down Expand Up @@ -227,7 +243,13 @@ int run_benchmark(
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_autovec.data());
output_autovec.data(),
false, // is_weight_positional
true, // use_offsets
-1, // output_stride
-1, // input_stride
true, // scale_bias_last
is_bf16_out);
} else {
success_autovec = EmbeddingSpMDMNBit_autovec(
bit_rate,
Expand All @@ -240,7 +262,13 @@ int run_benchmark(
offsets.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_autovec.data());
output_autovec.data(),
false, // is_weight_positional
true, // use_offsets
-1, // output_stride
-1, // input_stride
true, // scale_bias_last
is_bf16_out);
}
},
NUM_WARMUP,
Expand All @@ -256,6 +284,7 @@ int run_benchmark(
}
});

#ifndef OUT_TYPE_FLOAT16
// Hand-written AVX2/AVX512 implementation
double t = measureWithWarmup(
[&]() {
Expand Down Expand Up @@ -293,6 +322,7 @@ int run_benchmark(
cache_evict(output);
}
});
#endif // OUT_TYPE_FLOAT16

// printMatrix(
// matrix_op_t::NoTranspose,
Expand All @@ -312,6 +342,7 @@ int run_benchmark(
if (!flush_cache) {
// vector<float>& output_ref =
// has_weight ? output_slws_ref : output_sls_ref;
#ifndef OUT_TYPE_FLOAT16
if (success != success_ref) {
assert(
false &&
Expand All @@ -320,13 +351,32 @@ int run_benchmark(
<< endl;
} else {
for (size_t i = 0; i < output.size(); ++i) {
assert(fabs(output[i] - output_ref[i]) < 1e-3);
if (fabs(output[i] - output_ref[i]) >= 1e-3) {
cout << "asmjit vs ref : " << i << " " << output[i] << " "
<< output_ref[i] << endl;
float tmp1 = 0;
float tmp2 = 0;
if (std::is_same<OutType, float>::value) {
tmp1 = output[i];
tmp2 = output_ref[i];
} else if (std::is_same<OutType, uint16_t>::value) {
if (is_bf16_out) {
tmp1 = cpu_bf162float(output[i]);
tmp2 = cpu_bf162float(output_ref[i]);
} else {
tmp1 = cpu_half2float(output[i]);
tmp2 = cpu_half2float(output_ref[i]);
}
} else {
assert(false && "ERROR: unsupported output type");
cout << "ERROR: unsupported output type" << endl;
}

assert(fabs(tmp1 - tmp2) < 1e-3);
if (fabs(tmp1 - tmp2) >= 1e-3) {
cout << "asmjit vs ref : " << i << " " << tmp1 << " " << tmp2
<< endl;
}
}
}
#endif // OUT_TYPE_FLOAT16

if (success_autovec != success_ref) {
assert(
Expand All @@ -335,16 +385,47 @@ int run_benchmark(
cout << "autovec return " << success_autovec << " ref return "
<< success_ref << endl;
} else {
for (size_t i = 0; i < output.size(); ++i) {
assert(fabs(output_autovec[i] - output_ref[i]) < 1e-3);
if (fabs(output_autovec[i] - output_ref[i]) >= 1e-3) {
cout << "autovec vs ref: " << i << " " << output_autovec[i] << " "
<< output_ref[i] << endl;
for (size_t i = 0; i < output_autovec.size(); ++i) {
float tmp1 = 0;
float tmp2 = 0;
if (std::is_same<OutType, float>::value) {
tmp1 = output_autovec[i];
tmp2 = output_ref[i];
} else if (std::is_same<OutType, uint16_t>::value) {
if (is_bf16_out) {
tmp1 = cpu_bf162float(output_autovec[i]);
tmp2 = cpu_bf162float(output_ref[i]);
} else {
tmp1 = cpu_half2float(output_autovec[i]);
tmp2 = cpu_half2float(output_ref[i]);
}
} else {
assert(false && "ERROR: unsupported output type");
cout << "ERROR: unsupported output type" << endl;
}

assert(fabs(tmp1 - tmp2) < 1e-3);
if (fabs(tmp1 - tmp2) >= 1e-3) {
cout << "autovec vs ref: " << i << " " << tmp1 << " " << tmp2
<< endl;
}
}
}
}

if (std::is_same<OutType, float>::value) {
cout << "out type fp32, ";
} else if (std::is_same<OutType, uint16_t>::value) {
if (is_bf16_out) {
cout << "out type bf16, ";
} else {
cout << "out type fp16, ";
}
} else {
assert(false && "ERROR: unsupported output type");
cout << "ERROR: unsupported output type" << endl;
}

if (has_weight) {
cout << "SLW(WEIGHTED), ";
} else {
Expand All @@ -361,6 +442,7 @@ int run_benchmark(
cout << "prefetch off, ";
}

#ifndef OUT_TYPE_FLOAT16
cout << "b/w, " << bytes / 1e9 / t << ", GB/s, " << "effective b/w, "
<< bytes_padded / 1e9 / t << ", GB/s, " << "time, " << t
<< ", autovec b/w, " << bytes / 1e9 / t_autovec << ", GB/s, "
Expand All @@ -370,6 +452,14 @@ int run_benchmark(
<< bytes_padded / 1e9 / t_ref << ", GB/s, " << "ref time, " << t_ref
<< ", autovec speedup, " << t_ref / t_autovec << ", asmjit speedup, "
<< t_ref / t << endl;
#else
cout << "autovec b/w, " << bytes / 1e9 / t_autovec << ", GB/s, "
<< "autovec eff. b/w, " << bytes_padded / 1e9 / t_autovec
<< ", GB/s, " << "autovec time, " << t_autovec << ", ref b/w, "
<< bytes / 1e9 / t_ref << ", GB/s, " << "ref eff. b/w, "
<< bytes_padded / 1e9 / t_ref << ", GB/s, " << "ref time, " << t_ref
<< ", autovec speedup, " << t_ref / t_autovec << endl;
#endif // OUT_TYPE_FLOAT16
} // flush_cache
} // has_weight
return 0;
Expand Down Expand Up @@ -397,16 +487,41 @@ int main() {
// args: batch sz, num rows, emb dim, avg len, normalize, use 32b,
// prefetch
cout << "64 bit indices, ";
run_benchmark(
#ifndef OUT_TYPE_FLOAT16
run_benchmark<float>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false); // normalize_by_lengths
#else
run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
false, // use_32_bit_indices
false, // prefetch
false); // is_bf16_out

run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
false, // use_32_bit_indices
false, // prefetch
true); // is_bf16_out
#endif // OUT_TYPE_FLOAT16

cout << "64 bit indices with prefetching, ";
run_benchmark(
#ifndef OUT_TYPE_FLOAT16
run_benchmark<float>(
bit_rate,
batch_size,
num_rows,
Expand All @@ -415,19 +530,67 @@ int main() {
false, // normalize_by_lengths
false, // use_32_bit_indices
true); // prefetch
#else
run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
false, // use_32_bit_indices
true, // prefetch
false); // is_bf16_out

run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
false, // use_32_bit_indices
true, // prefetch
true); // is_bf16_out
#endif // OUT_TYPE_FLOAT16

cout << "32 bit indices, ";
run_benchmark(
#ifndef OUT_TYPE_FLOAT16
run_benchmark<float>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
true); // use_32_bit_indices
#else
run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
true, // use_32_bit_indices
false, // prefetch
false); // is_bf16_out

run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
true, // use_32_bit_indices
false, // prefetch
true); // is_bf16_out
#endif // OUT_TYPE_FLOAT16

cout << "32 bit indices with prefetching, ";
run_benchmark(
#ifndef OUT_TYPE_FLOAT16
run_benchmark<float>(
bit_rate,
batch_size,
num_rows,
Expand All @@ -436,6 +599,29 @@ int main() {
false, // normalize_by_lengths
true, // use_32_bit_indices
true); // prefetch
#else
run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
true, // use_32_bit_indices
true, // prefetch
false); // is_bf16_out

run_benchmark<float16>(
bit_rate,
batch_size,
num_rows,
embedding_dim,
average_len,
false, // normalize_by_lengths
true, // use_32_bit_indices
true, // prefetch
true); // is_bf16_out
#endif // OUT_TYPE_FLOAT16

// running with normalize by lengths
// run_benchmark(batch_size, num_rows, embedding_dim, average_len,
Expand Down

0 comments on commit 8050e77

Please sign in to comment.