Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for ffn_hidden_size of 128, and better error message for incompatible ffn sizes. #108

Merged
merged 4 commits into from
May 15, 2024

Conversation

snarayan21
Copy link
Collaborator

@snarayan21 snarayan21 commented May 15, 2024

Previously if ffn_hidden_size was 128 and top k was equal to the number of experts, the output of nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) would be something like torch.Tensor(1) instead of torch.Tensor([1]) -- a zero dimensional tensor instead of a one dimensional tensor. This was causing an error during concatenation on the next line:

  File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 148, in sparse_forward_once
    topo = self.topology(x, padded_bins)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 98, in topology
    column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
                                                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3/dist-packages/megablocks/layers/dmoe.py", line 56, in sparse_transpose
    offsets_t = torch.cat([zero, nnz_per_column])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: zero-dimensional tensor (at position 1) cannot be concatenated

To address the bug, we simply make nnz_per_column a 1D tensor if it's 0D. I added a new set of parameters to the dmoe tests that fails without this change and succeeds with the change. I successfully ran the llm foundry torch_dmoe vs mb_dmoe tests to verify correctness of this change as well.

The second change is to have better error messages for invalid ffn_hidden_size values to help external users.

You can reproduce this error with the small script below as well:

import torch
import pdb
from megablocks.layers.arguments import Arguments
from megablocks.layers.dmoe import ParallelDroplessMLP

args = Arguments(hidden_size = 256, ffn_hidden_size = 128)
pdmlp = ParallelDroplessMLP(args)

x = torch.randn((128, 128)).cuda().to(torch.bfloat16)
expert_weights = torch.randn((128, 1)).cuda().to(torch.bfloat16)
top_experts = torch.zeros((128, 1)).cuda().to(torch.int32)

pdb.set_trace()
topo = pdmlp.sparse_forward_once(x, expert_weights, top_experts)

megablocks/layers/dmoe_test.py Outdated Show resolved Hide resolved
@mihir-db mihir-db merged commit 0411977 into databricks:main May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants