Skip to content

Commit

Permalink
Initial FSDP Support for QLoRA Finetuning (#970)
Browse files Browse the repository at this point in the history
This PR adds initial FSDP support for training QLoRA models. It enables basic FSDP and CPU Offload support, with low memory training via FSDP.sync_module_states option unsupported.

This PR builds off of #840 commit 8278fca and BNB FSDP by @TimDettmers and @Titus-von-Koeller.

An example of using this PR to finetune QLoRA models with FSDP can be found in the demo repo: AnswerDotAi/fsdp_qlora.

* Minimal changes for fp32 4bit storage from BNB commit 8278fca

* Params4bit with selectable storage dtype

* possible fix for double quantizing linear weight & quant storage dtype

* minor fixes in Params4bit for peft tests

* remove redundant

* add float16

* update test

* Remove float16 quant cast as there are fp32, bf16, & fp16 quant kernels

---------

Co-authored-by: Kerem Turgutlu <[email protected]>
  • Loading branch information
warner-benjamin and KeremTurgutlu authored Jan 17, 2024
1 parent 64a28d0 commit dcfb6f8
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 33 deletions.
25 changes: 13 additions & 12 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState
qs_dict: based on state_dict, with only relevant keys, striped of prefixes.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items.
"""

# unpacking tensor with non-tensor components
Expand Down Expand Up @@ -802,7 +802,7 @@ def dequantize_blockwise(

if quant_state is None:
quant_state = QuantState(absmax=absmax, code=code, blocksize=blocksize, dtype=torch.float32)

absmax = quant_state.absmax
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
Expand Down Expand Up @@ -884,13 +884,13 @@ def get_4bit_type(typename, device=None, blocksize=64):
return data.to(device)


def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4', quant_storage)

def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4')
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage)

def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> Tensor:
"""
Quantize tensor A in blocks of 4-bit values.
Expand All @@ -903,7 +903,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor (8-bit).
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
Expand All @@ -912,7 +912,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
Returns
-------
torch.Tensor:
The 8-bit tensor with packed 4-bit values.
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
Expand All @@ -931,7 +931,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz


if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
mod = dtype2bytes[quant_storage] * 2
out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]

Expand Down Expand Up @@ -985,7 +986,7 @@ def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor =
Parameters
----------
A : torch.Tensor
The input 8-bit tensor (packed 4-bit values).
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
Expand Down Expand Up @@ -1626,7 +1627,7 @@ def gemv_4bit(
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)

if B.dtype == torch.uint8:
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(absmax), get_ptr(state.code), get_ptr(out), lda, ldb, ldc, ct.c_int32(state.blocksize))
elif A.dtype == torch.bfloat16:
Expand Down
60 changes: 45 additions & 15 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,18 @@ def forward(self, input: Tensor) -> Tensor:


class Params4bit(torch.nn.Parameter):

def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_state: QuantState = None, blocksize: int = 64, compress_statistics: bool = True, quant_type: str = 'fp4') -> "Params4bit":
def __new__(
cls,
data: Optional[torch.Tensor] = None,
requires_grad=True,
quant_state: QuantState = None,
blocksize: int = 64,
compress_statistics: bool = True,
quant_type: str = 'fp4',
quant_storage: torch.dtype = torch.uint8,
module: Optional["Linear4bit"] = None,
bnb_quantized: bool = False
) -> "Params4bit":
if data is None:
data = torch.empty(0)

Expand All @@ -151,7 +161,10 @@ def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad=True, quant_
self.compress_statistics = compress_statistics
self.quant_type = quant_type
self.quant_state = quant_state
self.quant_storage = quant_storage
self.bnb_quantized = bnb_quantized
self.data = data
self.module = module
return self

@classmethod
Expand All @@ -162,16 +175,23 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any],
self.blocksize = self.quant_state.blocksize
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
self.bnb_quantized = True
return self

def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics,
quant_type=self.quant_type, quant_storage=self.quant_storage)
self.data = w_4bit
self.quant_state = quant_state

if self.module is not None:
self.module.quant_state = quant_state
self.bnb_quantized = True
return self

def cuda(self, device: Optional[Union[int, device, str]] = None, non_blocking: bool = False):
return self.to(device='cuda' if device is None else device, non_blocking=non_blocking)

@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
...
Expand All @@ -187,8 +207,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)

if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
return self.cuda(device)
if (device is not None and device.type == "cuda" and not self.bnb_quantized):
return self._quantize(device)
else:
if self.quant_state is not None:
self.quant_state.to(device)
Expand All @@ -203,12 +223,14 @@ def to(self, *args, **kwargs):

class Linear4bit(nn.Linear):

def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', device=None):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4', quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage, module=self)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
self.compute_type_is_set = False
self.quant_state = None
self.quant_storage = quant_storage

def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
Expand Down Expand Up @@ -243,7 +265,15 @@ def forward(self, x: torch.Tensor):
self.bias.data = self.bias.data.to(x.dtype)

if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if getattr(self, 'quant_state', None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight.quant_state = self.quant_state
else:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
Expand All @@ -261,8 +291,8 @@ def forward(self, x: torch.Tensor):


class LinearFP4(Linear4bit):
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', device)
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4', quant_storage, device)


class LinearNF4(Linear4bit):
Expand All @@ -276,8 +306,8 @@ class LinearNF4(Linear4bit):
Implementation of the NF4 data type in bitsandbytes can be found in the `create_normal_map` function in
the `functional.py` file: https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L236.
'''
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', device)
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_storage=torch.uint8, device=None):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4', quant_storage, device)


class Int8Params(torch.nn.Parameter):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2370,7 +2370,8 @@ def test_normal_map_tree():
@pytest.mark.parametrize("storage_type", ['nf4', 'fp4'], ids=['nf4', 'fp4'])
@pytest.mark.parametrize("kind", ['fc1', 'fc2', 'attn', 'attn_packed'], ids=['fc1', 'fc2', 'attn', 'attn_packed'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, double_quant, kind):
@pytest.mark.parametrize("quant_storage", [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=['uint8', 'fp16', 'bf16', 'fp32'])
def test_gemv_4bit(dtype, storage_type, quant_storage, double_quant, kind):
for dim in [128, 256, 512, 1024]:
#for dim in [4*1024]:
#for dim in [1*16]:
Expand Down Expand Up @@ -2399,7 +2400,7 @@ def test_gemv_4bit(dtype, storage_type, double_quant, kind):
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(dim*3, dim, dtype=dtype, device='cuda')/math.sqrt(dim)

qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant, quant_storage=quant_storage)
C3 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
A.requires_grad = True
Expand Down
42 changes: 38 additions & 4 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@

import bitsandbytes as bnb

storage = {
'uint8': torch.uint8,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
'float32': torch.float32
}

@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize(
"quant_type, compress_statistics, bias",
list(product(["nf4", "fp4"], [False, True], [False, True])),
"quant_type, compress_statistics, bias, quant_storage",
list(product(["nf4", "fp4"], [False, True], [False, True], ['uint8', 'float16', 'bfloat16', 'float32'])),
)
def test_linear_serialization(quant_type, compress_statistics, bias):
def test_linear_serialization(quant_type, compress_statistics, bias, quant_storage):
original_dtype = torch.float16
compute_dtype = None
device = "cuda"
Expand All @@ -32,7 +38,7 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
quant_type=quant_type,
device="meta",
)
new_weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False)
new_weight = bnb.nn.Params4bit(data=linear.weight, quant_type=quant_type, requires_grad=False)
linear_q.weight = new_weight
if bias:
linear_q.bias = torch.nn.Parameter(linear.bias)
Expand Down Expand Up @@ -65,6 +71,22 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
# MATCHING
a, b = linear_q.weight, linear_q2.weight

# Quantizing original layer with specified quant_storage type
linear_qs = bnb.nn.Linear4bit(
linear.in_features,
linear.out_features,
bias=bias,
compute_dtype=compute_dtype,
compress_statistics=compress_statistics,
quant_type=quant_type,
quant_storage=storage[quant_storage],
device="meta",
)
linear_qs.weight = bnb.nn.Params4bit(data=linear.weight, requires_grad=False, quant_type=quant_type, quant_storage=storage[quant_storage])
if bias:
linear_qs.bias = torch.nn.Parameter(linear.bias)
linear_qs = linear_qs.to(device)

assert a.device == b.device
assert a.dtype == b.dtype
assert torch.equal(a, b)
Expand Down Expand Up @@ -96,9 +118,21 @@ def test_linear_serialization(quant_type, compress_statistics, bias):
x = torch.rand(42, layer_shape[0], device=device)
a = linear_q(x)
b = linear_q2(x)
c = linear_qs(x)
assert a.device == b.device
assert a.dtype == b.dtype
assert a.device == c.device
assert a.dtype == c.dtype
assert torch.equal(a, b)
assert torch.equal(a, c)

# Test moving to CPU and back to GPU
linear_q2.to('cpu')
linear_q2.to(device)
d = linear_qs(x)
assert c.dtype == d.dtype
assert c.device == d.device
assert torch.equal(c, d)

# Saved size ratio test. Target set for layer_shape == (300, 400) w/ bias
with TemporaryDirectory() as tmpdir:
Expand Down

0 comments on commit dcfb6f8

Please sign in to comment.