Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

feat(model): add support for half-precision #45

Merged
merged 1 commit into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ENV MODEL_CACHE_DIR="/models"
ENV MODEL_LOAD_IN_8BIT="false"
ENV MODEL_LOCAL_FILES_ONLY="false"
ENV MODEL_TRUST_REMOTE_CODE="false"
ENV MODEL_HALF_PRECISION="false"
ENV SERVER_THREADS="8"
ENV SERVER_IDENTITY="basaran"
ENV SERVER_CONNECTION_LIMIT="1024"
Expand Down
1 change: 1 addition & 0 deletions basaran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def is_true(value):
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", ""))
MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", ""))
MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", ""))

# Server-related arguments:
# https://docs.pylonsproject.org/projects/waitress/en/stable/arguments.html
Expand Down
2 changes: 2 additions & 0 deletions basaran/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from . import MODEL_LOAD_IN_8BIT
from . import MODEL_LOCAL_FILES_ONLY
from . import MODEL_TRUST_REMOTE_CODE
from . import MODEL_HALF_PRECISION
from . import SERVER_THREADS
from . import SERVER_IDENTITY
from . import SERVER_CONNECTION_LIMIT
Expand All @@ -39,6 +40,7 @@
load_in_8bit=MODEL_LOAD_IN_8BIT,
local_files_only=MODEL_LOCAL_FILES_ONLY,
trust_remote_code=MODEL_TRUST_REMOTE_CODE,
half_precision=MODEL_HALF_PRECISION,
)

# Create and configure application.
Expand Down
5 changes: 5 additions & 0 deletions basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def load_model(
load_in_8bit=False,
local_files_only=False,
trust_remote_code=False,
half_precision=False,
):
"""Load a text generation model and make it stream-able."""
kwargs = {
Expand All @@ -333,6 +334,10 @@ def load_model(
except ValueError:
model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path, **kwargs)

# Cast all parameters to half-precision if required.
if half_precision:
model = model.half()

# Check if the model has text generation capabilities.
if not model.can_generate():
raise TypeError(f"{name_or_path} is not a text generation model")
Expand Down