Skip to content

Commit

Permalink
Make RL training compatible with PyTorch (#1520)
Browse files Browse the repository at this point in the history
* Make RLEstimator() PyTorch compatible & modify cartpole notebook

* set use_pytorch to False by default

* minor refactor; check in first unit test

* indent correction
  • Loading branch information
annaluo676 authored Oct 9, 2020
1 parent fbdca81 commit 09ad9a7
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 22 deletions.
2 changes: 2 additions & 0 deletions reinforcement_learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ These examples demonstrate how to train reinforcement learning models on SageMak

**IMPORTANT for rllib users:** Some examples may break with latest [rllib](https://docs.ray.io/en/latest/rllib.html) due to breaking API changes. Please refer to [Amazon SageMaker RL Container](https://github.com/aws/sagemaker-rl-container) for the latest public images and modify the configs in entrypoint scripts according to [rllib algorithm config](https://docs.ray.io/en/latest/rllib-algorithms.html).

If you are using PyTorch rather than TensorFlow, please set `debugger_hook_config=False` when calling `RLEstimator()` to avoid TensorBoard conflicts.

- [Contextual Bandit with Live Environment](bandits_statlog_vw_customEnv) illustrates how you can manage your own contextual multi-armed bandit workflow on SageMaker using the built-in [Vowpal Wabbit](https://github.com/VowpalWabbit/vowpal_wabbit) (VW) container to train and deploy contextual bandit models.
- [Cartpole](rl_cartpole_coach) uses SageMaker RL base [docker image](https://github.com/aws/sagemaker-rl-container) to balance a broom upright.
- [Cartpole Batch](rl_cartpole_batch_coach) uses batch RL techniques to train Cartpole with offline data.
Expand Down
11 changes: 8 additions & 3 deletions reinforcement_learning/common/sagemaker_rl/ray_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,13 @@ def create_tf_serving_model(self, algorithm=None, env_string=None):
agent.restore(checkpoint)
export_tf_serving(agent, MODEL_OUTPUT_DIR)

def save_checkpoint_and_serving_model(self, algorithm=None, env_string=None):
def save_checkpoint_and_serving_model(self, algorithm=None, env_string=None, use_pytorch=False):
self.save_experiment_config()
self.copy_checkpoints_to_model_output()
self.create_tf_serving_model(algorithm, env_string)
if use_pytorch:
print("Skipped PyTorch serving.")
else:
self.create_tf_serving_model(algorithm, env_string)

# To ensure SageMaker local mode works fine
change_permissions_recursive(INTERMEDIATE_DIR, 0o777)
Expand Down Expand Up @@ -335,8 +338,10 @@ def launch(self):

algo = experiment_config["training"]["run"]
env_string = experiment_config["training"]["config"]["env"]
use_pytorch = experiment_config["training"]["config"].get("use_pytorch", False)
self.save_checkpoint_and_serving_model(algorithm=algo,
env_string=env_string)
env_string=env_string,
use_pytorch=use_pytorch)

@classmethod
def train_main(cls):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def change_permissions_recursive(path, mode):
for root, dirs, files in os.walk(path, topdown=False):
for dir in [os.path.join(root, d) for d in dirs]:
os.chmod(dir, mode)
for file in [os.path.join(root, f) for f in files]:
os.chmod(file, mode)
for file in [os.path.join(root, f) for f in files]:
os.chmod(file, mode)


def export_tf_serving(agent, output_dir):
Expand Down
Empty file.
Empty file.
31 changes: 31 additions & 0 deletions reinforcement_learning/common/tests/unit/test_ray_launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import pytest
from mock import Mock, MagicMock, patch

from sagemaker_rl.ray_launcher import SageMakerRayLauncher

@patch("sagemaker_rl.ray_launcher.SageMakerRayLauncher.__init__", return_value=None)
@patch("sagemaker_rl.ray_launcher.change_permissions_recursive")
def test_pytorch_save_checkpoint_and_serving_model(change_permission, launcher_init):
launcher = SageMakerRayLauncher()
launcher.copy_checkpoints_to_model_output = Mock()
launcher.create_tf_serving_model = Mock()
launcher.save_experiment_config = Mock()

launcher.save_checkpoint_and_serving_model(use_pytorch=True)
launcher.create_tf_serving_model.assert_not_called()
launcher.save_checkpoint_and_serving_model(use_pytorch=False)
launcher.create_tf_serving_model.assert_called_once()
assert 4 == change_permission.call_count
33 changes: 16 additions & 17 deletions reinforcement_learning/rl_cartpole_ray/rl_cartpole_ray_gymEnv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"---\n",
"## Introduction\n",
"\n",
"In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Ray RLlib \n",
"In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Ray RLlib. You can choose either TensorFlow or PyTorch as your underlying DL framework.\n",
"\n",
"(For a similar example using Coach library, see this [link](../rl_cartpole_coach/rl_cartpole_coach_gymEnv.ipynb). Another Cart-pole example using Coach library and offline data can be found [here](../rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb).)\n",
"\n",
Expand Down Expand Up @@ -196,7 +196,8 @@
"\n",
"cpu_or_gpu = 'gpu' if instance_type.startswith('ml.p') else 'cpu'\n",
"aws_region = boto3.Session().region_name\n",
"custom_image_name = \"462105765813.dkr.ecr.%s.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.5-tf-%s-py36\" % (aws_region, cpu_or_gpu)\n",
"framework = 'tf' # change to 'torch' for PyTorch training\n",
"custom_image_name = \"462105765813.dkr.ecr.%s.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.5-%s-%s-py36\" % (aws_region, framework, cpu_or_gpu)\n",
"custom_image_name"
]
},
Expand All @@ -206,8 +207,10 @@
"source": [
"## Write the Training Code\n",
"\n",
"The training code is written in the file “train-coach.py” which is uploaded in the /src directory. \n",
"First import the environment files and the preset files, and then define the main() function. "
"The training code is written in the file “train-rl-cartpole-ray.py” which is uploaded in the /src directory. \n",
"First import the environment files and the preset files, and then define the main() function. \n",
"\n",
"**Note**: If PyTorch is used, plese update the above training code and set `use_pytorch` to `True` in the config."
]
},
{
Expand All @@ -218,7 +221,7 @@
},
"outputs": [],
"source": [
"!pygmentize src/train-{job_name_prefix}.py"
"!pygmentize src/train-rl-cartpole-ray.py"
]
},
{
Expand Down Expand Up @@ -249,11 +252,12 @@
"\n",
"metric_definitions = RLEstimator.default_metric_definitions(RLToolkit.RAY)\n",
" \n",
"estimator = RLEstimator(entry_point=\"train-%s.py\" % job_name_prefix,\n",
"estimator = RLEstimator(entry_point=\"train-rl-cartpole-ray.py\",\n",
" source_dir='src',\n",
" dependencies=[\"common/sagemaker_rl\"],\n",
" image_name=custom_image_name,\n",
" role=role,\n",
" debugger_hook_config=False,\n",
" train_instance_type=instance_type,\n",
" train_instance_count=1,\n",
" output_path=s3_output_path,\n",
Expand Down Expand Up @@ -456,22 +460,17 @@
"print(\"Evaluation job: %s\" % job_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Visualize the output \n",
"\n",
"Optionally, you can run the steps defined earlier to visualize the output."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model deployment\n",
"\n",
"Now let us deploy the RL policy so that we can get the optimal action, given an environment observation."
"Now let us deploy the RL policy so that we can get the optimal action, given an environment observation.\n",
"\n",
"**Note**: Model deployment is supported for TensorFLow only at current stage. \n",
"\n",
"STOP HERE IF PYTORCH IS USED."
]
},
{
Expand Down Expand Up @@ -563,4 +562,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_experiment_config(self):
"training_iteration": 40
},
"config": {
"use_pytorch": False,
"gamma": 0.99,
"kl_coeff": 1.0,
"num_sgd_iter": 20,
Expand Down

0 comments on commit 09ad9a7

Please sign in to comment.