diff --git a/vertexai/_model_garden/_model_garden_models.py b/vertexai/_model_garden/_model_garden_models.py index b135742b16..506bc23a39 100644 --- a/vertexai/_model_garden/_model_garden_models.py +++ b/vertexai/_model_garden/_model_garden_models.py @@ -16,7 +16,7 @@ """Base class for working with Model Garden models.""" import dataclasses -from typing import Dict, Optional, Type +from typing import Dict, Optional, Type, TypeVar from google.cloud import aiplatform from google.cloud.aiplatform import base @@ -44,6 +44,8 @@ _LOGGER = base.Logger(__name__) +T = TypeVar("T", bound="_ModelGardenModel") + @dataclasses.dataclass class _ModelInfo: @@ -180,7 +182,7 @@ def __init__(self, model_id: str, endpoint_name: Optional[str] = None): ) @classmethod - def from_pretrained(cls, model_name: str) -> "_ModelGardenModel": + def from_pretrained(cls: Type[T], model_name: str) -> T: """Loads a _ModelGardenModel. Args: