Skip to content

Commit

Permalink
Fix the read_write_bytes in device bench
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Jul 15, 2022
1 parent b3e129c commit 741d944
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,17 @@ def device( # noqa C901

nparams = sum(w.numel() for w in emb.split_embedding_weights())
param_size_multiplier = weights_precision.bit_rate() / 8.0
output_size_multiplier = output_dtype.bit_rate() / 8.0
if do_pooling:
read_write_bytes = (
output_size_multiplier * B * sum(Ds) + param_size_multiplier * B * sum(Ds) * L
)
else:
read_write_bytes = (
output_size_multiplier * B * sum(Ds) * L
+ param_size_multiplier * B * sum(Ds) * L
)

logging.info(
f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, "
f"{nparams * param_size_multiplier / 1.0e9: .2f} GB"
Expand Down Expand Up @@ -251,7 +262,7 @@ def device( # noqa C901
logging.info(
f"Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)

Expand All @@ -276,7 +287,7 @@ def device( # noqa C901
)
logging.info(
f"ForwardBackward, B: {B}, E: {E}, T: {T}, D: {D}, L: {L}, "
f"BW: {3 * param_size_multiplier * B * sum(Ds) * L / time_per_iter / 1.0e9: .2f} GB/s, "
f"BW: {3 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, "
f"T: {time_per_iter * 1.0e6:.0f}us"
)

Expand Down

0 comments on commit 741d944

Please sign in to comment.