Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingestion/SageMaker): Remove deprecated apis and add stateful ingestion capability #10573

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import DefaultDict, Dict, Iterable
from typing import DefaultDict, Dict, Iterable, List, Optional

from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
Expand All @@ -10,7 +10,7 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import Source
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceConfig,
Expand All @@ -26,13 +26,19 @@
)
from datahub.ingestion.source.aws.sagemaker_processors.lineage import LineageProcessor
from datahub.ingestion.source.aws.sagemaker_processors.models import ModelProcessor
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class SagemakerSource(Source):
class SagemakerSource(StatefulIngestionSourceBase):
"""
This plugin extracts the following:

Expand All @@ -45,7 +51,7 @@ class SagemakerSource(Source):
report = SagemakerSourceReport()

def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext):
super().__init__(ctx)
super().__init__(config, ctx)
self.source_config = config
self.report = SagemakerSourceReport()
self.sagemaker_client = config.sagemaker_client
Expand All @@ -56,6 +62,14 @@ def create(cls, config_dict, ctx):
config = SagemakerSourceConfig.parse_obj(config_dict)
return cls(config, ctx)

def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
StaleEntityRemovalHandler.create(
self, self.source_config, self.ctx
).workunit_processor,
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# get common lineage graph
lineage_processor = LineageProcessor(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

from pydantic.fields import Field

from datahub.ingestion.api.source import SourceReport
from datahub.configuration.source_common import PlatformInstanceConfigMixin
from datahub.ingestion.source.aws.aws_common import AwsSourceConfig
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalSourceReport,
StatefulIngestionConfigBase,
StatefulStaleMetadataRemovalConfig,
)


class SagemakerSourceConfig(AwsSourceConfig):
class SagemakerSourceConfig(
AwsSourceConfig,
PlatformInstanceConfigMixin,
StatefulIngestionConfigBase,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't look like this source supports platform instances, so having it in the config might be misleading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! It's interesting that if I give platform_instance in the ingestion recipe, then I can see the platform instance available when browsing ML models, it's in the response of graphQL query getBrowseResultsV2 even though each ML models entity does not have platform_instance.

Do you think I can make some more improvements to support platform instances? Most of the make_ml_model..*_urn code is in mce_builder.py.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah adding support for it isn't too bad. it involves

  1. adding platform instance to the config (you already did this)
  2. adding the platform instance as a prefix to all urns. modifying the helpers in mce_builder.py to take an optional platform_instance is the right approach there
  3. emit dataPlatformInstance aspects for every entity

the browse path v2 aspects will get auto-generated based on everything else

):
extract_feature_groups: Optional[bool] = Field(
default=True, description="Whether to extract feature groups."
)
Expand All @@ -17,21 +26,24 @@ class SagemakerSourceConfig(AwsSourceConfig):
extract_jobs: Optional[Union[Dict[str, str], bool]] = Field(
default=True, description="Whether to extract AutoML jobs."
)
# Custom Stateful Ingestion settings
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None

@property
def sagemaker_client(self):
return self.get_sagemaker_client()


@dataclass
class SagemakerSourceReport(SourceReport):
class SagemakerSourceReport(StaleEntityRemovalSourceReport):
feature_groups_scanned = 0
features_scanned = 0
endpoints_scanned = 0
groups_scanned = 0
models_scanned = 0
jobs_scanned = 0
datasets_scanned = 0
filtered: List[str] = field(default_factory=list)

def report_feature_group_scanned(self) -> None:
self.feature_groups_scanned += 1
Expand All @@ -53,3 +65,6 @@ def report_job_scanned(self) -> None:

def report_dataset_scanned(self) -> None:
self.datasets_scanned += 1

def report_dropped(self, name: str) -> None:
self.filtered.append(name)
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Dict, Final
from typing import Dict

from typing_extensions import Final

from datahub.metadata.schema_classes import JobStatusClass

Expand Down Expand Up @@ -63,8 +65,9 @@ class AutoMlJobInfo(SageMakerJobInfo):
list_key: Final = "AutoMLJobSummaries"
list_name_key: Final = "AutoMLJobName"
list_arn_key: Final = "AutoMLJobArn"
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
describe_command: Final = "describe_auto_ml_job"
# DescribeAutoMLJobV2 are new versions of CreateAutoMLJob and DescribeAutoMLJob which offer backward compatibility.
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeAutoMLJobV2.html
describe_command: Final = "describe_auto_ml_job_v2"
describe_name_key: Final = "AutoMLJobName"
describe_arn_key: Final = "AutoMLJobArn"
describe_status_key: Final = "AutoMLJobStatus"
Expand Down Expand Up @@ -101,28 +104,6 @@ class CompilationJobInfo(SageMakerJobInfo):
processor = "process_compilation_job"


class EdgePackagingJobInfo(SageMakerJobInfo):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_edge_packaging_jobs
list_command: Final = "list_edge_packaging_jobs"
list_key: Final = "EdgePackagingJobSummaries"
list_name_key: Final = "EdgePackagingJobName"
list_arn_key: Final = "EdgePackagingJobArn"
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
describe_command: Final = "describe_edge_packaging_job"
describe_name_key: Final = "EdgePackagingJobName"
describe_arn_key: Final = "EdgePackagingJobArn"
describe_status_key: Final = "EdgePackagingJobStatus"
status_map = {
"INPROGRESS": JobStatusClass.IN_PROGRESS,
"COMPLETED": JobStatusClass.COMPLETED,
"FAILED": JobStatusClass.FAILED,
"STARTING": JobStatusClass.STARTING,
"STOPPING": JobStatusClass.STOPPING,
"STOPPED": JobStatusClass.STOPPED,
}
processor = "process_edge_packaging_job"


class HyperParameterTuningJobInfo(SageMakerJobInfo):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_hyper_parameter_tuning_jobs
list_command: Final = "list_hyper_parameter_tuning_jobs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from datahub.ingestion.source.aws.sagemaker_processors.job_classes import (
AutoMlJobInfo,
CompilationJobInfo,
EdgePackagingJobInfo,
HyperParameterTuningJobInfo,
LabelingJobInfo,
ProcessingJobInfo,
Expand All @@ -53,7 +52,6 @@
"JobInfo",
AutoMlJobInfo,
CompilationJobInfo,
EdgePackagingJobInfo,
HyperParameterTuningJobInfo,
LabelingJobInfo,
ProcessingJobInfo,
Expand All @@ -65,7 +63,6 @@
class JobType(Enum):
AUTO_ML = "auto_ml"
COMPILATION = "compilation"
EDGE_PACKAGING = "edge_packaging"
HYPER_PARAMETER_TUNING = "hyper_parameter_tuning"
LABELING = "labeling"
PROCESSING = "processing"
Expand All @@ -78,7 +75,6 @@ class JobType(Enum):
job_type_to_info: Mapping[JobType, Any] = {
JobType.AUTO_ML: AutoMlJobInfo(),
JobType.COMPILATION: CompilationJobInfo(),
JobType.EDGE_PACKAGING: EdgePackagingJobInfo(),
JobType.HYPER_PARAMETER_TUNING: HyperParameterTuningJobInfo(),
JobType.LABELING: LabelingJobInfo(),
JobType.PROCESSING: ProcessingJobInfo(),
Expand Down Expand Up @@ -416,23 +412,20 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
"""
Process outputs from Boto3 describe_auto_ml_job()

See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_auto_ml_job_v2.html
"""

JOB_TYPE = JobType.AUTO_ML

input_datasets = {}

for input_config in job.get("InputDataConfig", []):
for input_config in job.get("AutoMLJobInputDataConfig", []):
input_data = input_config.get("DataSource", {}).get("S3DataSource")

if input_data is not None and "S3Uri" in input_data:
input_datasets[make_s3_urn(input_data["S3Uri"], self.env)] = {
"dataset_type": "s3",
"uri": input_data["S3Uri"],
"datatype": input_data.get("S3DataType"),
}

output_datasets = {}

output_s3_path = job.get("OutputDataConfig", {}).get("S3OutputPath")
Expand All @@ -448,6 +441,18 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
JOB_TYPE,
)

metrics: Dict[str, Any] = {}
# Get job metrics from CandidateMetrics
candidate_metrics = (
job.get("BestCandidate", {})
.get("CandidateProperties", {})
.get("CandidateMetrics", [])
)
if candidate_metrics:
metrics = {
metric["MetricName"]: metric["Value"] for metric in candidate_metrics
}

model_containers = job.get("BestCandidate", {}).get("InferenceContainers", [])

for model_container in model_containers:
Expand All @@ -456,7 +461,7 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
if model_data_url is not None:
job_key = JobKey(job_snapshot.urn, JobDirection.TRAINING)

self.update_model_image_jobs(model_data_url, job_key)
self.update_model_image_jobs(model_data_url, job_key, metrics=metrics)

return SageMakerJob(
job_name=job_name,
Expand Down Expand Up @@ -515,83 +520,6 @@ def process_compilation_job(self, job: Dict[str, Any]) -> SageMakerJob:
output_datasets=output_datasets,
)

def process_edge_packaging_job(
self,
job: Dict[str, Any],
) -> SageMakerJob:
"""
Process outputs from Boto3 describe_edge_packaging_job()

See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
"""

JOB_TYPE = JobType.EDGE_PACKAGING

name: str = job["EdgePackagingJobName"]
arn: str = job["EdgePackagingJobArn"]

output_datasets = {}

model_artifact_s3_uri: Optional[str] = job.get("ModelArtifact")
output_s3_uri: Optional[str] = job.get("OutputConfig", {}).get(
"S3OutputLocation"
)

if model_artifact_s3_uri is not None:
output_datasets[make_s3_urn(model_artifact_s3_uri, self.env)] = {
"dataset_type": "s3",
"uri": model_artifact_s3_uri,
}

if output_s3_uri is not None:
output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
"dataset_type": "s3",
"uri": output_s3_uri,
}

# from docs: "The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged."
compilation_job_name: Optional[str] = job.get("CompilationJobName")

output_jobs = set()
if compilation_job_name is not None:
# globally unique job name
full_job_name = ("compilation", compilation_job_name)

if full_job_name in self.name_to_arn:
output_jobs.add(
make_sagemaker_job_urn(
"compilation",
compilation_job_name,
self.name_to_arn[full_job_name],
self.env,
)
)
else:
self.report.report_warning(
name,
f"Unable to find ARN for compilation job {compilation_job_name} produced by edge packaging job {arn}",
)

job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
job,
JOB_TYPE,
f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/edge-packaging-jobs/{job['EdgePackagingJobName']}",
)

if job.get("ModelName") is not None:
job_key = JobKey(job_snapshot.urn, JobDirection.DOWNSTREAM)

self.update_model_name_jobs(job["ModelName"], job_key)

return SageMakerJob(
job_name=job_name,
job_arn=job_arn,
job_type=JOB_TYPE,
job_snapshot=job_snapshot,
output_datasets=output_datasets,
output_jobs=output_jobs,
)

def process_hyper_parameter_tuning_job(
self,
job: Dict[str, Any],
Expand Down
Loading
Loading