-
Notifications
You must be signed in to change notification settings - Fork 517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
H100 Transformer Engine implementation #249
Comments
Linked issue in the lightning repo: Lightning-AI/pytorch-lightning#17172 |
@SinanAkkoyun Do you have access to H100s? If so, would you like to try out the PR Lightning-AI/pytorch-lightning#17597? You can install it by running pip install -U https://github.com/Lightning-AI/lightning/archive/refs/heads/carmocca/transformer-engine.zip We would love to get your feedback. |
Wouldn't you need to replace all your nn.Modules with te.Module equivalents if you switch to TransformersEngine? In particular, the te.Modules also have fused components (i.e. LayerNormLinear, LayerNormMLP)
You would also need to update your training code:
The speed-up is definitely there, we see 2-3x using FP8... |
Yes. Automatic replacement of the layers is missing but it's something that we want to do too. The parallelism customization would be left to the user to do though. |
@carmocca Hello! Thank you very much for the info! I currently have access to a H100 cloud GPU, although if I shut it down I might not get my hands on it again, so any quick help would be very appreciated :)
|
@SinanAkkoyun Did you run Can you Line 116 in 6e60162
|
@carmocca Sure! Thank you so much for the reply? I just ran generate.py, only get 30 tokens/second with the H100 |
@carmocca Does this mean its not automatically set? These are some stats:
|
It is not automatically set, try this: fabric = L.Fabric(devices=1, precision="8-mixed")
dtype = None |
@carmocca
I received the error above... What do I need to import? I already installed your zip, does it have anything to do with that? I am sorry for all the minor questions but I want to be safe in the implementation |
Thank you very much, I am in the process of installing it |
Actually, based on NVIDIA/TransformerEngine#242 (comment) it seems like we can keep the weights in fp16 or bf16 during inference. Meaning not doing |
Thank you so much! I had/have trouble reinstalling the right cuda and cudnn versions, it's done in a couple of minutes Thanks, I will leave the original dtype code |
It is still installing TransformerEngine (dependency hell, wrong cuda pytorc version etc, now compiling pytorch with cuda 12.1 from scratch) |
Normal generate.py: root@e7bbd97bc0d2:/app/lit-llama# python3 generate.py --prompt "Hello, my name is"
Loading model ...
Time to load model: 19.90 seconds.
Global seed set to 1234
Hello, my name is TJ.
I am a stay at home dad with 3 kids. I work part time for my church as a Custodian and I also tutor online. I have a Masters degree in Human Resource Management and I’
Time for inference 1: 1.25 sec total, 40.02 tokens/sec
Memory used: 13.57 GB So, after lots of trial and error, I finally set up an nvidia dev docker with cuda 12.1, installed pytorch (cuda 12.1) and ran the modified generation.py, received this output: root@e7bbd97bc0d2:/app/lit-llama# python3 gen.py --prompt "Hello, my name is"
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7efbb27bfe80>
cuda
Loading model ...
Time to load model: 24.39 seconds.
Traceback (most recent call last):
File "gen.py", line 179, in <module>
CLI(main)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
return _run_component(component, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
return component(**cfg)
File "gen.py", line 139, in main
model = fabric.setup_module(model)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/fabric.py", line 254, in setup_module
module = self._precision.convert_module(module)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 82, in convert_module
_convert_layers(module)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 116, in _convert_layers
_convert_layers(module)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 116, in _convert_layers
_convert_layers(module)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 116, in _convert_layers
_convert_layers(module)
[Previous line repeated 989 more times]
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/plugins/precision/fp8_transformer_engine.py", line 90, in _convert_layers
for name, child in module.named_children():
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 2206, in named_children
memo = set()
RecursionError: maximum recursion depth exceeded while calling a Python object I am trying to fix this but I am not achieving any results with that at the moment... |
I am stiill running the server, how could I debug this? |
This is what GPT-4 told me, I don't know if it makes sense: def _convert_layers(module: torch.nn.Module) -> None:
import transformer_engine.pytorch as te
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
if child.in_features % 16 != 0 or child.out_features % 16 != 0:
# https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting
rank_zero_warn(
"Support for FP8 in the linear layers with `precision='8-mixed'` is currently limited to tensors"
f" with shapes where both dimensions are divisible by 16. The layer {name!r} does not fit this"
" criteria. You might want to add padding to your inputs."
)
continue
has_bias = child.bias is not None
replacement = te.Linear(child.in_features, child.out_features, bias=has_bias)
replacement.weight.data = child.weight.data.clone()
if has_bias:
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
elif isinstance(child, torch.nn.LayerNorm):
replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps)
replacement.weight.data = child.weight.data.clone()
replacement.bias.data = child.bias.data.clone()
log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent")
module.__setattr__(name, replacement)
else:
_convert_layers(child) # Recurse on the child, not the parent After doing this mod, the following error occurs: root@d395b61d47d7:/app/lit-llama# python3 gen.py --prompt "Hello, my name is"
Loading model ...
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7f2df85656a0>
cuda
Time to load model: 22.48 seconds.
Global seed set to 1234
Traceback (most recent call last):
File "gen.py", line 179, in <module>
CLI(main)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
return _run_component(component, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
return component(**cfg)
File "gen.py", line 154, in main
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "gen.py", line 67, in generate
logits = model(x, max_seq_length, input_pos)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
output = self._forward_module(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 114, in forward
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 159, in forward
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 191, in forward
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 2267, in forward
with self.prepare_forward(inp, is_first_microbatch) as inp:
File "/usr/lib/python3.8/contextlib.py", line 113, in __enter__
return next(self.gen)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 593, in prepare_forward
self.set_activation_dtype(inp)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 484, in set_activation_dtype
assert all(
AssertionError: Data type for activations and buffers must match when outside of autocasted region |
I tried to resolve this, this is my attempt and findings:
When autocasting to f32, I get this:
|
Your fp16 vs fp32 issues might be caused because of this https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/module.py#L3464-L3467 Lines 128 to 133 in a24fc5e
So maybe the easiest thing to start with is to remove the BTW thank you for your efforts, this is really useful |
Thank you very much for helping me out!!! I tried to remove the EmptyInitOnDevice context manager, but I think I did not succeed, here is the code: """
``"gptq.int4"``: GPTQ 4-bit mode.
"""
assert checkpoint_path.is_file(), checkpoint_path
assert tokenizer_path.is_file(), tokenizer_path
fabric = L.Fabric(devices=1, precision="8-mixed")
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
print("Loading model ...", file=sys.stderr)
print(fabric.strategy.precision)
print(fabric.device.type)
t0 = time.time()
with lazy_load(checkpoint_path) as checkpoint:
name = llama_model_lookup(checkpoint)
# Removed the context manager here
model = LLaMA.from_name(name)
model.load_state_dict(checkpoint)
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
model.eval()
model = fabric.setup_module(model)
tokenizer = Tokenizer(tokenizer_path)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
prompt_length = encoded.size(0)
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0
model.reset_cache()
print(tokenizer.decode(y))
tokens_generated = y.size(0) - prompt_length
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
if fabric.device.type == "cuda":
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", file=sys.stderr)
if __name__ == "__main__": I still get the error: root@d395b61d47d7:/app/lit-llama# python3 gen.py --prompt "Hello, my name is"
Loading model ...
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7f075c468fa0>
cuda
Time to load model: 18.31 seconds.
Global seed set to 1234
DEBUG: IDX TYPE:
torch.int32
DEBUG: x dtype:
torch.float32
Traceback (most recent call last):
File "gen.py", line 179, in <module>
CLI(main)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
return _run_component(component, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
return component(**cfg)
File "gen.py", line 154, in main
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "gen.py", line 67, in generate
logits = model(x, max_seq_length, input_pos)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
output = self._forward_module(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 119, in forward
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 164, in forward
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 200, in forward
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 2306, in forward
out = linear_fn(*args)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 1671, in forward
assert (
AssertionError: Input and weight dimensions are not compatible for FP8 execution. |
The assertion you are hitting gets raised under two conditions: https://github.com/NVIDIA/TransformerEngine/blob/stable/transformer_engine/pytorch/module.py#L1675 The 7B config will have So it must be |
Thank you so so much!!! I first tried to pad it like this: # generate max_new_tokens tokens
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)
# new: padding code
original_length = x.size(1)
new_length = ((original_length - 1) // 8 + 1) * 8 # Round up to nearest multiple of 8
if original_length != new_length:
padding = torch.zeros(x.size(0), new_length - original_length, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# forward
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature Which resulted in: Traceback (most recent call last):
File "gen.py", line 185, in <module>
CLI(main)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
return _run_component(component, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
return component(**cfg)
File "gen.py", line 160, in main
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "gen.py", line 75, in generate
logits = model(x, max_seq_length, input_pos)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
output = self._forward_module(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 120, in forward
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 165, in forward
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 208, in forward
q = apply_rope(q, rope)
File "/app/lit-llama/lit_llama/model.py", line 317, in apply_rope
rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
RuntimeError: shape '[1, 8, 1, 64, 2]' is invalid for input of size 768 Then, I tried supplying only 8 tokens to see if it matches (because I failed at the padding). It now worked for the first CasualAttention forward passes but crashes after the second LLaMA forward pass: root@d395b61d47d7:/app/lit-llama# python3 gen.py --prompt "Hello Hello Hello Hello Hello Hello Hello"
Loading model ...
<lightning.fabric.plugins.precision.fp8_transformer_engine.Fp8TransformerEnginePrecision object at 0x7f21b0f43a90>
cuda
Time to load model: 62.42 seconds.
Length of encoded prompt: 8
Size of encoded prompt: 8
Global seed set to 1234
DEBUG: IDX TYPE: torch.int32
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 8, 4096])
DEBUG: q, k, v shape: torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096]), torch.Size([1, 8, 4096])
DEBUG: IDX TYPE: torch.int32
DEBUG: x dtype: torch.float32
DEBUG: x shape before split: torch.Size([1, 1, 4096])
Traceback (most recent call last):
File "gen.py", line 189, in <module>
CLI(main)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 85, in CLI
return _run_component(component, cfg_init)
File "/usr/local/lib/python3.8/dist-packages/jsonargparse/cli.py", line 147, in _run_component
return component(**cfg)
File "gen.py", line 164, in main
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "gen.py", line 67, in generate
logits = model(x, max_seq_length, input_pos)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/lightning/fabric/wrappers.py", line 109, in forward
output = self._forward_module(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 118, in forward
x, self.kv_caches[i] = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 163, in forward
h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/app/lit-llama/lit_llama/model.py", line 199, in forward
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1502, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1511, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 2306, in forward
out = linear_fn(*args)
File "/usr/local/lib/python3.8/dist-packages/transformer_engine/pytorch/module.py", line 1671, in forward
assert (
AssertionError: Input and weight dimensions are not compatible for FP8 execution. I really do not know what to do now, I tried many things over the last hours, I would greatly appreciate if you could implement the padding just like you imagined it and I would be happy to build upon that and test it out (I will keep the cloud GPU running, so if you find the time I would be very glad to test it soon) |
@28Smiles That seems like a completely separate issue to H100 support. Can you open a different issue? |
Hello!
As I asked on the Discord, here is the issue on implementing NVIDIA's Transformer Engine with compute capability 9 (H100 GPU).
I would really love to see and help with implementing that!
Thank you very much 😊
The text was updated successfully, but these errors were encountered: