Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RWKV - Inference NF4 quantization broken, also Int8 quantization weirdness. #23848

Closed
2 of 4 tasks
iantbutler01 opened this issue May 29, 2023 · 14 comments · Fixed by #23910 or #26134
Closed
2 of 4 tasks

RWKV - Inference NF4 quantization broken, also Int8 quantization weirdness. #23848

iantbutler01 opened this issue May 29, 2023 · 14 comments · Fixed by #23910 or #26134

Comments

@iantbutler01
Copy link

System Info

  • transformers version: 4.30.0.dev0
  • Platform: Linux-5.15.0-70-generic-x86_64-with-glibc2.35
  • Python version: 3.10.6
  • Huggingface_hub version: 0.14.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: RTX 6000 Ada
  • Using distributed or parallel set-up in script?: Not for inference.
  • bitsandbytes 0.39.

I'm using the RWKV/rwkv-raven-14b model.

Rescaling is broken for NF4 quantization with RWKV

RuntimeError: result type Float can't be cast to the desired output type Byte

Looks like torch cannot do the conversion in _div

And then if I turn rescaling off, it looks like theres a projection issue somewhere,
RuntimeError: mat1 and mat2 shapes cannot be multiplied (43x5120 and 1x13107200)

Additionally, with Int8 quantization enabled RWKV just outputs the endoftext token, I added a logits processor to output the scores and they're all NaN:

tensor([[nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       dtype=torch.float16)

Who can help?

@sgugger

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I have a repo with everything setup in generate.py to be able to quickly repro here:
https://github.com/iantbutler01/rwkv-raven-qlora-4bit-instruct/blob/main/generate.py

pip install -U git+https://github.com/huggingface/transformers.git
pip install -U git+https://github.com/huggingface/peft.git
pip install -U git+https://github.com/huggingface/accelerate.git
pip install --upgrade bitsandbytes

And then run python generate.py in a python 3.10+ environment. Uncomment 8bit or 4bit bnb config as needed.

Expected behavior

I would expect NF4 based quantization to work at all, and then for Int8 quantization for logits not to be NaN.

@iantbutler01 iantbutler01 changed the title RWKV - NF4 quantization broken, also Int8 quantization weirdness. RWKV - Inference NF4 quantization broken, also Int8 quantization weirdness. May 29, 2023
@sgugger
Copy link
Collaborator

sgugger commented May 30, 2023

Not sure quantization actually works for RWKV which has quite a few custom layers. cc @younesbelkada

@iantbutler01
Copy link
Author

iantbutler01 commented May 30, 2023

Hmm, I was able to do a 4bit finetuning with qlora last week at the very least targeting key value and receptance in the attention and feed forward blocks, it just seems like inference time is broken

I confirmed my tuned checkpoints worked fine for inference at full precision and actually it worked fine for just the forward call in 8bit in Eleuther's lm-evaluation-harness too now that I think of it, not sure for 4bit. Just seems to break when calling generate

@younesbelkada
Copy link
Contributor

younesbelkada commented May 31, 2023

Hi @iantbutler01
Thanks for the issue!
The 8bit support should be added in #23468
From my understanding it seems you have managed to finetune RWKV in 4bit ?

Hmm, I was able to do a 4bit finetuning with qlora last week at the very least targeting key value and receptance in the attention and feed forward blocks

Could you elaborate more on the error?

@iantbutler01
Copy link
Author

@younesbelkada

In regards to int8, I've been testing on the development branch, which includes the code you've merged there and it very much just produces tensor([[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0', dtype=torch.float16) for the logits during a generate call even with the base RWKV 14b model so I think something is still broken. You can reproduce this easily with the steps I've linked in the issue here.

For example, with

AndBytesConfig(
    load_in_8bit=True
)

model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-raven-14b",
    return_dict=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    context_length=1024,
    # rescale_every=0,
).cuda()

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-raven-14b")

pipeline = InstructionTextGenerationPipeline(
    model=model,
    tokenizer=tokenizer,
    top_p=0.92,
    top_k=50,
    temperature=1.0,
)
instruction = "Write me the steps to make a peanut butter and jelly sandwich"
prompt = PROMPT_FOR_GENERATION_FORMAT.format(
    instruction=instruction,
)

class IsBork(LogitsProcessor):
    def __call__(self, input_ids, scores):
        print(scores)
        return scores
    
prompt = str(prompt)
inputs = tokenizer(prompt, return_tensors="pt")

input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]
input_ids, attention_mask = input_ids.to("cuda"), attention_mask.to("cuda")

generated_sequence = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    logits_processor=LogitsProcessorList([IsBork()]),
    pad_token_id=tokenizer.pad_token_id,
    top_p=0.92,
    top_k=50,
    temperature=1.0,
    max_new_tokens=512
)

print(generated_sequence)

The call to generate raises an error,

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 171, in <module>
    gen = pipeline(prompt, max_new_tokens=512)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/pipelines/base.py", line 1118, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/pipelines/base.py", line 1125, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/pipelines/base.py", line 1024, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/instruct_pipeline.py", line 112, in _forward
    generated_sequence = self.model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1568, in generate
    return self.sample(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2651, in sample
    next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0      

Adding a logits processor that just prints out scores shows on the first token generated,

tensor([[nan, nan, nan, ..., nan, nan, nan]], device='cuda:0', dtype=torch.float16)

If I then set do_sample=False

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Write me the steps to make a peanut butter and jelly sandwich

### Response:
<|endoftext|>

It only generates end of text, where as the full precision model generates correctly.

@iantbutler01
Copy link
Author

iantbutler01 commented May 31, 2023

In regards to 4bit rescaling during inference is broken for NF4 quantization with RWKV if you try to run inference, with a generate call with nf4 quantization:

RuntimeError: result type Float can't be cast to the desired output type Byte
which is failing in the else statement of that block your int8 PR touches.

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 181, in <module>
    generated_sequence = model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 781, in forward
    rwkv_outputs = self.rwkv(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 642, in forward
    self._rescale_layers()
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 713, in _rescale_layers
    block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))

And then if I turn rescaling off by setting rescale_every=0, it looks like theres a projection issue somewhere,
RuntimeError: mat1 and mat2 shapes cannot be multiplied (43x5120 and 1x13107200)

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 181, in <module>
    generated_sequence = model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 781, in forward
    rwkv_outputs = self.rwkv(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 667, in forward
    hidden_states, state, attentions = block(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 384, in forward
    attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 308, in forward
    receptance, key, value, state = self.extract_key_value(hidden, state=state)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 300, in extract_key_value
    key = self.key(key)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 219, in forward
    out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 564, in matmul_4bit
    return MatMul4Bit.apply(A, B, out, bias, quant_state)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 512, in forward
    output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (42x5120 and 1x13107200)

But yeah I have this all reproducible in the script I've linked in the issue.

@younesbelkada
Copy link
Contributor

I see, thanks for sharing more details with me
So there are 2 issues here:

1- int8 RWKV seems to not work with you. From the snippet I am seeing, you are calling .cuda() on the 8bit model. This might lead to unexpected behavior because any .to(xxx) calls to the 8bit model will re-compute the quantization statistics.
I have managed to reproduce your issue with the snippet below:

from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

model_id = "RWKV/rwkv-4-1b5-pile"

model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0}).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_id)

generation_config = GenerationConfig(max_new_tokens=20, pad_token_id=tokenizer.eos_token_id)
question = "Hello my name is"
inputs = tokenizer(question, return_tensors="pt").to(0)
output_int8 = model.generate((inputs["input_ids"]), generation_config=generation_config)
print(tokenizer.decode(output_int8[0], skip_special_tokens=True))

and the model directly predicts EOS token. The fix is to replace model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0}).cuda() by model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map={"":0}). Could you confirm this fixes your issue?

2- RWKV + 4bit seems to be not supported for now. I will dig into that and let you know as soon as I have a fix

@younesbelkada
Copy link
Contributor

I just added the 4bit inference support for RWKV in #23910 - please try out the fixes stated above together with #23910 and let us know how it goes

@iantbutler01
Copy link
Author

@younesbelkada

Okay so 8bit is working fine now, thank you very much for the workaround!

4bit loaded in with this configuration:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)


model = AutoModelForCausalLM.from_pretrained(
    "RWKV/rwkv-raven-14b",
    return_dict=True,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    context_length=1024,
    # rescale_every=0,
    device_map={"":0}
)

Is still failing unfortunately, :(

Traceback (most recent call last):
  File "/home/crow/SoftwareProjects/rwkv-raven-lora-instruct/generate.py", line 182, in <module>
    generated_sequence = model.generate(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 1518, in generate
    return self.greedy_search(
  File "/home/crow/SoftwareProjects/transformers/src/transformers/generation/utils.py", line 2335, in greedy_search
    outputs = self(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 789, in forward
    rwkv_outputs = self.rwkv(
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/crow/venvs/experimental/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 642, in forward
    self._rescale_layers()
  File "/home/crow/SoftwareProjects/transformers/src/transformers/models/rwkv/modeling_rwkv.py", line 714, in _rescale_layers
    block.attention.output.weight.quant_state[0].div_(
RuntimeError: result type Float can't be cast to the desired output type Byte

@younesbelkada
Copy link
Contributor

I see, this is because you are using nested quantization bnb_4bit_use_double_quant=True. Can you try without that while I find a fix for this specific usecase? 🙏

@iantbutler01
Copy link
Author

Yes sorry about that, I had always intended this to be with double quant, that was in my original repro code, but I should have been more explicit when communicating it to you 👍

I tried it without double quantization and it does work.

@younesbelkada
Copy link
Contributor

No problem and thanks for double checking, will get back once I fix the issue with nested quantization!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Aug 7, 2023
@jonataslaw
Copy link

jonataslaw commented Aug 9, 2023

I think It should not be closed @younesbelkada

@younesbelkada younesbelkada reopened this Aug 17, 2023
@younesbelkada
Copy link
Contributor

younesbelkada commented Aug 17, 2023

Correct, it is known that RWKV double-quant 4bit inference does not work yet, not sure if I can propose a fix anytime soon because of the rescale layers operation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py#L722

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants