-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug Report] The output from HookedTransformer is not identical compared to Huggingface model for Lllama 3 #615
Comments
Thank you for being so thorough on this. We are in the middle of working out some tools that will make benchmarking models easier, and I think that is going to help us debug this a lot. If anyone sees this, and is able to take this on, that would be highly appreciative in the time being. If not, then hopefully those tools will be ready soon, so that we can more easily figure out what the deal is here. |
I observed similar problems with
|
I observe the same issue. It might be related to #570. There seems to be quite a large difference in logits between the two models. from matplotlib import pyplot as plt
import seaborn as sns
import torch
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer, LlamaForCausalLM
MODEL_PATH = 'meta-llama/Meta-Llama-3-8B-Instruct'
# Load huggingface model and tokenizer using LlamaForCausalLM and shard it on 8 gpus
hf_model = LlamaForCausalLM.from_pretrained(
MODEL_PATH
).to('cuda:0')
hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Load HookedTransformer model
hooked_model = HookedTransformer.from_pretrained_no_processing(
MODEL_PATH,
device="cuda",
dtype=torch.float32,
default_padding_side='left',
n_devices=1,
)
a_large_chunk_of_text = "Generate is not a good way to check a model is running properly. Can you run the following and share the results?"
tokens = hf_tokenizer(a_large_chunk_of_text, return_tensors="pt").input_ids.to('cuda:0')
logits_hf = hf_model(tokens).logits
logits_tl = hooked_model(tokens, return_type="logits")
logits_diff = (logits_hf - logits_tl)
logits_diff_last = logits_diff[:, -1, :]
print("TF Greedy:", logits_hf[:, -1, :].argmax(dim=-1), "Logit:", logits_hf[:, -1, :].max())
print("TL Greedy:", logits_tl[:, -1, :].argmax(dim=-1), "Logit:", logits_tl[:, -1, :].max())
# histogram of the difference between the logits from the hooked model and the huggingface model
sns.histplot(logits_diff_last.flatten().cpu().numpy(), bins=100)
# set title
plt.title("Difference between logits (last) from the hooked model and the huggingface model") |
Made a few more investigations. Seems like differences stem from both attn and mlp, (haven't appended the plot but the embedding matrix output is equal). hf_model_nnsight = NNsight(hf_model)
with hf_model_nnsight.trace(tokens):
hf_attn_out = hf_model_nnsight.model.layers[0].self_attn.output.save()
hf_mlp_out = hf_model_nnsight.model.layers[0].mlp.output.save()
hf_resid_post = hf_model_nnsight.model.layers[0].output.save()
_, cache = hooked_model.run_with_cache(tokens)
layer0_attn_out_diff = hf_attn_out[0] - cache["blocks.0.hook_attn_out"]
layer0_mlp_out_diff = hf_mlp_out[0] - cache["blocks.0.hook_mlp_out"]
layer0_resid_post_diff = hf_resid_post[0] - cache["blocks.0.hook_resid_post"]
# histogram of the difference between the attention output from the hooked model and the huggingface model
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 5))
sns.histplot(layer0_attn_out_diff.detach().flatten().cpu().numpy(), bins=100, ax=ax1)
ax1.set_title("Difference between attention output")
# histogram of the difference between the mlp output from the hooked model and the huggingface model
sns.histplot(layer0_mlp_out_diff.detach().flatten().cpu().numpy(), bins=100, ax=ax2)
ax2.set_title("Difference between mlp output")
# histogram of the difference between the residual post output from the hooked model and the huggingface model
sns.histplot(layer0_resid_post_diff.detach().flatten().cpu().numpy(), bins=100, ax=ax3)
ax3.set_title("Difference between residual post output")
plt.suptitle("Difference between outputs from the hooked model and the huggingface model in layer 0") |
@jkminder If you are looking into this, the thing to do is check the implementation of the model in |
@bryce13950 thanks, i already started digging a bit deeper here. Seems like the weights are slightly different, even though an equality test doesn't show this. Would be great to have a second pair of eyes on this! I am examining the MLPs in the first layer of llama3 (following my code above) Test: are all weights the same? -> Yes (hooked_model.blocks[0].mlp._parameters["W_gate"] == hf_model.model.layers[0].mlp.gate_proj.weight.T).all(), (hooked_model.blocks[0].mlp._parameters["W_in"] == hf_model.model.layers[0].mlp.up_proj.weight.T).all(), (hooked_model.blocks[0].mlp._parameters["W_out"] == hf_model.model.layers[0].mlp.down_proj.weight.T).all() (tensor(True, device='cuda:0'), Test: Is the output equal? No test_in = torch.ones(1, 1, 4096).to('cuda:0')
hf_out = hf_model.model.layers[0].mlp(test_in).detach().cpu()
hooked_out = hooked_model.blocks[0].mlp(test_in).detach().cpu()
isclose = torch.isclose(hf_out,hooked_out)
(~is_close).sum(), is_close.shape (tensor(9), tensor(False)) Test: Zooming in on W_gate -> same issue from fancy_einsum import einsum
hf_w_gate_out = hf_model.model.layers[0].mlp.gate_proj(test_in).detach().cpu()
# https://github.com/TransformerLensOrg/TransformerLens/blob/318236402ddcc9cabace3f2fca40c71f0c2e9e57/transformer_lens/components/gated_mlp.py#L102C13-L107C18
hooked_w_gate_out = einsum(
"batch pos d_model, d_model d_mlp -> batch pos d_mlp",
test_in,
hooked_model.blocks[0].mlp.W_gate,
).detach().cpu()
is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)
(~is_close).sum(), is_close.shape (tensor(9), torch.Size([1, 1, 14336])) Test: is this an issue with fancy einsum/does opt einsum work? No from opt_einsum import contract
hf_w_gate_out = hf_model.model.layers[0].mlp.gate_proj(test_in).detach().cpu()
# Adapted from https://github.com/TransformerLensOrg/TransformerLens/blob/318236402ddcc9cabace3f2fca40c71f0c2e9e57/transformer_lens/components/gated_mlp.py#L102C13-L107C18
hooked_w_gate_out = contract("bpk, kd -> bpd",
test_in,
hooked_model.blocks[0].mlp.W_gate,
).detach().cpu()
is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)
(~is_close).sum(), is_close.shape (tensor(9), torch.Size([1, 1, 14336])) Test: How about just normal torch matrix product? Same issue hooked_w_gate_out = (test_in @ hooked_model.blocks[0].mlp.W_gate).detach().cpu()
is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)
(~is_close).sum(), is_close.shape (tensor(9), torch.Size([1, 1, 14336])) Test: Convert to linear and then calc -> no lin = torch.nn.Linear(4096, 14336, bias=False)
lin.weight = torch.nn.Parameter(hooked_model.blocks[0].mlp.W_gate.T)
hooked_w_gate_out = lin(test_in).detach().cpu()
is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)
(~is_close).sum(), is_close.shape (tensor(9), torch.Size([1, 1, 14336])) **Test: Verify that the implementation of linear is not the problem -> no ** hooked_w_gate_out = torch.nn.functional.linear(test_in, hooked_model.blocks[0].mlp.W_gate.T).detach().cpu()
hf_w_gate_out = torch.nn.functional.linear(test_in, hf_model.model.layers[0].mlp.gate_proj.weight).detach().cpu()
is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)
(~is_close).sum(), (hf_w_gate_out != hooked_w_gate_out).sum() (tensor(9), tensor(4819)) **Test: Replace the hooked weights with the HF weights -> solves it ** hooked_model.blocks[0].mlp.W_gate.data = hf_model.model.layers[0].mlp.gate_proj.weight.T
hf_w_gate_out = hf_model.model.layers[0].mlp.gate_proj(test_in).detach().cpu()
# https://github.com/TransformerLensOrg/TransformerLens/blob/318236402ddcc9cabace3f2fca40c71f0c2e9e57/transformer_lens/components/gated_mlp.py#L102C13-L107C18
hooked_w_gate_out = einsum(
"batch pos d_model, d_model d_mlp -> batch pos d_mlp",
test_in,
hooked_model.blocks[0].mlp.W_gate,
).detach().cpu()
is_close = torch.isclose(hf_w_gate_out, hooked_w_gate_out)
(~is_close).sum(), is_close.shape (tensor(0), torch.Size([1, 1, 14336])) I'm really unsure what to make of this behaviour, but I guess it could explain why we are seeing differences in outputs. I assume there are slight changes in the weights, which result in numerical issues. Wierdly when comparing them they all are equal. Any insights on what could be done to solve this? I guess one could switch to torch.functional calls for the matrix multiplications. Also in case you don't wanna copy paste, here is my notebook: https://gist.github.com/jkminder/d05d708f3f93c66037ac7f0c352eefa4 |
@jkminder Sorry for the delay on getting back to you. we have some pretty major changes coming to existing components in the same vein as what you are bringing up. At the moment, I am completely focused on wrapping up a number of changes for mixtral, which is similar to this, and after that I intend to start playing with this model to see what the status is at that point. Hopefully that will be all done, merged, and released within the next few days. Let's relook this at that point. Some more info on the details of what is being discussed can be found in #645. Regardless, more work is likely to be needed here, but this will benefit from the work being done. |
Any updates on this? I used llama3-8b-instruct with Both were run using |
Can you share some of your generations as an example on this issue? I have a list of models that need to be investigated, and generally the meta-llama family hasn't had too many major issues. In my tests, there are other models that are higher priority, but if you have found a pretty large issue with TransformerLens generation for these models, that may push that up the priority queue. |
Thank you for your prompt response. Here is an example.
transformerlens:
vllm:
vllm never starts a response with I apologize, and the generations are quite different when using transformerlens and vllm. |
Very interesting. Finally, if you can share the code you are using in vllm to generate this text, that will be very useful to expedite recreating this issue. |
Hi!
|
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
The generations from huggingface model (LlamaForCausalLM) and HookedTransformer are different
Code example
The result from Hooked model is different compared to Huggingface model. The outputs from Huggingface model is identical to results from LLM arena, therefore there is likely a bug in HookedTransformer implementation.
System Info
Describe the characteristic of your environment:
transformer_lens
was installed (pip, docker, source, ...): !pip install transformers transformers_stream_generator tiktoken transformer_lens einops jaxtyping coloramaAdditional context
Add any other context about the problem here.
Checklist
This may be related to #385
The text was updated successfully, but these errors were encountered: