Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing flops profiler formatting, units and precision #3927

Merged
merged 7 commits into from
Jul 19, 2023

Conversation

clumsy
Copy link
Contributor

@clumsy clumsy commented Jul 11, 2023

Here's an overview of the changes:

  • repetitive unit conversion is consolidated into a single method: number_to_string
  • extending the list of units to include B for parameter count
  • extending the column alignment to 70 characters to accommodate lines like: fwd+bwd FLOPS per GPU = 3 * fwd flops per GPU / (fwd+bwd latency): (66 characters long)
  • fixing minor artifacts like empty trailing commas because of missing optional sections, truncating trailing 0 decimals, trailing commas, etc.
  • indicating when Flops profiler is started/finished using logger to help troubleshoot issues like: skipping activation function assertion while FLOPS profiler measures throughput Megatron-DeepSpeed#156
  • fixing typos like gpu -> GPU

Example output:

...

world size:                                                             1
data parallel size:                                                     1
model parallel size:                                                    1
batch size per GPU:                                                     4
params per GPU:                                                         429.51 M
params of model = params per GPU * mp_size:                             429.51 M
fwd MACs per GPU:                                                       208.65 GMACs
fwd flops per GPU:                                                      417.36 G
fwd flops of model = fwd flops per GPU * mp_size:                       417.36 G
fwd latency:                                                            375.11 ms
fwd FLOPS per GPU = fwd flops per GPU / fwd latency:                    1.11 TFLOPS
bwd latency:                                                            148.41 ms
bwd FLOPS per GPU = 2 * fwd flops per GPU / bwd latency:                5.62 TFLOPS
fwd+bwd FLOPS per GPU = 3 * fwd flops per GPU / (fwd+bwd latency):      2.39 TFLOPS
step latency:                                                           56.51 ms
iter latency:                                                           580.03 ms
FLOPS per GPU = 3 * fwd flops per GPU / iter latency:                   2.16 TFLOPS
samples/second:                                                         6.9

...

MegatronBertModel(
  429.51 M = 100% Params, 208.65 GMACs = 100% MACs, 374.95 ms = 100% latency, 1.11 TFLOPS
  (model): BertModel(
    424.67 M = 98.87% Params, 84.72 GMACs = 40.6% MACs, 29.57 ms = 7.89% latency, 5.73 TFLOPS
    (language_model): TransformerLanguageModel(
      424.67 M = 98.87% Params, 84.72 GMACs = 40.6% MACs, 29.09 ms = 7.76% latency, 5.83 TFLOPS
      (embedding): Embedding(
        387.91 M = 90.31% Params, 0 MACs = 0% MACs, 971.32 us = 0.26% latency, 0 FLOPS
        (word_embeddings): VocabParallelEmbedding(387.13 M = 90.13% Params, 0 MACs = 0% MACs, 376.46 us = 0.1% latency, 0 FLOPS)
        (position_embeddings): Embedding(774.14 K = 0.18% Params, 0 MACs = 0% MACs, 184.3 us = 0.05% latency, 0 FLOPS, 512, 1512)
        (embedding_dropout): Dropout(0 = 0% Params, 0 MACs = 0% MACs, 127.08 us = 0.03% latency, 0 FLOPS, p=0.1, inplace=False)
      )
      (encoder): ParallelTransformer(
        36.77 M = 8.56% Params, 84.72 GMACs = 40.6% MACs, 28.02 ms = 7.47% latency, 6.05 TFLOPS
        (layers): ModuleList(
          36.77 M = 8.56% Params, 84.72 GMACs = 40.6% MACs, 27.54 ms = 7.34% latency, 6.15 TFLOPS
          (0): ParallelTransformerLayer(
            12.26 M = 2.85% Params, 28.24 GMACs = 13.53% MACs, 9.69 ms = 2.58% latency, 5.83 TFLOPS
            (input_layernorm): MixedFusedLayerNorm(3.02 K = 0% Params, 0 MACs = 0% MACs, 119.92 us = 0.03% latency, 0 FLOPS)
            (self_attention): ParallelAttention(
              9.15 M = 2.13% Params, 21.9 GMACs = 10.5% MACs, 7.01 ms = 1.87% latency, 6.25 TFLOPS
              (query_key_value): ColumnParallelLinear(6.86 M = 1.6% Params, 14.05 GMACs = 6.73% MACs, 3.45 ms = 0.92% latency, 8.14 TFLOPS)
              (core_attention): CoreAttention(
                0 = 0% Params, 3.17 GMACs = 1.52% MACs, 2 ms = 0.53% latency, 3.18 TFLOPS
                (scale_mask_softmax): FusedScaleMaskSoftmax(0 = 0% Params, 0 MACs = 0% MACs, 218.87 us = 0.06% latency, 0 FLOPS)
                (attention_dropout): Dropout(0 = 0% Params, 0 MACs = 0% MACs, 179.77 us = 0.05% latency, 0 FLOPS, p=0.1, inplace=False)
              )
              (dense): RowParallelLinear(2.29 M = 0.53% Params, 4.68 GMACs = 2.24% MACs, 1.28 ms = 0.34% latency, 7.32 TFLOPS)
            )
            (post_attention_layernorm): MixedFusedLayerNorm(3.02 K = 0% Params, 0 MACs = 0% MACs, 105.62 us = 0.03% latency, 0 FLOPS)
            (mlp): ParallelMLP(
              3.1 M = 0.72% Params, 6.34 GMACs = 3.04% MACs, 1.98 ms = 0.53% latency, 6.41 TFLOPS
              (dense_h_to_4h): ColumnParallelLinear(1.55 M = 0.36% Params, 3.17 GMACs = 1.52% MACs, 939.13 us = 0.25% latency, 6.75 TFLOPS)
              (dense_4h_to_h): RowParallelLinear(1.55 M = 0.36% Params, 3.17 GMACs = 1.52% MACs, 860.21 us = 0.23% latency, 7.37 TFLOPS)
            )
          )
          (1): ParallelTransformerLayer(
            12.26 M = 2.85% Params, 28.24 GMACs = 13.53% MACs, 8.94 ms = 2.38% latency, 6.32 TFLOPS
            (input_layernorm): MixedFusedLayerNorm(3.02 K = 0% Params, 0 MACs = 0% MACs, 99.9 us = 0.03% latency, 0 FLOPS)
            (self_attention): ParallelAttention(
              9.15 M = 2.13% Params, 21.9 GMACs = 10.5% MACs, 6.33 ms = 1.69% latency, 6.93 TFLOPS
              (query_key_value): ColumnParallelLinear(6.86 M = 1.6% Params, 14.05 GMACs = 6.73% MACs, 3.39 ms = 0.9% latency, 8.29 TFLOPS)
              (core_attention): CoreAttention(
                0 = 0% Params, 3.17 GMACs = 1.52% MACs, 1.5 ms = 0.4% latency, 4.25 TFLOPS
                (scale_mask_softmax): FusedScaleMaskSoftmax(0 = 0% Params, 0 MACs = 0% MACs, 164.51 us = 0.04% latency, 0 FLOPS)
                (attention_dropout): Dropout(0 = 0% Params, 0 MACs = 0% MACs, 157.12 us = 0.04% latency, 0 FLOPS, p=0.1, inplace=False)

...

@clumsy clumsy requested a review from cli99 as a code owner July 11, 2023 19:53
@clumsy
Copy link
Contributor Author

clumsy commented Jul 11, 2023

As discussed, @cli99.

@clumsy
Copy link
Contributor Author

clumsy commented Jul 18, 2023

@tjruwase , @cli99 this PR seem to pass all the checks now. Do you think it's ready to be merged? Thanks!

@mrwyattii mrwyattii added this pull request to the merge queue Jul 19, 2023
Merged via the queue into deepspeedai:master with commit 488a1b9 Jul 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants