Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[BLOOM] Configurable tokenizer name (#735)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhanyuan Zhang <[email protected]>
Co-authored-by: Lianmin Zheng <[email protected]>
  • Loading branch information
3 people authored Oct 8, 2022
1 parent cfa8b86 commit 15c0a52
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False

# Load the model. Alpa automatically downloads the weights to the specificed path
Expand Down
24 changes: 20 additions & 4 deletions examples/llm_serving/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ The code below shows how to use huggingface/transformers interface and Alpa dist
from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model
# Load the tokenizer. We have to use the 30B version because
# other versions have some issues. The 30B version works for all OPT models.
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
# Load the tokenizer. All OPT models with different sizes share the same tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False
# Load the model. Alpa automatically downloads the weights to the specificed path
Expand All @@ -60,7 +59,6 @@ The code below shows how to use huggingface/transformers interface and Alpa dist
print(generated_string)
Requirements
============
1. Install Alpa following the `installation guide <https://alpa-projects.github.io/install.html>`_. You can either install by python wheel or build from source.
Expand Down Expand Up @@ -227,6 +225,24 @@ Here are some tips for improving the generation speed.

If you find the generation speed too slow and want to accelerate it, please join `Alpa slack <https://forms.gle/YEZTCrtZD6EAVNBQ7>`_ and tell us your use cases. We are acitvely working on improving the performance.


Other Models
============
Alpa also supports `BLOOM <https://huggingface.co/bigscience/bloom>`_.
You can use commands similar to OPT but with a different model name.

.. code:: shell
# Huggingface/pytorch backend
python3 textgen.py --model bigscience/bloom-560m
# Jax backend
python3 textgen.py --model jax/bloom-560m
# Alpa backend
python3 textgen.py --model alpa/bloom-560m
License
=======
The use of the OPT pretrained weights is subject to the `Model License <https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md>`_ by Metaseq.
16 changes: 13 additions & 3 deletions examples/llm_serving/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self,
self.torch_device = torch_device

# Tokenizer arguments
self.tokenizer_name = "facebook/opt-30b" if not tokenizer_name else tokenizer_name
self.tokenizer_name = tokenizer_name
self.tokenizer = None
self.add_bos_token = add_bos_token

Expand Down Expand Up @@ -70,8 +70,18 @@ def load_model(self):
load_time = time.time() - tic

# Init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
self.tokenizer.add_bos_token = False
if self.tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)
else:
if "opt" in self.model_name:
self.tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b")
self.tokenizer.add_bos_token = False
elif "bloom" in self.model_name:
tokenizer_name = self.model_name.replace("alpa", "bigscience")\
.replace("jax", "bigscience")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

if "alpa" in self.model_name:
import alpa
Expand Down
5 changes: 4 additions & 1 deletion examples/llm_serving/launch_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self,
model_name: str,
path: str,
torch_device: str,
tokenizer_name: str,
num_beams: int,
num_return_sequences: int,
use_recaptcha: bool,
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(self,
self.generator = Generator(model_name,
path,
torch_device=torch_device,
tokenizer_name=tokenizer_name,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
max_seq_len=self.max_seq_len,
Expand Down Expand Up @@ -342,6 +344,7 @@ def get_remote_ip(self, request):
parser.add_argument("--path", type=str, default="~/opt_weights/")
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--torch-device", type=str, default="cpu")
parser.add_argument("--tokenizer", type=str)
parser.add_argument("--no-recaptcha", action="store_true")
parser.add_argument("--register-name", type=str, default="default")
parser.add_argument("--ssl-keyfile", type=str)
Expand All @@ -360,7 +363,7 @@ def get_remote_ip(self, request):
controller.launch_mesh_group_manager.remote(group_id)
t = controller.register_model.remote(
args.register_name, LangaugeModelWorker,
(args.model, args.path, args.torch_device, NUM_BEAMS, NUM_RETURN_SEQ,
(args.model, args.path, args.torch_device, args.tokenizer, NUM_BEAMS, NUM_RETURN_SEQ,
False if args.no_recaptcha else USE_RECAPTCHA),
override=True)
ray.get(t)
Expand Down
9 changes: 5 additions & 4 deletions examples/llm_serving/model/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,12 @@ def inference_func(input_ids,

inference_func_config = InferenceFuncConfig()
for key in inference_func_config.__dataclass_fields__.keys():
setattr(inference_func_config, key, getattr(model.config, key))
if hasattr(model.config, "seq_length"):
seq_len = model.config.seq_length
else:
if hasattr(model.config, key):
setattr(inference_func_config, key, getattr(model.config, key))
if hasattr(model.config, "max_position_embeddings"):
seq_len = model.config.max_position_embeddings
else:
seq_len = 2048

transformer_config = TransformerModelConfig(
H=model.config.hidden_size,
Expand Down
10 changes: 10 additions & 0 deletions examples/llm_serving/test_textgen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test the correctness of textgen.py
set -x

python3 textgen.py --model bigscience/bloom-560m
python3 textgen.py --model jax/bloom-560m
python3 textgen.py --model alpa/bloom-560m

python3 textgen.py --model facebook/opt-1.3b
python3 textgen.py --model jax/opt-1.3b
python3 textgen.py --model alpa/opt-1.3b
4 changes: 2 additions & 2 deletions examples/llm_serving/textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def main(args):
if "opt" in args.model:
# We have to use the 30B version because other versions have some issues.
# The 30B version works for all OPT models.
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-30b")
tokenizer.add_bos_token = False
elif "bloom" in args.model:
name = args.model.replace("alpa", "bigscience")\
.replace("jax", "bigscience")
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=False)
tokenizer = AutoTokenizer.from_pretrained(name)

generate_params = {
"do_sample": args.do_sample,
Expand Down

0 comments on commit 15c0a52

Please sign in to comment.