Skip to content

Commit

Permalink
fix: update passing additional model data sources to API (aws#1472)
Browse files Browse the repository at this point in the history
* feat: Added utils for extracting JS data sources   (aws#1471)

* added utils for accessing hosting data sources

* added utils for accessing hosting data sources

* removed other changes

* fixed formatting issues

* remove .keys()

* updated JumpStartModelDataSource

* fix slots

* remove print

* fix tests

* update tests

* fix: update passing additional model data sources to API

* format

* format

* format

* format and address comments

* format

* format

* format

---------

Co-authored-by: Adam Kozdrowicz <[email protected]>
  • Loading branch information
Captainia and akozd authored Jun 12, 2024
1 parent 9a410e5 commit 2331dec
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 21 deletions.
40 changes: 39 additions & 1 deletion src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.session import Session
from sagemaker.utils import name_from_base, format_tags, Tags
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker import resource_requirements
Expand Down Expand Up @@ -615,6 +615,40 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
return kwargs


def _add_additional_model_data_sources_to_kwargs(
kwargs: JumpStartModelInitKwargs,
) -> JumpStartModelInitKwargs:
"""Sets default additional model data sources to init kwargs"""

specs = verify_model_region_and_return_specs(
model_id=kwargs.model_id,
version=kwargs.model_version,
scope=JumpStartScriptScope.INFERENCE,
region=kwargs.region,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
sagemaker_session=kwargs.sagemaker_session,
model_type=kwargs.model_type,
config_name=kwargs.config_name,
)

additional_data_sources = specs.get_additional_s3_data_sources()
api_shape_additional_model_data_sources = (
[
camel_case_to_pascal_case(data_source.to_json())
for data_source in additional_data_sources
]
if specs.get_additional_s3_data_sources()
else None
)

kwargs.additional_model_data_sources = (
kwargs.additional_model_data_sources or api_shape_additional_model_data_sources
)

return kwargs


def _add_config_name_to_deploy_kwargs(
kwargs: JumpStartModelDeployKwargs, training_config_name: Optional[str] = None
) -> JumpStartModelInitKwargs:
Expand Down Expand Up @@ -861,6 +895,7 @@ def get_init_kwargs(
disable_instance_type_logging: bool = False,
resources: Optional[ResourceRequirements] = None,
config_name: Optional[str] = None,
additional_model_data_sources: Optional[Dict[str, Any]] = None,
) -> JumpStartModelInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -893,6 +928,7 @@ def get_init_kwargs(
training_instance_type=training_instance_type,
resources=resources,
config_name=config_name,
additional_model_data_sources=additional_model_data_sources,
)

model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
Expand Down Expand Up @@ -925,4 +961,6 @@ def get_init_kwargs(

model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)

model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)

return model_init_kwargs
17 changes: 7 additions & 10 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
model_package_arn: Optional[str] = None,
resources: Optional[ResourceRequirements] = None,
config_name: Optional[str] = None,
additional_model_data_sources: Optional[Dict[str, Any]] = None,
):
"""Initializes a ``JumpStartModel``.
Expand Down Expand Up @@ -287,8 +288,10 @@ def __init__(
for a model to be deployed to an endpoint.
Only EndpointType.INFERENCE_COMPONENT_BASED supports this feature.
(Default: None).
config_name (Optional[str]): The name of the JumpStartConfig that can be
optionally applied to the model and override corresponding fields.
config_name (Optional[str]): The name of the JumpStart config that can be
optionally applied to the model.
additional_model_data_sources (Optional[Dict[str, Any]]): Additional location
of SageMaker model data (default: None).
Raises:
ValueError: If the model ID is not recognized by JumpStart.
"""
Expand Down Expand Up @@ -339,6 +342,7 @@ def _validate_model_id_and_type():
model_package_arn=model_package_arn,
resources=resources,
config_name=config_name,
additional_model_data_sources=additional_model_data_sources,
)

self.orig_predictor_cls = predictor_cls
Expand All @@ -352,6 +356,7 @@ def _validate_model_id_and_type():
self.region = model_init_kwargs.region
self.sagemaker_session = model_init_kwargs.sagemaker_session
self.config_name = model_init_kwargs.config_name
self.additional_model_data_sources = model_init_kwargs.additional_model_data_sources

if self.model_type == JumpStartModelType.PROPRIETARY:
self.log_subscription_warning()
Expand All @@ -369,14 +374,6 @@ def _validate_model_id_and_type():
model_type=self.model_type,
)

self.additional_model_data_sources = (
self._metadata_configs.get(self.config_name).resolved_config.get(
"hosting_additional_data_sources"
)
if self._metadata_configs.get(self.config_name)
else None
)

def log_subscription_warning(self) -> None:
"""Log message prompting the customer to subscribe to the proprietary model."""
subscription_link = verify_model_region_and_return_specs(
Expand Down
26 changes: 17 additions & 9 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,13 +844,15 @@ def to_json(self) -> Dict[str, Any]:
cur_val = getattr(self, att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
else:
elif cur_val:
json_obj[att] = cur_val
return json_obj


class AdditionalModelDataSource(JumpStartDataHolderType):
"""Data class of additional model data source mirrors Hosting API."""
"""Data class of additional model data source mirrors CreateModel API."""

SERIALIZATION_EXCLUSION_SET: Set[str] = set()

__slots__ = ["channel_name", "s3_data_source"]

Expand All @@ -871,23 +873,26 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
self.channel_name: str = json_obj["channel_name"]
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])

def to_json(self) -> Dict[str, Any]:
def to_json(self, exclude_keys=True) -> Dict[str, Any]:
"""Returns json representation of AdditionalModelDataSource object."""
json_obj = {}
for att in self.__slots__:
if hasattr(self, att):
cur_val = getattr(self, att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
else:
json_obj[att] = cur_val
if exclude_keys and att not in self.SERIALIZATION_EXCLUSION_SET or not exclude_keys:
cur_val = getattr(self, att)
if issubclass(type(cur_val), JumpStartDataHolderType):
json_obj[att] = cur_val.to_json()
else:
json_obj[att] = cur_val
return json_obj


class JumpStartModelDataSource(AdditionalModelDataSource):
"""Data class JumpStart additional model data source."""

__slots__ = ["artifact_version"] + AdditionalModelDataSource.__slots__
SERIALIZATION_EXCLUSION_SET = {"artifact_version"}

__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__

def from_json(self, json_obj: Dict[str, Any]) -> None:
"""Sets fields in object based on json.
Expand Down Expand Up @@ -1761,6 +1766,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
"training_instance_type",
"resources",
"config_name",
"additional_model_data_sources",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1806,6 +1812,7 @@ def __init__(
training_instance_type: Optional[str] = None,
resources: Optional[ResourceRequirements] = None,
config_name: Optional[str] = None,
additional_model_data_sources: Optional[Dict[str, Any]] = None,
) -> None:
"""Instantiates JumpStartModelInitKwargs object."""

Expand Down Expand Up @@ -1837,6 +1844,7 @@ def __init__(
self.training_instance_type = training_instance_type
self.resources = resources
self.config_name = config_name
self.additional_model_data_sources = additional_model_data_sources


class JumpStartModelDeployKwargs(JumpStartKwargs):
Expand Down
30 changes: 30 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,3 +1798,33 @@ def extract_instance_rate_per_hour(price_data: Dict[str, Any]) -> Optional[Dict[
"name": "Instance Rate",
}
return None


def camel_case_to_pascal_case(data: Dict[str, Any]) -> Dict[str, Any]:
"""Iteratively updates a dictionary to convert all keys from snake_case to PascalCase.
Args:
data (dict): The dictionary to be updated.
Returns:
dict: The updated dictionary with keys in PascalCase.
"""
result = {}

def convert_key(key):
"""Converts a snake_case key to PascalCase."""
return "".join(part.capitalize() for part in key.split("_"))

def convert_value(value):
"""Recursively processes the value of a key-value pair."""
if isinstance(value, dict):
return camel_case_to_pascal_case(value)
if isinstance(value, list):
return [convert_value(item) for item in value]

return value

for key, value in data.items():
result[convert_key(key)] = convert_value(value)

return result
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/jumpstart/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
Please add the new argument to the skip set below,
and reach out to JumpStart team."""

init_args_to_skip: Set[str] = set(["additional_model_data_sources"])
init_args_to_skip: Set[str] = set([])
deploy_args_to_skip: Set[str] = set(["kwargs"])

parent_class_init = Model.__init__
Expand Down
40 changes: 40 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sagemaker.experiments._run_context import _RunContext
from sagemaker.session_settings import SessionSettings
from sagemaker.utils import (
camel_case_to_pascal_case,
deep_override_dict,
flatten_dict,
get_instance_type_family,
Expand Down Expand Up @@ -2055,3 +2056,42 @@ def test_resolve_routing_config(routing_config, expected):

def test_resolve_routing_config_ex():
pytest.raises(ValueError, lambda: _resolve_routing_config({"RoutingStrategy": "Invalid"}))


class TestConvertToPascalCase(TestCase):
def test_simple_dict(self):
input_dict = {"first_name": "John", "last_name": "Doe"}
expected_output = {"FirstName": "John", "LastName": "Doe"}
self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output)

def camel_case_to_pascal_case_nested(self):
input_dict = {
"model_name": "my-model",
"primary_container": {
"image": "my-docker-image:latest",
"model_data_url": "s3://my-bucket/model.tar.gz",
"environment": {"env_var_1": "value1", "env_var_2": "value2"},
},
"execution_role_arn": "arn:aws:iam::123456789012:role/my-sagemaker-role",
"tags": [
{"key": "project", "value": "my-project"},
{"key": "environment", "value": "development"},
],
}
expected_output = {
"ModelName": "my-model",
"PrimaryContainer": {
"Image": "my-docker-image:latest",
"ModelDataUrl": "s3://my-bucket/model.tar.gz",
"Environment": {"EnvVar1": "value1", "EnvVar2": "value2"},
},
"ExecutionRoleArn": "arn:aws:iam::123456789012:role/my-sagemaker-role",
"Tags": [
{"Key": "project", "Value": "my-project"},
{"Key": "environment", "Value": "development"},
],
}
self.assertEqual(camel_case_to_pascal_case(input_dict), expected_output)

def test_empty_input(self):
self.assertEqual(camel_case_to_pascal_case({}), {})

0 comments on commit 2331dec

Please sign in to comment.