diff --git a/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/dtree_evaluate.py b/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/dtree_evaluate.py index 93f0a0961a..d94c9c2cbc 100644 --- a/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/dtree_evaluate.py +++ b/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/dtree_evaluate.py @@ -50,10 +50,15 @@ test_path = "/opt/ml/processing/test/test.csv" - logger.info("Loading test input data") + logger.info("Loading test input data") df = pd.read_csv(test_path, header=None) + logger.debug("Reading test data.") + y_test = df.iloc[:, 0].to_numpy() + df.drop(df.columns[0], axis=1, inplace=True) + X_test = numpy.array(df.values) + logger.info("Performing predictions against test data.") predictions = model.predict(X_test) diff --git a/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/pipeline.py b/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/pipeline.py index 2614eb7b46..3cccbdaa86 100644 --- a/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/pipeline.py +++ b/sagemaker-pipeline-multi-model/sagemaker-pipeline/pipelines/restate/pipeline.py @@ -185,7 +185,6 @@ def get_pipeline( print(f"Data Wrangler flow {flow_file_name} uploaded to {flow_s3_uri}") - ## Input - Flow: restate-athena-russia.flow flow_input = ProcessingInput( source=flow_s3_uri, destination="/opt/ml/processing/flow", @@ -347,7 +346,6 @@ def get_pipeline( cache_config=cache_config, ) - # dtree_image_uri = '625467769535.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-decision-tree:latest' dtree_image_uri = sagemaker_session.sagemaker_client.describe_image_version( ImageName="restate-dtree" )["ContainerImage"]