Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add environment variable backed properties to config #2051

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/cog/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .base_input import BaseInput
from .base_predictor import BasePredictor
from .code_xforms import load_module_from_string, strip_model_source_code
from .env_property import env_property
from .errors import ConfigDoesNotExist
from .mode import Mode
from .predictor import (
Expand All @@ -23,6 +24,11 @@
from .types import CogConfig

COG_YAML_FILE = "cog.yaml"
COG_PREDICT_TYPE_STUB_ENV_VAR = "COG_PREDICT_TYPE_STUB"
COG_TRAIN_TYPE_STUB_ENV_VAR = "COG_TRAIN_TYPE_STUB"
COG_PREDICT_CODE_STRIP_ENV_VAR = "COG_PREDICT_CODE_STRIP"
COG_TRAIN_CODE_STRIP_ENV_VAR = "COG_TRAIN_CODE_STRIP"
COG_GPU_ENV_VAR = "COG_GPU"
PREDICT_METHOD_NAME = "predict"
TRAIN_METHOD_NAME = "train"

Expand All @@ -37,6 +43,14 @@ def _method_name_from_mode(mode: Mode) -> str:
raise ValueError(f"Mode {mode} not recognised for method name mapping")


def _env_var_from_mode(mode: Mode) -> str:
if mode == Mode.PREDICT:
return COG_PREDICT_CODE_STRIP_ENV_VAR
elif mode == Mode.TRAIN:
return COG_TRAIN_CODE_STRIP_ENV_VAR
raise ValueError(f"Mode {mode} not recognised for env var mapping")


class Config:
"""A class for reading the cog.yaml properties."""

Expand Down Expand Up @@ -65,16 +79,19 @@ def _cog_config(self) -> CogConfig:
return config

@property
@env_property(COG_PREDICT_TYPE_STUB_ENV_VAR)
def predictor_predict_ref(self) -> Optional[str]:
"""Find the predictor ref for the predict mode."""
return self._cog_config.get(str(Mode.PREDICT))

@property
@env_property(COG_TRAIN_TYPE_STUB_ENV_VAR)
def predictor_train_ref(self) -> Optional[str]:
"""Find the predictor ref for the train mode."""
return self._cog_config.get(str(Mode.TRAIN))

@property
@env_property(COG_GPU_ENV_VAR)
def requires_gpu(self) -> bool:
"""Whether this cog requires the use of a GPU."""
return bool(self._cog_config.get("build", {}).get("gpu", False))
Expand All @@ -87,6 +104,9 @@ def _predictor_code(
mode: Mode,
module_name: str,
) -> Optional[str]:
source_code = os.environ.get(_env_var_from_mode(mode))
if source_code is not None:
return source_code
if sys.version_info >= (3, 9):
with open(module_path, encoding="utf-8") as file:
return strip_model_source_code(file.read(), [class_name], [method_name])
Expand Down
42 changes: 42 additions & 0 deletions python/cog/env_property.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, Union

R = TypeVar("R")


def _get_origin(typ: Any) -> Any:
if hasattr(typ, "__origin__"):
return typ.__origin__
return None


def _get_args(typ: Any) -> Any:
if hasattr(typ, "__args__"):
return typ.__args__
return ()


def env_property(
env_var: str,
) -> Callable[[Callable[[Any], R]], Callable[[Any], R]]:
"""Wraps a class property in an environment variable check."""

def decorator(func: Callable[[Any], R]) -> Callable[[Any], R]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> R:
result = os.environ.get(env_var)
if result is not None:
expected_type = func.__annotations__.get("return", str)
if (
_get_origin(expected_type) is Optional
or _get_origin(expected_type) is Union
):
expected_type = _get_args(expected_type)[0]
return expected_type(result)
result = func(*args, **kwargs)
return result

return wrapper

return decorator
125 changes: 125 additions & 0 deletions python/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,110 @@
import os
import tempfile

import pytest

from cog.config import (
COG_GPU_ENV_VAR,
COG_PREDICT_CODE_STRIP_ENV_VAR,
COG_PREDICT_TYPE_STUB_ENV_VAR,
COG_TRAIN_TYPE_STUB_ENV_VAR,
COG_YAML_FILE,
Config,
)
from cog.errors import ConfigDoesNotExist
from cog.mode import Mode


def test_predictor_predict_ref_env_var():
predict_ref = "predict.py:Predictor"
os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] = predict_ref
config = Config()
config_predict_ref = config.predictor_predict_ref
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
assert (
config_predict_ref == predict_ref
), "Predict Reference should come from the environment variable."


def test_predictor_predict_ref_no_env_var():
if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
pwd = os.getcwd()
with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)
with open(COG_YAML_FILE, "w", encoding="utf-8") as handle:
handle.write("""
build:
python_version: "3.11"
predict: "predict.py:Predictor"
""")
config = Config()
config_predict_ref = config.predictor_predict_ref
assert (
config_predict_ref == "predict.py:Predictor"
), "Predict Reference should come from the cog config file."
os.chdir(pwd)


def test_config_no_config_file():
if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
config = Config()
with pytest.raises(ConfigDoesNotExist):
_ = config.predictor_predict_ref


def test_config_initial_values():
if COG_PREDICT_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
config = Config(config={"predict": "predict.py:Predictor"})
config_predict_ref = config.predictor_predict_ref
assert (
config_predict_ref == "predict.py:Predictor"
), "Predict Reference should come from the initial config dictionary."


def test_predictor_train_ref_env_var():
train_ref = "predict.py:Predictor"
os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR] = train_ref
config = Config()
config_train_ref = config.predictor_train_ref
del os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR]
assert (
config_train_ref == train_ref
), "Train Reference should come from the environment variable."


def test_predictor_train_ref_no_env_var():
train_ref = "predict.py:Predictor"
if COG_TRAIN_TYPE_STUB_ENV_VAR in os.environ:
del os.environ[COG_TRAIN_TYPE_STUB_ENV_VAR]
config = Config(config={"train": train_ref})
config_train_ref = config.predictor_train_ref
assert (
config_train_ref == train_ref
), "Train Reference should come from the initial config dictionary."


def test_requires_gpu_env_var():
gpu = True
os.environ[COG_GPU_ENV_VAR] = str(gpu)
config = Config()
config_gpu = config.requires_gpu
del os.environ[COG_GPU_ENV_VAR]
assert config_gpu, "Requires GPU should come from the environment variable."


def test_requires_gpu_no_env_var():
if COG_GPU_ENV_VAR in os.environ:
del os.environ[COG_GPU_ENV_VAR]
config = Config(config={"build": {"gpu": False}})
config_gpu = config.requires_gpu
assert (
not config_gpu
), "Requires GPU should come from the initial config dictionary."


def test_get_predictor_ref_predict():
train_ref = "predict.py:Predictor"
config = Config(config={"train": train_ref})
Expand All @@ -25,6 +123,33 @@ def test_get_predictor_ref_train():
), "The predict ref should equal the config predict ref."


def test_get_predictor_types_with_env_var():
predict_ref = "predict.py:Predictor"
os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR] = predict_ref
os.environ[COG_PREDICT_CODE_STRIP_ENV_VAR] = """
from cog import BasePredictor, Path
from typing import Optional
from pydantic import BaseModel
class ModelOutput(BaseModel):
success: bool
error: Optional[str]
segmentedImage: Optional[Path]
class Predictor(BasePredictor):
def predict(self, msg: str) -> ModelOutput:
return None
"""
config = Config()
input_type, output_type = config.get_predictor_types(Mode.PREDICT)
del os.environ[COG_PREDICT_CODE_STRIP_ENV_VAR]
del os.environ[COG_PREDICT_TYPE_STUB_ENV_VAR]
assert (
str(input_type) == "<class 'cog.predictor.Input'>"
), "Predict input type should be the predictor Input."
assert (
str(output_type) == "<class 'cog.predictor.get_output_type.<locals>.Output'>"
), "Predict output type should be the predictor Output."


def test_get_predictor_types():
with tempfile.TemporaryDirectory() as tmpdir:
predict_python_file = os.path.join(tmpdir, "predict.py")
Expand Down