Skip to content

Commit

Permalink
Do not install unnecessary dependencies (#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Aug 15, 2023
1 parent 24d608c commit fdc142b
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 160 deletions.
2 changes: 1 addition & 1 deletion .github/azure-gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'"
displayName: 'Image info & NVIDIA'
- script: pip install pytest pytest-rerunfailures -r requirements.txt transformers einops
- script: pip install pytest pytest-rerunfailures -r requirements.txt transformers einops bitsandbytes scipy tokenizers zstandard
displayName: 'Install dependencies'

- bash: pytest -v --durations=10 --disable-pytest-warnings --strict-markers --color=yes
Expand Down
18 changes: 16 additions & 2 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,24 @@ jobs:
requirements.txt
setup.py
- name: Run tests without the package installed
- name: Install minimal dependencies
run: |
pip install --index-url https://download.pytorch.org/whl/nightly/cpu --pre torch>=2.1.0dev
pip install pytest pytest-rerunfailures -r requirements.txt transformers einops
pip install -r requirements.txt
pip list
# make sure all modules are importable
modules=$(
find * -type f -name "*.py" | \
grep -v tests | grep "/" | grep -v lm_eval | \
sed 's/\.py$//' | sed 's/\//./g' | \
sed 's/.__init__//g' | xargs -I {} echo "import {};"
)
echo "$modules"
python -c "$modules"
- name: Run tests without the package installed
run: |
pip install pytest pytest-rerunfailures transformers einops bitsandbytes scipy tokenizers zstandard
pip list
pytest --disable-pytest-warnings --strict-markers --color=yes
Expand Down
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ Hackable [implementation](lit_gpt/model.py) of state-of-the-art open-source larg

Supports the following popular model checkpoints:

| Model and usage | Reference |
|--------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| Meta AI [Llama 2](tutorials/download_llama_2.md) | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| TII UAE [Falcon](tutorials/download_falcon.md) | [TII 2023](https://falconllm.tii.ae) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | [Li et al. 2023](https://lmsys.org/blog/2023-06-29-longchat) |
| Together [RedPajama-INCITE](tutorials/download_redpajama_incite.md) | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| EleutherAI [Pythia](tutorials/download_pythia.md) | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| StabilityAI [StableLM](tutorials/download_stablelm.md) | [Stability AI 2023](https://github.com/Stability-AI/StableLM)
| Model and usage | Reference |
|---------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| Meta AI [Llama 2](tutorials/download_llama_2.md) | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Stability AI [FreeWilly2](tutorials/download_freewilly_2.md) | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| TII UAE [Falcon](tutorials/download_falcon.md) | [TII 2023](https://falconllm.tii.ae) |
| OpenLM Research [OpenLLaMA](tutorials/download_openllama.md) | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| LMSYS [Vicuna](tutorials/download_vicuna.md) | [Li et al. 2023](https://lmsys.org/blog/2023-06-29-longchat) |
| Together [RedPajama-INCITE](tutorials/download_redpajama_incite.md) | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| EleutherAI [Pythia](tutorials/download_pythia.md) | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| StabilityAI [StableLM](tutorials/download_stablelm.md) | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |

This implementation extends on [Lit-LLaMA](https://github.com/lightning-AI/lit-llama) and [nanoGPT](https://github.com/karpathy/nanoGPT), and it's **powered by [Lightning Fabric](https://lightning.ai/docs/fabric/stable/)**.

Expand Down Expand Up @@ -109,10 +109,10 @@ pip install --index-url https://download.pytorch.org/whl/nightly/cpu --pre 'torc
MAX_JOBS=4 pip install 'flash-attn>=2.0.0.post1' --no-build-isolation
```

All good, now install the dependencies:
All good, now install the dependencies plus some optional ones:

```bash
pip install -r requirements.txt
pip install -r requirements.txt tokenizers sentencepiece
```

You are all set! 🎉
Expand Down
286 changes: 155 additions & 131 deletions quantize/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,147 +10,169 @@
from typing import Optional

import torch
from datasets import load_dataset
from lightning import Fabric

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

import triton
import triton.language as tl
from lightning_utilities.core.imports import RequirementCache

from lit_gpt import GPT, Config, Tokenizer
from lit_gpt.utils import check_valid_checkpoint_dir, lazy_load


# This is adapted from the OpenAI Triton matmul example.
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8}, num_stages=5, num_warps=2
),
],
key=["M", "N", "K"],
)
@triton.jit
def linear_kernel_4bit_weight(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
bscales_ptr,
bzeros_ptr,
# bdequant,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.T.
A has shape (M, K), B has shape (N, K) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + ((offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn)

bscales_ptrs = bscales_ptr + offs_bn[None, :]
bzeros_ptrs = bzeros_ptr + offs_bn[None, :]

scale = tl.load(bscales_ptrs)
zero = tl.load(bzeros_ptrs)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# wasteful as it is to load everything twice, my attempts at avoiding it lead to slower code
b12 = tl.load(b_ptrs, mask=b_mask)
# Note that for simplicity, we don't apply a mask in K here.
a = tl.load(a_ptrs, mask=a_mask).to(tl.float32)
b = (((b12.to(tl.uint8) >> ((offs_k[:, None] % 2) * 4)) & 0xF).to(tl.float32) - zero) * scale
accumulator += tl.dot(a, b)

# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
c = accumulator

# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
_TRITON_AVAILABLE = RequirementCache("triton")
if _TRITON_AVAILABLE:
import triton
import triton.language as tl

# This is adapted from the OpenAI Triton matmul example.
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8},
num_stages=5,
num_warps=2,
),
],
key=["M", "N", "K"],
)
@triton.jit
def linear_kernel_4bit_weight(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
bscales_ptr,
bzeros_ptr,
# bdequant,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.T.
A has shape (M, K), B has shape (N, K) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse
# See above `L2 Cache Optimizations` section for details
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# a_ptrs is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# b_ptrs is a block of [BLOCK_SIZE_K, BLOCK_SIZE_n] pointers
# see above `Pointer Arithmetics` section for details
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_mask = offs_am[:, None] < M
b_mask = offs_bn[None, :] < N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + ((offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn)

bscales_ptrs = bscales_ptr + offs_bn[None, :]
bzeros_ptrs = bzeros_ptr + offs_bn[None, :]

scale = tl.load(bscales_ptrs)
zero = tl.load(bzeros_ptrs)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
# wasteful as it is to load everything twice, my attempts at avoiding it lead to slower code
b12 = tl.load(b_ptrs, mask=b_mask)
# Note that for simplicity, we don't apply a mask in K here.
a = tl.load(a_ptrs, mask=a_mask).to(tl.float32)
b = (((b12.to(tl.uint8) >> ((offs_k[:, None] % 2) * 4)) & 0xF).to(tl.float32) - zero) * scale
accumulator += tl.dot(a, b)

# Advance the ptrs to the next K block
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
c = accumulator

# -----------------------------------------------------------
# Write back the block of the output matrix C
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def qlinear_4bit_weight(inp, weight, scales, zeros):
Expand Down Expand Up @@ -446,6 +468,8 @@ def quantize(self):


def get_sample_data():
from datasets import load_dataset

traindata = load_dataset(
"allenai/c4", "allenai--c4", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, split="train"
)
Expand Down
Loading

0 comments on commit fdc142b

Please sign in to comment.