Skip to content

Commit

Permalink
black-nb
Browse files Browse the repository at this point in the history
  • Loading branch information
Qingwei Li committed Sep 9, 2022
1 parent 841d4d2 commit bbea127
Showing 1 changed file with 19 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@
"account = session.account_id()\n",
"region = session.boto_region_name\n",
"img = \"djl_deepspeed\"\n",
"fullname = account+'.dkr.ecr.'+region+'.amazonaws.com/'+img+':latest'\n",
"fullname = account + \".dkr.ecr.\" + region + \".amazonaws.com/\" + img + \":latest\"\n",
"bucket = session.default_bucket()\n",
"path = \"s3://\" + bucket + \"/DEMO-djl-big-model\""
]
Expand Down Expand Up @@ -261,7 +261,9 @@
"metadata": {},
"outputs": [],
"source": [
"model_s3_url = sagemaker.s3.S3Uploader.upload('gpt-j.tar.gz', path, kms_key=None, sagemaker_session=session)"
"model_s3_url = sagemaker.s3.S3Uploader.upload(\n",
" \"gpt-j.tar.gz\", path, kms_key=None, sagemaker_session=session\n",
")"
]
},
{
Expand All @@ -288,21 +290,21 @@
"outputs": [],
"source": [
"from datetime import datetime\n",
"sm_client = boto3.client('sagemaker')\n",
"\n",
"sm_client = boto3.client(\"sagemaker\")\n",
"\n",
"time_stamp = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n",
"model_name = 'gpt-j-' + time_stamp\n",
"model_name = \"gpt-j-\" + time_stamp\n",
"\n",
"create_model_response = sm_client.create_model(\n",
" ModelName = model_name,\n",
" ExecutionRoleArn = session.get_caller_identity_arn(),\n",
" PrimaryContainer = {\n",
" 'Image': fullname,\n",
" 'ModelDataUrl': model_s3_url,\n",
" 'Environment': {\n",
" 'TENSOR_PARALLEL_DEGREE': '2'\n",
" }\n",
" })"
" ModelName=model_name,\n",
" ExecutionRoleArn=session.get_caller_identity_arn(),\n",
" PrimaryContainer={\n",
" \"Image\": fullname,\n",
" \"ModelDataUrl\": model_s3_url,\n",
" \"Environment\": {\"TENSOR_PARALLEL_DEGREE\": \"2\"},\n",
" },\n",
")"
]
},
{
Expand All @@ -322,17 +324,17 @@
"source": [
"initial_instance_count = 1\n",
"instance_type = \"ml.g5.48xlarge\"\n",
"variant_name = \"AllTraffic\" \n",
"endpoint_config_name = \"t-j-config-\"+time_stamp\n",
"variant_name = \"AllTraffic\"\n",
"endpoint_config_name = \"t-j-config-\" + time_stamp\n",
"\n",
"production_variants = [\n",
" {\n",
" \"VariantName\": variant_name,\n",
" \"ModelName\": model_name,\n",
" \"InitialInstanceCount\": initial_instance_count,\n",
" \"InstanceType\": instance_type,\n",
" 'ModelDataDownloadTimeoutInSeconds':1800,\n",
" 'ContainerStartupHealthCheckTimeoutInSeconds':3600\n",
" \"ModelDataDownloadTimeoutInSeconds\": 1800,\n",
" \"ContainerStartupHealthCheckTimeoutInSeconds\": 3600,\n",
" }\n",
"]\n",
"\n",
Expand Down

0 comments on commit bbea127

Please sign in to comment.