From 875c110b2ab4953d1273848398dd4b1227a80f60 Mon Sep 17 00:00:00 2001 From: atqy Date: Thu, 8 Sep 2022 16:49:11 -0700 Subject: [PATCH] reformat --- ...PT-J-6B-model-parallel-inference-DJL.ipynb | 53 ++++++++++--------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/advanced_functionality/pytorch_deploy_large_GPT_model/GPT-J-6B-model-parallel-inference-DJL.ipynb b/advanced_functionality/pytorch_deploy_large_GPT_model/GPT-J-6B-model-parallel-inference-DJL.ipynb index d40325d877..c36008b831 100644 --- a/advanced_functionality/pytorch_deploy_large_GPT_model/GPT-J-6B-model-parallel-inference-DJL.ipynb +++ b/advanced_functionality/pytorch_deploy_large_GPT_model/GPT-J-6B-model-parallel-inference-DJL.ipynb @@ -135,19 +135,26 @@ "\n", "predictor = None\n", "\n", + "\n", "def get_model():\n", - " model_name = 'EleutherAI/gpt-j-6B'\n", - " tensor_parallel = int(os.getenv('TENSOR_PARALLEL_DEGREE', '2'))\n", - " local_rank = int(os.getenv('LOCAL_RANK', '0'))\n", - " model = AutoModelForCausalLM.from_pretrained(model_name, revision=\"float32\", torch_dtype=torch.float32)\n", + " model_name = \"EleutherAI/gpt-j-6B\"\n", + " tensor_parallel = int(os.getenv(\"TENSOR_PARALLEL_DEGREE\", \"2\"))\n", + " local_rank = int(os.getenv(\"LOCAL_RANK\", \"0\"))\n", + " model = AutoModelForCausalLM.from_pretrained(\n", + " model_name, revision=\"float32\", torch_dtype=torch.float32\n", + " )\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", - " \n", - " model = deepspeed.init_inference(model,\n", - " mp_size=tensor_parallel,\n", - " dtype=model.dtype,\n", - " replace_method='auto',\n", - " replace_with_kernel_inject=True)\n", - " generator = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=local_rank)\n", + "\n", + " model = deepspeed.init_inference(\n", + " model,\n", + " mp_size=tensor_parallel,\n", + " dtype=model.dtype,\n", + " replace_method=\"auto\",\n", + " replace_with_kernel_inject=True,\n", + " )\n", + " generator = pipeline(\n", + " task=\"text-generation\", model=model, tokenizer=tokenizer, device=local_rank\n", + " )\n", " return generator\n", "\n", "\n", @@ -190,7 +197,7 @@ "source": [ "%%writefile serving.properties\n", "\n", - "engine=Rubikon" + "engine = Rubikon" ] }, { @@ -213,11 +220,11 @@ "session = sagemaker.Session()\n", "account = session.account_id()\n", "region = session.boto_region_name\n", - "img = 'djl_deepspeed'\n", - "fullname = account+'.dkr.ecr.'+region+'amazonaws.com/'+img+':latest'\n", + "img = \"djl_deepspeed\"\n", + "fullname = account + \".dkr.ecr.\" + region + \"amazonaws.com/\" + img + \":latest\"\n", "\n", "bucket = session.default_bucket()\n", - "path = 's3://' + bucket + '/DEMO-djl-big-model/'" + "path = \"s3://\" + bucket + \"/DEMO-djl-big-model/\"" ] }, { @@ -362,18 +369,16 @@ "source": [ "import boto3, json\n", "\n", - "client = boto3.client('sagemaker-runtime')\n", + "client = boto3.client(\"sagemaker-runtime\")\n", "\n", - "endpoint_name = \"gpt-j\" # Your endpoint name.\n", - "content_type = \"text/plain\" # The MIME type of the input data in the request body.\n", + "endpoint_name = \"gpt-j\" # Your endpoint name.\n", + "content_type = \"text/plain\" # The MIME type of the input data in the request body.\n", "# accept = \"...\" # The desired MIME type of the inference in the response.\n", - "payload = \"Amazon.com is the best\" # Payload for inference.\n", + "payload = \"Amazon.com is the best\" # Payload for inference.\n", "response = client.invoke_endpoint(\n", - " EndpointName=endpoint_name, \n", - " ContentType=content_type,\n", - " Body=payload\n", - " )\n", - "print(response['Body'].read())" + " EndpointName=endpoint_name, ContentType=content_type, Body=payload\n", + ")\n", + "print(response[\"Body\"].read())" ] }, {