diff --git a/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb b/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb index 98d64661b9..d0dac37bc9 100644 --- a/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb +++ b/training/distributed_training/pytorch/model_parallel/gpt2/smp-train-gpt-simple.ipynb @@ -545,7 +545,7 @@ " kwargs[\"security_group_ids\"] = [fsx_security_group_id]\n", " kwargs[\"subnets\"] = [fsx_subnet]\n", "\n", - "smp_estimator = HuggingFace(\n", + "smp_estimator = PyTorch(\n", " entry_point=\"train_gpt_simple.py\",\n", " source_dir=os.getcwd(),\n", " role=role,\n", @@ -579,8 +579,8 @@ " }\n", " },\n", " },\n", - " pytorch_version=\"1.11\",\n", - " transformers_version=\"4.17\",\n", + " pytorch_version=\"1.12\",\n", + " transformers_version=\"4.21\",\n", " py_version=\"py38\",\n", " output_path=s3_output_location,\n", " checkpoint_s3_uri=checkpoint_s3_uri if not use_fsx else None,\n",