Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
atqy committed Sep 8, 2022
1 parent 0a899ea commit 8a548b9
Showing 1 changed file with 15 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,7 @@
" ],\n",
" outputs=[\n",
" ProcessingOutput(output_name=\"train\", source=\"/opt/ml/processing/train\"),\n",
" ProcessingOutput(\n",
" output_name=\"validation\", source=\"/opt/ml/processing/validation\"\n",
" ),\n",
" ProcessingOutput(output_name=\"validation\", source=\"/opt/ml/processing/validation\"),\n",
" ProcessingOutput(output_name=\"test\", source=\"/opt/ml/processing/test\"),\n",
" ],\n",
" code=\"code/preprocessing.py\",\n",
Expand Down Expand Up @@ -429,9 +427,7 @@
" header=None,\n",
" )\n",
" df_train = df_train.iloc[np.random.permutation(len(df_train))]\n",
" df_train.columns = [\"target\"] + [\n",
" f\"feature_{x}\" for x in range(df_train.shape[1] - 1)\n",
" ]\n",
" df_train.columns = [\"target\"] + [f\"feature_{x}\" for x in range(df_train.shape[1] - 1)]\n",
"\n",
" try:\n",
" df_validation = pd.read_csv(\n",
Expand Down Expand Up @@ -479,18 +475,12 @@
" parser.add_argument(\"--tree_method\", type=str, default=\"auto\")\n",
" parser.add_argument(\"--predictor\", type=str, default=\"auto\")\n",
" parser.add_argument(\"--learning_rate\", type=str, default=\"auto\")\n",
" parser.add_argument(\n",
" \"--output_data_dir\", type=str, default=os.environ.get(\"SM_OUTPUT_DATA_DIR\")\n",
" )\n",
" parser.add_argument(\"--output_data_dir\", type=str, default=os.environ.get(\"SM_OUTPUT_DATA_DIR\"))\n",
" parser.add_argument(\"--model_dir\", type=str, default=os.environ.get(\"SM_MODEL_DIR\"))\n",
" parser.add_argument(\"--train\", type=str, default=os.environ.get(\"SM_CHANNEL_TRAIN\"))\n",
" parser.add_argument(\n",
" \"--validation\", type=str, default=os.environ.get(\"SM_CHANNEL_VALIDATION\")\n",
" )\n",
" parser.add_argument(\"--validation\", type=str, default=os.environ.get(\"SM_CHANNEL_VALIDATION\"))\n",
" parser.add_argument(\"--sm_hosts\", type=str, default=os.environ.get(\"SM_HOSTS\"))\n",
" parser.add_argument(\n",
" \"--sm_current_host\", type=str, default=os.environ.get(\"SM_CURRENT_HOST\")\n",
" )\n",
" parser.add_argument(\"--sm_current_host\", type=str, default=os.environ.get(\"SM_CURRENT_HOST\"))\n",
"\n",
" args, _ = parser.parse_known_args()\n",
"\n",
Expand Down Expand Up @@ -575,9 +565,7 @@
"train_args = xgb_train.fit(\n",
" inputs={\n",
" \"train\": TrainingInput(\n",
" s3_data=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"train\"\n",
" ].S3Output.S3Uri,\n",
" s3_data=step_process.properties.ProcessingOutputConfig.Outputs[\"train\"].S3Output.S3Uri,\n",
" content_type=\"text/csv\",\n",
" ),\n",
" \"validation\": TrainingInput(\n",
Expand Down Expand Up @@ -719,16 +707,12 @@
" destination=\"/opt/ml/processing/model\",\n",
" ),\n",
" ProcessingInput(\n",
" source=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"test\"\n",
" ].S3Output.S3Uri,\n",
" source=step_process.properties.ProcessingOutputConfig.Outputs[\"test\"].S3Output.S3Uri,\n",
" destination=\"/opt/ml/processing/test\",\n",
" ),\n",
" ],\n",
" outputs=[\n",
" ProcessingOutput(\n",
" output_name=\"evaluation\", source=\"/opt/ml/processing/evaluation\"\n",
" ),\n",
" ProcessingOutput(output_name=\"evaluation\", source=\"/opt/ml/processing/evaluation\"),\n",
" ],\n",
" code=\"code/evaluation.py\",\n",
")"
Expand Down Expand Up @@ -815,14 +799,10 @@
" input_data = xgb.DMatrix(data=df)\n",
"\n",
" else:\n",
" raise ValueError(\n",
" \"Content type {} is not supported.\".format(request_content_type)\n",
" )\n",
" raise ValueError(\"Content type {} is not supported.\".format(request_content_type))\n",
"\n",
" prediction = model.predict(input_data)\n",
" feature_contribs = model.predict(\n",
" input_data, pred_contribs=True, validate_features=False\n",
" )\n",
" feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False)\n",
" output = np.hstack((prediction[:, np.newaxis], feature_contribs))\n",
"\n",
" logging.info(\"Successfully completed transform job!\")\n",
Expand Down Expand Up @@ -1093,9 +1073,7 @@
"source": [
"# Get output files from processing job\n",
"\n",
"processing_job_name = steps[\"PipelineExecutionSteps\"][0][\"Metadata\"][\"ProcessingJob\"][\n",
" \"Arn\"\n",
"]\n",
"processing_job_name = steps[\"PipelineExecutionSteps\"][0][\"Metadata\"][\"ProcessingJob\"][\"Arn\"]\n",
"outputs = local_pipeline_session.sagemaker_client.describe_processing_job(\n",
" ProcessingJobName=processing_job_name\n",
")[\"ProcessingOutputConfig\"][\"Outputs\"]\n",
Expand Down Expand Up @@ -1208,9 +1186,7 @@
" ],\n",
" outputs=[\n",
" ProcessingOutput(output_name=\"train\", source=\"/opt/ml/processing/train\"),\n",
" ProcessingOutput(\n",
" output_name=\"validation\", source=\"/opt/ml/processing/validation\"\n",
" ),\n",
" ProcessingOutput(output_name=\"validation\", source=\"/opt/ml/processing/validation\"),\n",
" ProcessingOutput(output_name=\"test\", source=\"/opt/ml/processing/test\"),\n",
" ],\n",
" code=\"code/preprocessing.py\",\n",
Expand Down Expand Up @@ -1261,9 +1237,7 @@
"train_args = xgb_train.fit(\n",
" inputs={\n",
" \"train\": TrainingInput(\n",
" s3_data=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"train\"\n",
" ].S3Output.S3Uri,\n",
" s3_data=step_process.properties.ProcessingOutputConfig.Outputs[\"train\"].S3Output.S3Uri,\n",
" content_type=\"text/csv\",\n",
" ),\n",
" \"validation\": TrainingInput(\n",
Expand Down Expand Up @@ -1306,16 +1280,12 @@
" destination=\"/opt/ml/processing/model\",\n",
" ),\n",
" ProcessingInput(\n",
" source=step_process.properties.ProcessingOutputConfig.Outputs[\n",
" \"test\"\n",
" ].S3Output.S3Uri,\n",
" source=step_process.properties.ProcessingOutputConfig.Outputs[\"test\"].S3Output.S3Uri,\n",
" destination=\"/opt/ml/processing/test\",\n",
" ),\n",
" ],\n",
" outputs=[\n",
" ProcessingOutput(\n",
" output_name=\"evaluation\", source=\"/opt/ml/processing/evaluation\"\n",
" ),\n",
" ProcessingOutput(output_name=\"evaluation\", source=\"/opt/ml/processing/evaluation\"),\n",
" ],\n",
" code=\"code/evaluation.py\",\n",
")\n",
Expand Down

0 comments on commit 8a548b9

Please sign in to comment.