-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(build): separate MODEL_REVISION, MODEL_PRECISION, HF_MODEL_ID
The above already worked for runtime downloads, but can now be used as build-args to download the image at build time and include in your image. SOME VERY IMPORTANT NOTES: 1) MODEL_REVISION no longer defaults to MODEL_PRECISION, you need to specify it separately (however, still defaults to "fp16" in Dockerfile). You'll get a warning if you specify MODEL_PRECISION without _REVISION, to help in the most common case of the old behaviour. 2) build-arg PRECISION still works but has been deprecated for MODEL_PRECISION (which was already the call-arg name). 3) normalized_model_id uses MODEL_REVISION, so for the early birds already using the "cloud cache" on S3, your filenames might no longer match in some cases.
- Loading branch information
Showing
6 changed files
with
79 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,32 @@ | ||
import os | ||
import torch | ||
|
||
PRECISION = os.getenv("PRECISION") | ||
DEPRECATED_PRECISION = os.getenv("PRECISION") | ||
MODEL_PRECISION = os.getenv("MODEL_PRECISION") or DEPRECATED_PRECISION | ||
MODEL_REVISION = os.getenv("MODEL_REVISION") | ||
|
||
revision = None if PRECISION == "" else PRECISION | ||
torch_dtype = None if PRECISION == "" else torch.float16 | ||
if DEPRECATED_PRECISION: | ||
print("Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION") | ||
print("Your setup still works but in a future release, this will throw an error") | ||
|
||
if MODEL_PRECISION and not MODEL_REVISION: | ||
print("Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please") | ||
print(f'explicitly set MODEL_REVISION="{MODEL_PRECISION}" if that\'s what you') | ||
print("want.") | ||
|
||
def revision_from_precision(precision=PRECISION): | ||
return precision if precision else None | ||
|
||
def revision_from_precision(precision=MODEL_PRECISION): | ||
# return precision if precision else None | ||
raise Exception("revision_from_precision no longer supported") | ||
|
||
def torch_dtype_from_precision(precision=PRECISION): | ||
|
||
def torch_dtype_from_precision(precision=MODEL_PRECISION): | ||
if precision == "fp16": | ||
return torch.float16 | ||
return None | ||
|
||
|
||
def torch_dtype_from_precision(precision=MODEL_PRECISION): | ||
if precision == "fp16": | ||
return torch.float16 | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters