Skip to content

Commit

Permalink
Enhancement/train batch function (#107)
Browse files Browse the repository at this point in the history
* Updated RL docs with latest models

* Added POC for train_batch interface when populating RL datasets

What Changed:
- Custom train_batch method in VPG model
- This generates a batch of data at each time step
- Experience source no longer gets initialized with a device, instead
the correct device is passed to the step() method in the train_batch
function
- Moved experience methods from rl.comon to datamodules

* Updated other models to use train_batch interface

* Update tests/datamodules/test_experience_sources.py

Co-authored-by: Jirka Borovec <[email protected]>

* Fixing lint errors

* Fixed linting errors

* Update pl_bolts/datamodules/experience_source.py

Co-authored-by: Justus Schock <[email protected]>

* Resolved comments

* req

* Removed cyclic import of Agents from experience source

* Updated reference of Experience to datamodules instead of the rl.common

* timeout

* Commented out test_dev_dataset to test run times

* undo commenting out of test_dev_datasets

Co-authored-by: Donal <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Jirka <[email protected]>
  • Loading branch information
5 people authored Jul 11, 2020
1 parent 024b574 commit ca38ad1
Show file tree
Hide file tree
Showing 17 changed files with 496 additions and 169 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
# requires: 'minimal'

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 20
timeout-minutes: 35

steps:
- uses: actions/checkout@v2
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Device is no longer set in the DQN model init
- Moved RL loss function to the losses module
- Moved rl.common.experience to datamodules
- train_batch function to VPG model to generate batch of data at each step (POC)
- Experience source no longer gets initialized with a device, instead the device is passed at each step()

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pandoc
docutils
sphinxcontrib-fulltoc
sphinxcontrib-mockautodoc
gym

git+https://github.com/PytorchLightning/lightning_sphinx_theme.git
# pip_shims
sphinx-autodoc-typehints
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource,
NStepExperienceSource, EpisodicExperienceStream)
202 changes: 202 additions & 0 deletions pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""
Datamodules for RL models that rely on experiences generated during training
Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
"""
from collections import deque, namedtuple
from typing import Iterable, Callable, Tuple, List
import numpy as np
import torch
from torch.utils.data import IterableDataset

# Datasets

Experience = namedtuple(
"Experience", field_names=["state", "action", "reward", "done", "new_state"]
)


class ExperienceSourceDataset(IterableDataset):
"""
Basic experience source dataset. Takes a generate_batch function that returns an iterator.
The logic for the experience source and how the batch is generated is defined the Lightning model itself
"""

def __init__(self, generate_batch: Callable):
self.generate_batch = generate_batch

def __iter__(self) -> Iterable:
iterator = self.generate_batch()
return iterator

# Experience Sources


class ExperienceSource(object):
"""
Basic single step experience source
Args:
env: Environment that is being used
agent: Agent being used to make decisions
"""

def __init__(self, env, agent):
self.env = env
self.agent = agent
self.state = self.env.reset()

def _reset(self) -> None:
"""resets the env and state"""
self.state = self.env.reset()

def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""Takes a single step through the environment"""
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
action=action,
reward=reward,
new_state=new_state,
done=done,
)
self.state = new_state

if done:
self.state = self.env.reset()

return experience, reward, done

def run_episode(self, device: torch.device) -> float:
"""Carries out a single episode and returns the total reward. This is used for testing"""
done = False
total_reward = 0

while not done:
_, reward, done = self.step(device)
total_reward += reward

return total_reward


class NStepExperienceSource(ExperienceSource):
"""Expands upon the basic ExperienceSource by collecting experience across N steps"""

def __init__(self, env, agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent)
self.gamma = gamma
self.n_steps = n_steps
self.n_step_buffer = deque(maxlen=n_steps)

def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""
Takes an n-step in the environment
Returns:
Experience
"""
exp = self.n_step(device)

while len(self.n_step_buffer) < self.n_steps:
self.n_step(device)

reward, next_state, done = self.get_transition_info()
first_experience = self.n_step_buffer[0]
multi_step_experience = Experience(
first_experience.state, first_experience.action, reward, done, next_state
)

return multi_step_experience, exp.reward, exp.done

def n_step(self, device: torch.device) -> Experience:
"""
Takes a single step in the environment and appends it to the n-step buffer
Returns:
Experience
"""
exp, _, _ = super().step(device)
self.n_step_buffer.append(exp)
return exp

def get_transition_info(self) -> Tuple[np.float, np.array, np.int]:
"""
get the accumulated transition info for the n_step_buffer
Args:
gamma: discount factor
Returns:
multi step reward, final observation and done
"""
last_experience = self.n_step_buffer[-1]
final_state = last_experience.new_state
done = last_experience.done
reward = last_experience.reward

# calculate reward
# in reverse order, go through all the experiences up till the first experience
for experience in reversed(list(self.n_step_buffer)[:-1]):
reward_t = experience.reward
new_state_t = experience.new_state
done_t = experience.done

reward = reward_t + self.gamma * reward * (1 - done_t)
final_state, done = (new_state_t, done_t) if done_t else (final_state, done)

return reward, final_state, done


class EpisodicExperienceStream(ExperienceSource, IterableDataset):
"""
Basic experience stream that iteratively yield the current experience of the agent in the env
Args:
env: Environmen that is being used
agent: Agent being used to make decisions
"""

def __init__(self, env, agent, device: torch.device, episodes: int = 1):
super().__init__(env, agent)
self.episodes = episodes
self.device = device

def __getitem__(self, item):
return item

def __iter__(self) -> List[Experience]:
"""
Plays a step through the environment until the episode is complete
Returns:
Batch of all transitions for the entire episode
"""
episode_steps, batch = [], []

while len(batch) < self.episodes:
exp = self.step(self.device)
episode_steps.append(exp)

if exp.done:
batch.append(episode_steps)
episode_steps = []

yield batch

def step(self, device: torch.device) -> Experience:
"""Carries out a single step in the environment"""
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
action=action,
reward=reward,
new_state=new_state,
done=done,
)
self.state = new_state

if done:
self.state = self.env.reset()

return experience
50 changes: 30 additions & 20 deletions pl_bolts/models/rl/common/experience.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Experience sources to be used as datasets for Ligthning DataLoaders
Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py
..note:: Deprecated, these functions have been moved to pl_bolts.datamodules.experience_source.py
"""
import warnings
from collections import deque
from typing import List, Tuple

import numpy as np
import torch
from gym import Env
from torch.utils.data import IterableDataset

Expand All @@ -24,6 +29,8 @@ class RLDataset(IterableDataset):
"""

def __init__(self, buffer: Buffer, sample_size: int = 1) -> None:
warnings.warn("Deprecated, these functions have been moved to pl_bolts.datamodules.experience_source.py",
DeprecationWarning)
self.buffer = buffer
self.sample_size = sample_size

Expand Down Expand Up @@ -74,19 +81,20 @@ class ExperienceSource:
agent: Agent being used to make decisions
"""

def __init__(self, env: Env, agent: Agent, device):
def __init__(self, env: Env, agent: Agent):
warnings.warn("Deprecated, these functions have been moved to pl_bolts.datamodules.experience_source.py",
DeprecationWarning)
self.env = env
self.agent = agent
self.state = self.env.reset()
self.device = device

def _reset(self) -> None:
"""resets the env and state"""
self.state = self.env.reset()

def step(self) -> Tuple[Experience, float, bool]:
def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""Takes a single step through the environment"""
action = self.agent(self.state, self.device)
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
Expand All @@ -102,13 +110,13 @@ def step(self) -> Tuple[Experience, float, bool]:

return experience, reward, done

def run_episode(self) -> float:
def run_episode(self, device: torch.device) -> float:
"""Carries out a single episode and returns the total reward. This is used for testing"""
done = False
total_reward = 0

while not done:
_, reward, done = self.step()
_, reward, done = self.step(device)
total_reward += reward

return total_reward
Expand All @@ -117,22 +125,23 @@ def run_episode(self) -> float:
class NStepExperienceSource(ExperienceSource):
"""Expands upon the basic ExperienceSource by collecting experience across N steps"""

def __init__(self, env: Env, agent: Agent, device, n_steps: int = 1):
super().__init__(env, agent, device)
def __init__(self, env: Env, agent: Agent, n_steps: int = 1, gamma: float = 0.99):
super().__init__(env, agent)
self.gamma = gamma
self.n_steps = n_steps
self.n_step_buffer = deque(maxlen=n_steps)

def step(self) -> Tuple[Experience, float, bool]:
def step(self, device: torch.device) -> Tuple[Experience, float, bool]:
"""
Takes an n-step in the environment
Returns:
Experience
"""
exp = self.single_step()
exp = self.single_step(device)

while len(self.n_step_buffer) < self.n_steps:
self.single_step()
self.single_step(device)

reward, next_state, done = self.get_transition_info()
first_experience = self.n_step_buffer[0]
Expand All @@ -142,18 +151,18 @@ def step(self) -> Tuple[Experience, float, bool]:

return multi_step_experience, exp.reward, exp.done

def single_step(self) -> Experience:
def single_step(self, device: torch.device) -> Experience:
"""
Takes a single step in the environment and appends it to the n-step buffer
Returns:
Experience
"""
exp, _, _ = super().step()
exp, _, _ = super().step(device)
self.n_step_buffer.append(exp)
return exp

def get_transition_info(self, gamma=0.9) -> Tuple[np.float, np.array, np.int]:
def get_transition_info(self) -> Tuple[np.float, np.array, np.int]:
"""
get the accumulated transition info for the n_step_buffer
Args:
Expand All @@ -174,7 +183,7 @@ def get_transition_info(self, gamma=0.9) -> Tuple[np.float, np.array, np.int]:
new_state_t = experience.new_state
done_t = experience.done

reward = reward_t + gamma * reward * (1 - done_t)
reward = reward_t + self.gamma * reward * (1 - done_t)
final_state, done = (new_state_t, done_t) if done_t else (final_state, done)

return reward, final_state, done
Expand All @@ -189,9 +198,10 @@ class EpisodicExperienceStream(ExperienceSource, IterableDataset):
agent: Agent being used to make decisions
"""

def __init__(self, env: Env, agent: Agent, device, episodes: int = 1):
super().__init__(env, agent, device)
def __init__(self, env: Env, agent: Agent, device: torch.device, episodes: int = 1):
super().__init__(env, agent)
self.episodes = episodes
self.device = device

def __getitem__(self, item):
return item
Expand All @@ -206,7 +216,7 @@ def __iter__(self) -> List[Experience]:
episode_steps, batch = [], []

while len(batch) < self.episodes:
exp = self.step()
exp = self.step(self.device)
episode_steps.append(exp)

if exp.done:
Expand All @@ -215,9 +225,9 @@ def __iter__(self) -> List[Experience]:

yield batch

def step(self) -> Experience:
def step(self, device: torch.device) -> Experience:
"""Carries out a single step in the environment"""
action = self.agent(self.state, self.device)
action = self.agent(self.state, device)
new_state, reward, done, _ = self.env.step(action)
experience = Experience(
state=self.state,
Expand Down
Loading

0 comments on commit ca38ad1

Please sign in to comment.