Skip to content

Commit

Permalink
Generate unique SageMaker training job name based on pipeline and ste… (
Browse files Browse the repository at this point in the history
#1505)

* Generate unique SageMaker training job name based on pipeline and step name

* Client import

* Reformatted based on format.sh

* Replace random string generation with string_utils.random_str

* Revert "Reformatted based on format.sh"

This reverts commit 98ca3f1.

* Ruff and black changes

---------

Co-authored-by: Alex Strick van Linschoten <[email protected]>
  • Loading branch information
christianversloot and strickvl authored May 2, 2023
1 parent d8e1d1c commit 0ab4c5c
Showing 1 changed file with 13 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import sagemaker

from zenml.client import Client
from zenml.config.build_configuration import BuildConfiguration
from zenml.enums import StackComponentType
from zenml.integrations.aws.flavors.sagemaker_step_operator_flavor import (
Expand All @@ -26,6 +27,7 @@
from zenml.logger import get_logger
from zenml.stack import Stack, StackValidator
from zenml.step_operators import BaseStepOperator
from zenml.utils.string_utils import random_str

if TYPE_CHECKING:
from zenml.config.base_settings import BaseSettings
Expand Down Expand Up @@ -179,9 +181,16 @@ def launch(
image_name, self.config.role, **estimator_args
)

# SageMaker allows 63 characters at maximum for job name - ZenML uses 60 for safety margin.
step_name = Client().get_run_step(info.step_run_id).name
training_job_name = f"{info.pipeline.name}-{step_name}"[:55]
suffix = random_str(4)
unique_training_job_name = f"{training_job_name}-{suffix}"

# Sagemaker doesn't allow any underscores in job/experiment/trial names
job_name = f"{info.run_name}-{info.pipeline_step_name}"
job_name = job_name.replace("_", "-")
sanitized_training_job_name = unique_training_job_name.replace(
"_", "-"
)

# Construct training input object, if necessary
inputs = None
Expand All @@ -201,12 +210,12 @@ def launch(
if settings.experiment_name:
experiment_config = {
"ExperimentName": settings.experiment_name,
"TrialName": job_name,
"TrialName": sanitized_training_job_name,
}

estimator.fit(
wait=True,
inputs=inputs,
experiment_config=experiment_config,
job_name=job_name,
job_name=sanitized_training_job_name,
)

0 comments on commit 0ab4c5c

Please sign in to comment.