Skip to content

Commit

Permalink
Merge branch 'master' into sup_xpu
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Sep 22, 2023
2 parents 89c20e9 + 9c94044 commit 029bbb0
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 14 deletions.
2 changes: 1 addition & 1 deletion applications/DeepSpeed-Chat/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ protobuf==3.20.3
accelerate>=0.15.0
torch>=1.12.0
deepspeed>=0.9.0
transformers>=4.31.0
transformers>=4.31.0,!=4.33.2
tensorboard
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def main():
print_rank_0(
f'Epoch: {epoch} | Step: {step} | PPO Epoch: {ppo_ep+1} | Actor Loss: {actor_loss_sum/inner_iter} | Critic Loss: {critic_loss_sum/inner_iter} | Unsupervised Loss: {unsup_loss_sum/inner_iter}',
args.global_rank)
print_throughput_step3(rlhf_engine.actor.model,
print_throughput_step3(rlhf_engine.actor.module,
rlhf_engine.critic, args, e2e_time,
trainer.generate_time, training_time,
args.global_rank)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _generate_sequence(self, prompts, mask, step):
# This has been added due to a probability/nan error that happens after
# meta-llama/Llama-2-7b-hf enabled do_sample:
# https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
if self.actor_model.model.config.model_type == "llama":
if self.actor_model.module.config.model_type == "llama":
kwargs = dict(do_sample=False)
else:
kwargs = dict()
Expand Down
11 changes: 6 additions & 5 deletions inference/huggingface/zero_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ With these two added techniques, we show the significant throughput and batch si
We plan to release more performance improvements to ZeRO-Inference, such as partial offloading, KV cache quantization, and etc, in the near future. Please check the [Working-In-Progress](#working-in-progress) list and stay tuned.

## Performance and Feature Highlights
We use token generation workload for our benchmarking of ZeRO-Inference. We run all our experiments on a single `NVIDIA A6000 GPU` with 48GB of device HBM on a Lambda workstation with 252GB of host CPU memory and a [CS3040 NVMe 2TB SDD](https://www.pny.com/CS3040-M2-NVMe-SSD?sku=M280CS3040-2TB-RB) with throughput of 5600 MB/s sequential reads. We configure a prompt length of 512 tokens and generation length of 32 tokens.
We use a token generation workload for our benchmarking of ZeRO-Inference. We run all our experiments on a single `NVIDIA A6000 GPU` with 48GB of device HBM on a Lambda workstation with 252GB of host CPU memory and a [CS3040 NVMe 2TB SDD](https://www.pny.com/CS3040-M2-NVMe-SSD?sku=M280CS3040-2TB-RB) with throughput of 5600 MB/s sequential reads. We configure a prompt length of 512 tokens and a generation length of 32 tokens.


### 😽 Overall Throughput Improvement of new ZeRO-Inference release 😽
Expand Down Expand Up @@ -38,7 +38,7 @@ Framework | Weight Quantization | KV Cache Offload | OPT-30B | OPT-66B | OPT
| ZeRO-Inference | Yes | Yes | 19.34 (bsz=128, cpu_offload) | **8.08** (bsz=64, cpu_offload) | **2.26** (bsz=24, cpu_offload) | **1.33** (bsz=24, cpu_offload) | 3.65 (bsz=200, cpu_offload)

#### Generality
Unlike FlexGen which supports only the OPT model family, ZeRO-Inference is designed as a general technique to support different model families. With our new optimizations, we continue to make it easy for model scientists to inference their favorite models using ZeRO-Inference. Our weight quantization optimization is generally applicable to any model without requiring modifcations. For KV cache offloading which requires minor code changes for each model family, we provide the required modifications for three model families (BLOOM, LLAMA2, and OPT) as a guide.
Unlike FlexGen which supports only the OPT model family, ZeRO-Inference is designed as a general technique to support different model families. With our new optimizations, we continue to make it easy for model scientists to inference their favorite models using ZeRO-Inference. Our weight quantization optimization is generally applicable to any model without requiring modifications. For KV cache offloading which requires minor code changes for each model family, we provide the required modifications for three model families (BLOOM, LLAMA2, and OPT) as a guide.

#### Token Generation Throughput
For fairness, we evaluate the same set of optimizations supported by both FlexGen and our ZeRO-Inference for performance comparison, specifically 4-bit weight quantization and KV cache offloading to CPU memory. We measure the impact of the optimizations individually and collectively. We consider model sizes that exceed the available 48GB HBM, thus requiring that model weights be offloaded to CPU or NVMe. Each data point is described using the format of | `throughput` (`batch size` and the memory used for weights offloading) |. Throughput is measured by `tokens/sec`. Each data point represents the best observed throughput from a batch size sweep. We observe that for the OPT family of models supported by both frameworks, ZeRO-Inference consistently achieved better generation throughput.
Expand Down Expand Up @@ -111,10 +111,10 @@ The following features/improvements are part of our work-in-progress. Please sta

## How to Enable INT4 Weight Quantization in ds_config

INT4 weight quantization can be easily enabled with a few lines of configuration change in your ds_config. ZeRO-Inference engine will automatically identify all candidate layers and convert their weight tensors into INT4. Currently, we support 2 modes: quantized initialization and post initialization quantization.
INT4 weight quantization can be easily enabled with a few lines of configuration change in your ds_config. ZeRO-Inference engine will automatically identify all candidate layers and convert their weight tensors into INT4. Currently, we support 2 modes: quantized initialization and post-initialization quantization.

### Quantized Initialization
This is the easiest way to getting started. By providing a few lines of hints in ds_config, the model will be on-the-fly quantized during model initialization (e.g., AutoModel.from_pretrained). All candidate layers will be automatically quantized.
This is the easiest way to get started. By providing a few lines of hints in ds_config, the model will be on-the-fly quantized during model initialization (e.g., AutoModel.from_pretrained). All candidate layers will be automatically quantized.
```python
ds_config = {
'weight_quantization': {
Expand All @@ -134,7 +134,7 @@ with torch.no_grad():
Currently, ZeRO-inference can quantize the weight matrix of nn.Embedding and nn.Linear into INT4 format. In the example above, we applied group_size=64 and performed asymmetric quantization on the 1st dimension of the weight matrix. `group_size` here is configurable based on users' demand.
### Post Initialization Quantization
In this mode, model is first loaded in FP16 format and then convert into INT4. The advantage of enabling this mode is that users will have an overview of the model architecture. Thus, they will have fine-grained control over the quantization decision. For example, which layer should be quantized with which quantization configuration can be controlled. Only a few lines of code changes are needed. Note that we plan to expand this mode to accommodate more formats in the near future.
In this mode, the model is first loaded in FP16 format and then converted into INT4. The advantage of enabling this mode is that users will have an overview of the model architecture. Thus, they will have fine-grained control over the quantization decision. For example, which layer should be quantized with which quantization configuration can be controlled. Only a few lines of code changes are needed. Note that we plan to expand this mode to accommodate more formats in the near future.
```python
from deepspeed.compression.inference.quantization import _init_group_wise_weight_quantization
ds_config = {
Expand Down Expand Up @@ -172,4 +172,5 @@ In running example above, only two fully connected layers (fc1 and fc2) and the
## References
- DeepSpeed [ZeRO-Inference](https://www.deepspeed.ai/2022/09/09/zero-inference.html)
- Sheng, Ying et al. [FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU](https://arxiv.org/abs/2303.06865)
- Shen, Sheng, et al. "Q-bert: Hessian based ultra low precision quantization of bert." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 34. No. 05. 2020.
20 changes: 14 additions & 6 deletions inference/huggingface/zero_inference/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from packaging import version


assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "ZeRO-Inference with weight quantization and kv cache offloading is available only in DeepSpeed 0.10.3+, please upgrade DeepSpeed"
assert version.parse(deepspeed.__version__) >= version.parse("0.10.3"), "ZeRO-Inference with weight quantization and kv cache offloading is available only in DeepSpeed 0.10.3+, please upgrade DeepSpeed"

def get_model_config(model_name):
if "175b" in model_name:
Expand Down Expand Up @@ -161,11 +161,19 @@ def run_generation(
return_token_type_ids = True
padding_side = "left" if config.model_type in ["opt"] else "right"

tokenizer = AutoTokenizer.from_pretrained(
model_name,
return_token_type_ids=return_token_type_ids,
padding_side=padding_side
)
if config.model_type == "opt":
tokenizer = AutoTokenizer.from_pretrained(
model_name.replace("175b", "66b"),
return_token_type_ids=return_token_type_ids,
padding_side=padding_side
)
else:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
return_token_type_ids=return_token_type_ids,
padding_side=padding_side
)


tokenizer.pad_token = tokenizer.eos_token

Expand Down

0 comments on commit 029bbb0

Please sign in to comment.