Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize BCast Map Failure #4013

Closed
csarofeen opened this issue Mar 5, 2025 · 0 comments · Fixed by #4019
Closed

Initialize BCast Map Failure #4013

csarofeen opened this issue Mar 5, 2025 · 0 comments · Fixed by #4019
Assignees
Labels
bug Something isn't working Thunder

Comments

@csarofeen
Copy link
Collaborator

When running the hf_llama example from https://github.com/kevinstephano/thunder_model_blocks I'm getting an error in the backend Thunder-nvFuser-more-ops.

The error is:

RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/Fuser/csrc/logical_domain_map.cpp":1048, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
Exception raised from initializeBcastMap at /opt/pytorch/Fuser/csrc/logical_domain_map.cpp:1048 (most recent call first):

The repro is:

# CUDA devices:
#  0: NVIDIA Graphics Device
#  1: NVIDIA Graphics Device
#  2: NVIDIA Graphics Device
#  3: NVIDIA Graphics Device
# torch version: 2.7.0a0+ecf3bae40a.nvInternal
# cuda version: 12.8
# nvfuser version: 0.2.26+gita38ce70
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[128256, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[1, 6], contiguity=[None, True], dtype=DataType.Int, is_cpu=False, stride_order=[1, 0])
    T3 = fd.define_tensor(shape=[32], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[2048], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[2048, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T7 = fd.define_tensor(shape=[512, 2048], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    S8 = fd.define_scalar(2.00000, dtype=DataType.Double)
    S9 = fd.define_scalar(False, dtype=DataType.Bool)
    S10 = fd.define_scalar(False, dtype=DataType.Bool)
    T11 = fd.ops.embedding_fwd(T0, T1, None, None, S8, S9, S10)
    S12 = fd.define_scalar(6, dtype=DataType.Int)
    S13 = fd.define_scalar(0, dtype=DataType.Int)
    S14 = fd.define_scalar(1, dtype=DataType.Int)
    T15 = fd.ops.iota(S12, S13, S14, dtype=DataType.Int)
    T19 = fd.ops.broadcast_in_dim(T15, shape=[1, 6], broadcast_dims=[1])
    S20 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double)
    T24 = fd.ops.full(shape=[6, 6], fill_value=S20, dtype=DataType.BFloat16)
    T28 = fd.ops.broadcast_in_dim(T15, shape=[6, 1], broadcast_dims=[0])
    T32 = fd.ops.broadcast_in_dim(T19, shape=[6, 6], broadcast_dims=[0, 1])
    T36 = fd.ops.broadcast_in_dim(T28, shape=[6, 6], broadcast_dims=[0, 1])
    T37 = fd.ops.sub(T32, T36)
    S38 = fd.define_scalar(1, dtype=DataType.Int)
    T39 = fd.ops.ge(T37, S38)
    S40 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T41 = fd.ops.where(T39, T24, S40)
    T42 = fd.ops.cast(T41, dtype=DataType.BFloat16)
    T46 = fd.ops.reshape(T15, new_shape=[6, 1])
    T50 = fd.ops.broadcast_in_dim(T15, shape=[6, 6], broadcast_dims=[1])
    T54 = fd.ops.broadcast_in_dim(T46, shape=[6, 6], broadcast_dims=[0, 1])
    T55 = fd.ops.gt(T50, T54)
    T56 = fd.ops.cast(T42, dtype=DataType.Float)
    T57 = fd.ops.cast(T55, dtype=DataType.Float)
    T58 = fd.ops.mul(T56, T57)
    T59 = fd.ops.cast(T58, dtype=DataType.BFloat16)
    T65 = fd.ops.broadcast_in_dim(T59, shape=[1, 1, 6, 6], broadcast_dims=[2, 3])
    T71 = fd.ops.broadcast_in_dim(T65, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3])
    T72 = fd.ops.set(T71)
    T78 = fd.ops.broadcast_in_dim(T2, shape=[1, 1, 1, 6], broadcast_dims=[0, 3])
    T84 = fd.ops.broadcast_in_dim(T78, shape=[1, 1, 6, 6], broadcast_dims=[0, 1, 2, 3])
    T85 = fd.ops.cast(T72, dtype=DataType.Float)
    T86 = fd.ops.cast(T84, dtype=DataType.Float)
    T87 = fd.ops.add(T85, T86)
    T88 = fd.ops.cast(T87, dtype=DataType.BFloat16)
    S89 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T90 = fd.ops.eq(T88, S89)
    S91 = fd.define_scalar(-3.38953e+38, dtype=DataType.Double)
    T92 = fd.ops.where(T90, S91, T72)
    T93 = fd.ops.cast(T92, dtype=DataType.BFloat16)
    T98 = fd.ops.broadcast_in_dim(T3, shape=[1, 32, 1], broadcast_dims=[1])
    T99 = fd.ops.cast(T98, dtype=DataType.Float)
    T104 = fd.ops.broadcast_in_dim(T99, shape=[1, 32, 1], broadcast_dims=[0, 1, 2])
    T109 = fd.ops.broadcast_in_dim(T19, shape=[1, 1, 6], broadcast_dims=[0, 2])
    T110 = fd.ops.cast(T109, dtype=DataType.Float)
    T111 = fd.ops.matmul(T104, T110)
    T112 = fd.ops.permute(T111, dims=[0, 2, 1])
    T113 = fd.ops.cat([T112, T112], dim=-1, manual_padding=0)
    T114 = fd.ops.cos(T113)
    T115 = fd.ops.sin(T113)
    S116 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T117 = fd.ops.mul(T114, S116)
    S118 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T119 = fd.ops.mul(T115, S118)
    T120 = fd.ops.cast(T117, dtype=DataType.BFloat16)
    T121 = fd.ops.cast(T119, dtype=DataType.BFloat16)
    T122 = fd.ops.cast(T11, dtype=DataType.Float)
    S123 = fd.define_scalar(2.00000, dtype=DataType.Double)
    T124 = fd.ops.pow(T122, S123)
    T125 = fd.ops.sum(T124, dims=[2], keepdim=False, dtype=DataType.Null)
    T130 = fd.ops.broadcast_in_dim(T125, shape=[1, 6, 1], broadcast_dims=[0, 1])
    S131 = fd.define_scalar(2048.00, dtype=DataType.Double)
    S132 = fd.ops.reciprocal(S131)
    T133 = fd.ops.mul(T130, S132)
    S134 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T135 = fd.ops.add(T133, S134)
    T136 = fd.ops.rsqrt(T135)
    T141 = fd.ops.broadcast_in_dim(T136, shape=[1, 6, 2048], broadcast_dims=[0, 1, 2])
    T142 = fd.ops.mul(T122, T141)
    T147 = fd.ops.broadcast_in_dim(T4, shape=[1, 6, 2048], broadcast_dims=[2])
    T148 = fd.ops.cast(T147, dtype=DataType.Float)
    T149 = fd.ops.mul(T148, T142)
    T150 = fd.ops.cast(T149, dtype=DataType.BFloat16)
    T151 = fd.ops.linear(T150, T5)
    T157 = fd.ops.reshape(T151, new_shape=[1, 6, 32, 64])
    T158 = fd.ops.permute(T157, dims=[0, 2, 1, 3])
    T159 = fd.ops.linear(T150, T6)
    T165 = fd.ops.reshape(T159, new_shape=[1, 6, 8, 64])
    T166 = fd.ops.permute(T165, dims=[0, 2, 1, 3])
    T167 = fd.ops.linear(T150, T7)
    T173 = fd.ops.reshape(T167, new_shape=[1, 6, 8, 64])
    T174 = fd.ops.permute(T173, dims=[0, 2, 1, 3])
    T180 = fd.ops.broadcast_in_dim(T120, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3])
    T186 = fd.ops.broadcast_in_dim(T121, shape=[1, 1, 6, 64], broadcast_dims=[0, 2, 3])
    T192 = fd.ops.broadcast_in_dim(T180, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T193 = fd.ops.cast(T158, dtype=DataType.Float)
    T194 = fd.ops.cast(T192, dtype=DataType.Float)
    T195 = fd.ops.mul(T193, T194)
    T211 = fd.ops.slice(T158, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T227 = fd.ops.slice(T158, start_indices=[0, 0, 0, 32], end_indices=[1, 32, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T228 = fd.ops.cast(T227, dtype=DataType.Float)
    T229 = fd.ops.neg(T228)
    T230 = fd.ops.cast(T229, dtype=DataType.BFloat16)
    T231 = fd.ops.cat([T230, T211], dim=-1, manual_padding=0)
    T237 = fd.ops.broadcast_in_dim(T186, shape=[1, 32, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T238 = fd.ops.cast(T231, dtype=DataType.Float)
    T239 = fd.ops.cast(T237, dtype=DataType.Float)
    T240 = fd.ops.mul(T238, T239)
    T241 = fd.ops.add(T195, T240)
    T242 = fd.ops.cast(T241, dtype=DataType.BFloat16)
    T248 = fd.ops.broadcast_in_dim(T180, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T249 = fd.ops.cast(T166, dtype=DataType.Float)
    T250 = fd.ops.cast(T248, dtype=DataType.Float)
    T251 = fd.ops.mul(T249, T250)
    T267 = fd.ops.slice(T166, start_indices=[0, 0, 0, 0], end_indices=[1, 8, 6, 32], strides=[1, 1, 1, 1], manual_normalization=0)
    T283 = fd.ops.slice(T166, start_indices=[0, 0, 0, 32], end_indices=[1, 8, 6, 64], strides=[1, 1, 1, 1], manual_normalization=0)
    T284 = fd.ops.cast(T283, dtype=DataType.Float)
    T285 = fd.ops.neg(T284)
    T286 = fd.ops.cast(T285, dtype=DataType.BFloat16)
    T287 = fd.ops.cat([T286, T267], dim=-1, manual_padding=0)
    T293 = fd.ops.broadcast_in_dim(T186, shape=[1, 8, 6, 64], broadcast_dims=[0, 1, 2, 3])
    T294 = fd.ops.cast(T287, dtype=DataType.Float)
    T295 = fd.ops.cast(T293, dtype=DataType.Float)
    T296 = fd.ops.mul(T294, T295)
    T297 = fd.ops.add(T251, T296)
    T298 = fd.ops.cast(T297, dtype=DataType.BFloat16)
    T305 = fd.ops.broadcast_in_dim(T298, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4])
    T312 = fd.ops.broadcast_in_dim(T305, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T318 = fd.ops.reshape(T312, new_shape=[1, 32, 6, 64])
    T325 = fd.ops.broadcast_in_dim(T174, shape=[1, 8, 1, 6, 64], broadcast_dims=[0, 1, 3, 4])
    T332 = fd.ops.broadcast_in_dim(T325, shape=[1, 8, 4, 6, 64], broadcast_dims=[0, 1, 2, 3, 4])
    T338 = fd.ops.reshape(T332, new_shape=[1, 32, 6, 64])
    T339 = fd.ops.stride_order(T242, stride_order=[3, 2, 1, 0])
    T340 = fd.ops.stride_order(T318, stride_order=[3, 2, 1, 0])
    T341 = fd.ops.stride_order(T338, stride_order=[3, 2, 1, 0])
    fd.add_output(T11)
    fd.add_output(T72)
    fd.add_output(T93)
    fd.add_output(T174)
    fd.add_output(T298)
    fd.add_output(T339)
    fd.add_output(T340)
    fd.add_output(T341)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.testing.make_tensor((1, 6), dtype=torch.int64, device='cuda:0'),
    torch.testing.make_tensor((128256, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((1, 6), dtype=torch.int64, device='cuda:0'),
    torch.testing.make_tensor((32,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((2048,), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((2048, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'),
    torch.testing.make_tensor((512, 2048), dtype=torch.bfloat16, device='cuda:0'),
]
fd.execute(inputs)
@csarofeen csarofeen added bug Something isn't working Thunder labels Mar 5, 2025
naoyam added a commit that referenced this issue Mar 6, 2025
naoyam added a commit that referenced this issue Mar 6, 2025
Fixes #4013 

`ComputeAtLogicalDomainMap` needs `handle` for `EmbeddingFwdOp` to
analyze broadcast resolutions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Thunder
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants