Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smoothquant LLaMA builds not working on 0.8.0 release #1267

Closed
2 of 4 tasks
ghost opened this issue Mar 11, 2024 · 2 comments
Closed
2 of 4 tasks

Smoothquant LLaMA builds not working on 0.8.0 release #1267

ghost opened this issue Mar 11, 2024 · 2 comments
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@ghost
Copy link

ghost commented Mar 11, 2024

System Info

GPU : NVIDIA A100 80GB

package version
tensorrt-9.2.0.post12.dev5-cp310-none-linux_x86_64.whl
[TensorRT-LLM] TensorRT-LLM version: 0.8.00.8.0

Who can help?

@Tracin @byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Installation
    python -m pip install tensorrt_llm==0.8.0 --extra-index-url https://pypi.nvidia.com

  2. Create smoothquant checkpoint for LLaMA
    python ./examples/llama/convert_checkpoint.py --model_dir ~/Llama-2-13b-chat-hf --output_dir ~/fp16-tp4-sq5 --dtype float16 --tp_size 4 --smoothquant 0.5 --per_token --per_channel --workers 4

Expected behavior

Checkpoint should be created.

actual behavior

Error at line - https://github.com/NVIDIA/TensorRT-LLM/blob/v0.8.0/examples/llama/convert_checkpoint.py#L1502

ValueError: You are trying to save a non contiguous tensor: transformer.layers.0.attention.qkv.weight which is not allowed. It either means you are trying to save tensors which are reference of each other in which case it's recommended to save only the full tensors, and reslice at load time, or simply call .contiguous() on your tensor to pack it before saving.

additional notes

No such error seen on release 0.7.1

My guess is that the function get_tllm_linear_sq_weight returns some non-contiguous tensors.

@ghost ghost added the bug Something isn't working label Mar 11, 2024
@byshiue
Copy link
Collaborator

byshiue commented Mar 14, 2024

Thank you for the report. We will fix it in next update. As you mentioned, you could add .contiguous() for

results[prefix + 'weight'] = torch.from_numpy(cur_weights).t().clone()

in get_tllm_linear_sq_weight.

@byshiue byshiue self-assigned this Mar 14, 2024
@byshiue byshiue added the triaged Issue has been triaged by maintainers label Mar 14, 2024
@ghost ghost closed this as completed Mar 14, 2024
@Hukongtao
Copy link

Qwen has the same error, When I run

python3 examples/qwen/convert_checkpoint.py \
    --model_dir ./Qwen-72B-Chat-Int4/ \
    --output_dir./Qwen-72B-Chat-Int4-TRT/tllm_checkpoint_1gpu_gptq/ \
    --dtype float16 \
    --use_weight_only \
    --weight_only_precision int4_gptq \
    --tp_size 4 \
    --pp_size 1 \
    --per_group

I got:
ValueError: You are trying to save a non contiguous tensor: transformer.layers.0.mlp.gate.weights_scaling_factor which is not allowed. It either means you are trying to save tensors which are reference of each other in which case it's recommended to save only the full tensors, and reslice at load time, or simply call .contiguous() on your tensor to pack it before saving.

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants