Skip to content

Commit

Permalink
chore: LLM - Improved the typing of the from_pretrained method
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 540099054
  • Loading branch information
Ark-kun authored and copybara-github committed Jun 13, 2023
1 parent 1fda417 commit 0ab62a0
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions vertexai/_model_garden/_model_garden_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Base class for working with Model Garden models."""

import dataclasses
from typing import Dict, Optional, Type
from typing import Dict, Optional, Type, TypeVar

from google.cloud import aiplatform
from google.cloud.aiplatform import base
Expand Down Expand Up @@ -44,6 +44,8 @@

_LOGGER = base.Logger(__name__)

T = TypeVar("T", bound="_ModelGardenModel")


@dataclasses.dataclass
class _ModelInfo:
Expand Down Expand Up @@ -180,7 +182,7 @@ def __init__(self, model_id: str, endpoint_name: Optional[str] = None):
)

@classmethod
def from_pretrained(cls, model_name: str) -> "_ModelGardenModel":
def from_pretrained(cls: Type[T], model_name: str) -> T:
"""Loads a _ModelGardenModel.
Args:
Expand Down

0 comments on commit 0ab62a0

Please sign in to comment.