Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
atqy committed Sep 8, 2022
1 parent dc7c060 commit 875c110
Showing 1 changed file with 29 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,19 +135,26 @@
"\n",
"predictor = None\n",
"\n",
"\n",
"def get_model():\n",
" model_name = 'EleutherAI/gpt-j-6B'\n",
" tensor_parallel = int(os.getenv('TENSOR_PARALLEL_DEGREE', '2'))\n",
" local_rank = int(os.getenv('LOCAL_RANK', '0'))\n",
" model = AutoModelForCausalLM.from_pretrained(model_name, revision=\"float32\", torch_dtype=torch.float32)\n",
" model_name = \"EleutherAI/gpt-j-6B\"\n",
" tensor_parallel = int(os.getenv(\"TENSOR_PARALLEL_DEGREE\", \"2\"))\n",
" local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" model_name, revision=\"float32\", torch_dtype=torch.float32\n",
" )\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" \n",
" model = deepspeed.init_inference(model,\n",
" mp_size=tensor_parallel,\n",
" dtype=model.dtype,\n",
" replace_method='auto',\n",
" replace_with_kernel_inject=True)\n",
" generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=local_rank)\n",
"\n",
" model = deepspeed.init_inference(\n",
" model,\n",
" mp_size=tensor_parallel,\n",
" dtype=model.dtype,\n",
" replace_method=\"auto\",\n",
" replace_with_kernel_inject=True,\n",
" )\n",
" generator = pipeline(\n",
" task=\"text-generation\", model=model, tokenizer=tokenizer, device=local_rank\n",
" )\n",
" return generator\n",
"\n",
"\n",
Expand Down Expand Up @@ -190,7 +197,7 @@
"source": [
"%%writefile serving.properties\n",
"\n",
"engine=Rubikon"
"engine = Rubikon"
]
},
{
Expand All @@ -213,11 +220,11 @@
"session = sagemaker.Session()\n",
"account = session.account_id()\n",
"region = session.boto_region_name\n",
"img = 'djl_deepspeed'\n",
"fullname = account+'.dkr.ecr.'+region+'amazonaws.com/'+img+':latest'\n",
"img = \"djl_deepspeed\"\n",
"fullname = account + \".dkr.ecr.\" + region + \"amazonaws.com/\" + img + \":latest\"\n",
"\n",
"bucket = session.default_bucket()\n",
"path = 's3://' + bucket + '/DEMO-djl-big-model/'"
"path = \"s3://\" + bucket + \"/DEMO-djl-big-model/\""
]
},
{
Expand Down Expand Up @@ -362,18 +369,16 @@
"source": [
"import boto3, json\n",
"\n",
"client = boto3.client('sagemaker-runtime')\n",
"client = boto3.client(\"sagemaker-runtime\")\n",
"\n",
"endpoint_name = \"gpt-j\" # Your endpoint name.\n",
"content_type = \"text/plain\" # The MIME type of the input data in the request body.\n",
"endpoint_name = \"gpt-j\" # Your endpoint name.\n",
"content_type = \"text/plain\" # The MIME type of the input data in the request body.\n",
"# accept = \"...\" # The desired MIME type of the inference in the response.\n",
"payload = \"Amazon.com is the best\" # Payload for inference.\n",
"payload = \"Amazon.com is the best\" # Payload for inference.\n",
"response = client.invoke_endpoint(\n",
" EndpointName=endpoint_name, \n",
" ContentType=content_type,\n",
" Body=payload\n",
" )\n",
"print(response['Body'].read())"
" EndpointName=endpoint_name, ContentType=content_type, Body=payload\n",
")\n",
"print(response[\"Body\"].read())"
]
},
{
Expand Down

0 comments on commit 875c110

Please sign in to comment.