-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pack/Unpack to Different Dtypes for FSDP #14
Comments
Hi @KeremTurgutlu, thank you for your question! |
You can find the related PR for bnb here: bitsandbytes-foundation/bitsandbytes#970
Also, BnB packing/unpack is dtype agnostic I think it is due to the way they implemented packing/unpacking logic in their kernels, for example: import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit
import torch
W = torch.randn(128,128)
W.dtype
param = Params4bit(W, quant_storage=torch.uint8)
w = param.data.contiguous().cuda(0)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=param.blocksize, compress_statistics=param.compress_statistics,
quant_type=param.quant_type, quant_storage=param.quant_storage)
w_dq_uint8 = bnb.functional.dequantize_4bit(w, quant_state)
param = Params4bit(W, quant_storage=torch.bfloat16)
w = param.data.contiguous().cuda(0)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=param.blocksize, compress_statistics=param.compress_statistics,
quant_type=param.quant_type, quant_storage=param.quant_storage)
w_dq_bf16 = bnb.functional.dequantize_4bit(w, quant_state)
assert torch.equal(w_dq_uint8,w_dq_bf16) I am not an expert in quantization but this might be the difference: https://github.com/TimDettmers/bitsandbytes/blob/e820409c095ea7cbb5ce156992307b84352cbf90/csrc/kernels.cu#L827 |
Understood! I think we can use int32 instead of uint8 for packing, then cast to float32. I did a few checks regarding the range of the values and it looks like it could work. Since 3-bit already uses int32, we can quickly test it: https://github.com/mobiusml/hqq/blob/master/hqq/core/bitpack.py#L68 replace these with: @staticmethod
def pack_3bit_32(W_q_in):
W_q = torch.zeros([int(10*np.ceil(W_q_in.shape[0]/10.)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32)
W_q[:len(W_q_in)] = W_q_in
_step = int(len(W_q)/10)
W_q = (W_q[:_step] << 27) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 21) | (W_q[_step*3:_step*4] << 18) | (W_q[_step*4:_step*5] << 15) | (W_q[_step*5:_step*6] << 12) | (W_q[_step*6:_step*7] << 9) | (W_q[7*_step:_step*8] << 6) | (W_q[_step*8:_step*9] << 3) | (W_q[_step*9:])
return W_q.to(torch.float32) #Now the stored quantized weights are float32
@staticmethod
def unpack_3bit_32(W_q):
W_q = W_q.to(torch.int32)
_step = W_q.shape[0]
tmp = torch.empty([10*_step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = ((W_q & 0b00111000000000000000000000000000) >> 27)
tmp[1*_step:2*_step] = ((W_q & 0b00000111000000000000000000000000) >> 24)
tmp[2*_step:3*_step] = ((W_q & 0b00000000111000000000000000000000) >> 21)
tmp[3*_step:4*_step] = ((W_q & 0b00000000000111000000000000000000) >> 18)
tmp[4*_step:5*_step] = ((W_q & 0b00000000000000111000000000000000) >> 15)
tmp[5*_step:6*_step] = ((W_q & 0b00000000000000000111000000000000) >> 12)
tmp[6*_step:7*_step] = ((W_q & 0b00000000000000000000111000000000) >> 9)
tmp[7*_step:8*_step] = ((W_q & 0b00000000000000000000000111000000) >> 6)
tmp[8*_step:9*_step] = ((W_q & 0b00000000000000000000000000111000) >> 3)
tmp[9*_step:] = ((W_q & 0b00000000000000000000000000000111))
return tmp Example:
So then we would need to implement bitpacking with int32 instead of uint8 for 8/4/2/1 bits + add their corresponding CUDA kernels |
Great, thanks! I will give it a try and try to update the CUDA kernels as well. |
Yes, once the quantized weights are packed to float32 they should not be touched. Let me know if the trick for the 3-bit case works with FSDP. Meanwhile, I will run more tests to make sure the casting doesn't create any issues. If it works, I can add new bitpacking with the same logic + their CUDA kernels. The bitpacking logic with int32 will be quite different but not too difficult to add. |
import numpy as np
def pack_4bit_32(W_q):
_step = int(len(W_q)/8)
W_q = (W_q[:_step] << 28) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 20) | (W_q[_step*3:_step*4] << 16) | (W_q[_step*4:_step*5] << 12) | (W_q[_step*5:_step*6] << 8) | (W_q[_step*6:_step*7] << 4) | (W_q[_step*7:])
return W_q
def unpack_4bit_32_cat(W_q):
return torch.cat([((W_q & 0b11110000000000000000000000000000) >> 28),
((W_q & 0b00001111000000000000000000000000) >> 24),
((W_q & 0b00000000111100000000000000000000) >> 20),
((W_q & 0b00000000000011110000000000000000) >> 16),
((W_q & 0b00000000000000001111000000000000) >> 12),
((W_q & 0b00000000000000000000111100000000) >> 8),
((W_q & 0b00000000000000000000000011110000) >> 4),
((W_q & 0b00000000000000000000000000001111))], axis=0)
#A bit faster than _cat version
def unpack_4bit_32(W_q):
_step = W_q.shape[0]
tmp = torch.empty([8*_step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = ((W_q & 0b11110000000000000000000000000000) >> 28)
tmp[1*_step:2*_step] = ((W_q & 0b00001111000000000000000000000000) >> 24)
tmp[2*_step:3*_step] = ((W_q & 0b00000000111100000000000000000000) >> 20)
tmp[3*_step:4*_step] = ((W_q & 0b00000000000011110000000000000000) >> 16)
tmp[4*_step:5*_step] = ((W_q & 0b00000000000000001111000000000000) >> 12)
tmp[5*_step:6*_step] = ((W_q & 0b00000000000000000000111100000000) >> 8)
tmp[6*_step:7*_step] = ((W_q & 0b00000000000000000000000011110000) >> 4)
tmp[7*_step:8*_step] = (W_q & 0b00000000000000000000000000001111)
return tmp
for i in range(100):
x = torch.randint(0,2**4,(32,)); #random 4-bit quantized 1D weights
pack_4bit_32(x)
assert torch.equal(unpack_4bit_32_cat(pack_4bit_32(x)), x)
for i in range(100):
x = torch.randint(0,2**4,(32,32)); #random 4-bit quantized 2D weights
assert torch.equal(unpack_4bit_32(pack_4bit_32(x)), x) Using int32 does seem to work! I will test FSDP training with 3bit HQQ Lora now. Edit: Actually casting back and forth breaks it.
The problem is that the first bit is used for sign. Maybe instead of packing 8 groups of 4 bit values, we can do 7 groups, it will cause |
So actually it only works with rows up to 8, the 2 rows after are not properly decoded. I will play with some toy examples and see if there's another way |
The problem seems the casting from float32 > int32 In [105]: torch.tensor([365215118], dtype=torch.int32).to(torch.float32).to(torch.int32)
Out[105]: tensor([365215104], dtype=torch.int32) |
This might be helpful? https://discuss.pytorch.org/t/bitwise-operation-on-float-tensor/170863 torch.tensor([365215118], dtype=torch.int32).view(torch.float32).view(torch.int32)```
We can probably use `view()` without needing to change it `pack_4bit_u8(x).view(torch.bfloat16), pack_4bit_u8(x).view(torch.float16)` |
Actually that was very helpful, now it works: import torch
import numpy as np
#Float32 cast
def pack_3bit_32(W_q_in):
W_q = torch.zeros([int(10*np.ceil(W_q_in.shape[0]/10.)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.float32).view(torch.int32)
W_q[:len(W_q_in)] = W_q_in
_step = int(len(W_q)/10)
W_q = (W_q[:_step] << 27) | (W_q[_step:_step*2] << 24) | (W_q[_step*2:_step*3] << 21) | (W_q[_step*3:_step*4] << 18) | (W_q[_step*4:_step*5] << 15) | (W_q[_step*5:_step*6] << 12) | (W_q[_step*6:_step*7] << 9) | (W_q[7*_step:_step*8] << 6) | (W_q[_step*8:_step*9] << 3) | (W_q[_step*9:])
return W_q.view(torch.float32)
def unpack_3bit_32(W_q):
W_q = (W_q).view(torch.int32)
_step = W_q.shape[0]
tmp = torch.empty([10*_step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
tmp[:_step] = ((W_q & 0b00111000000000000000000000000000) >> 27)
tmp[1*_step:2*_step] = ((W_q & 0b00000111000000000000000000000000) >> 24)
tmp[2*_step:3*_step] = ((W_q & 0b00000000111000000000000000000000) >> 21)
tmp[3*_step:4*_step] = ((W_q & 0b00000000000111000000000000000000) >> 18)
tmp[4*_step:5*_step] = ((W_q & 0b00000000000000111000000000000000) >> 15)
tmp[5*_step:6*_step] = ((W_q & 0b00000000000000000111000000000000) >> 12)
tmp[6*_step:7*_step] = ((W_q & 0b00000000000000000000111000000000) >> 9)
tmp[7*_step:8*_step] = ((W_q & 0b00000000000000000000000111000000) >> 6)
tmp[8*_step:9*_step] = ((W_q & 0b00000000000000000000000000111000) >> 3)
tmp[9*_step:] = ((W_q & 0b00000000000000000000000000000111))
return tmp
#######################################################################################################
W_q = torch.randint(low=0, high=(2**3), size=(4096, 4096))
W_q_packed = pack_3bit_32(W_q)
W_q_unpacked = unpack_3bit_32(W_q_packed)
assert torch.mean(1.*(W_q_unpacked[:len(W_q)]==W_q) ) == 1. # Works! I am not sure how that would work on the CUDA side |
Probably we can keep the existing packing logic and CUDA kernels in |
Yeah do Only works with float32, didn't work with float16/int16 |
By the way, for your test with 3-bit, you can still use the existing CUDA kernels for 3-bit, you just need to replace this: Hope it works ! |
Thanks for this great package!
I've noticed that the existing packing/unpacking only works with certain dtypes. FSDP requires all the params to be float dtype for sharding, so are there any plans to extend them to different dtypes?
The text was updated successfully, but these errors were encountered: