-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhancement/train batch function #107
Merged
williamFalcon
merged 21 commits into
Lightning-Universe:master
from
djbyrne:enhancement/train_batch_function
Jul 11, 2020
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
9c06583
Updated RL docs with latest models
33be076
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
fdc92f9
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
682bbe6
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
17073bc
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
47e5fa0
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
86b0dee
Added POC for train_batch interface when populating RL datasets
885f198
Updated other models to use train_batch interface
896b032
Update tests/datamodules/test_experience_sources.py
djbyrne 2baa02c
Fixing lint errors
08e71f7
Merge branch 'enhancement/train_batch_function' of https://github.com…
db18cd8
Fixed linting errors
0f5ca79
Update pl_bolts/datamodules/experience_source.py
djbyrne 5d9dfa6
Resolved comments
c3f62ac
req
Borda 577569c
Removed cyclic import of Agents from experience source
fa658a4
Merge branch 'enhancement/train_batch_function' of https://github.com…
2292528
Updated reference of Experience to datamodules instead of the rl.common
13cc727
timeout
Borda d4c1cc7
Commented out test_dev_dataset to test run times
djbyrne 04f02cd
undo commenting out of test_dev_datasets
djbyrne File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
""" | ||
Datamodules for RL models that rely on experiences generated during training | ||
|
||
Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py | ||
""" | ||
from collections import deque, namedtuple | ||
from typing import Iterable, Callable, Tuple, List | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import IterableDataset | ||
|
||
# Datasets | ||
|
||
Experience = namedtuple( | ||
"Experience", field_names=["state", "action", "reward", "done", "new_state"] | ||
) | ||
|
||
|
||
class ExperienceSourceDataset(IterableDataset): | ||
""" | ||
Basic experience source dataset. Takes a generate_batch function that returns an iterator. | ||
The logic for the experience source and how the batch is generated is defined the Lightning model itself | ||
""" | ||
|
||
def __init__(self, generate_batch: Callable): | ||
self.generate_batch = generate_batch | ||
|
||
def __iter__(self) -> Iterable: | ||
iterator = self.generate_batch() | ||
return iterator | ||
|
||
# Experience Sources | ||
|
||
|
||
class ExperienceSource(object): | ||
""" | ||
Basic single step experience source | ||
|
||
Args: | ||
env: Environment that is being used | ||
agent: Agent being used to make decisions | ||
""" | ||
|
||
def __init__(self, env, agent): | ||
self.env = env | ||
self.agent = agent | ||
self.state = self.env.reset() | ||
|
||
def _reset(self) -> None: | ||
"""resets the env and state""" | ||
self.state = self.env.reset() | ||
|
||
def step(self, device: torch.device) -> Tuple[Experience, float, bool]: | ||
"""Takes a single step through the environment""" | ||
action = self.agent(self.state, device) | ||
new_state, reward, done, _ = self.env.step(action) | ||
experience = Experience( | ||
state=self.state, | ||
action=action, | ||
reward=reward, | ||
new_state=new_state, | ||
done=done, | ||
) | ||
self.state = new_state | ||
|
||
if done: | ||
self.state = self.env.reset() | ||
|
||
return experience, reward, done | ||
|
||
def run_episode(self, device: torch.device) -> float: | ||
"""Carries out a single episode and returns the total reward. This is used for testing""" | ||
done = False | ||
total_reward = 0 | ||
|
||
while not done: | ||
_, reward, done = self.step(device) | ||
total_reward += reward | ||
|
||
return total_reward | ||
|
||
|
||
class NStepExperienceSource(ExperienceSource): | ||
"""Expands upon the basic ExperienceSource by collecting experience across N steps""" | ||
|
||
def __init__(self, env, agent, n_steps: int = 1, gamma: float = 0.99): | ||
super().__init__(env, agent) | ||
self.gamma = gamma | ||
self.n_steps = n_steps | ||
self.n_step_buffer = deque(maxlen=n_steps) | ||
|
||
def step(self, device: torch.device) -> Tuple[Experience, float, bool]: | ||
djbyrne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Takes an n-step in the environment | ||
|
||
Returns: | ||
Experience | ||
""" | ||
exp = self.n_step(device) | ||
|
||
while len(self.n_step_buffer) < self.n_steps: | ||
self.n_step(device) | ||
|
||
reward, next_state, done = self.get_transition_info() | ||
first_experience = self.n_step_buffer[0] | ||
multi_step_experience = Experience( | ||
first_experience.state, first_experience.action, reward, done, next_state | ||
) | ||
|
||
return multi_step_experience, exp.reward, exp.done | ||
|
||
def n_step(self, device: torch.device) -> Experience: | ||
""" | ||
Takes a single step in the environment and appends it to the n-step buffer | ||
|
||
Returns: | ||
Experience | ||
""" | ||
exp, _, _ = super().step(device) | ||
self.n_step_buffer.append(exp) | ||
return exp | ||
|
||
def get_transition_info(self) -> Tuple[np.float, np.array, np.int]: | ||
""" | ||
get the accumulated transition info for the n_step_buffer | ||
Args: | ||
gamma: discount factor | ||
|
||
Returns: | ||
multi step reward, final observation and done | ||
""" | ||
last_experience = self.n_step_buffer[-1] | ||
final_state = last_experience.new_state | ||
done = last_experience.done | ||
reward = last_experience.reward | ||
|
||
# calculate reward | ||
# in reverse order, go through all the experiences up till the first experience | ||
for experience in reversed(list(self.n_step_buffer)[:-1]): | ||
reward_t = experience.reward | ||
new_state_t = experience.new_state | ||
done_t = experience.done | ||
|
||
reward = reward_t + self.gamma * reward * (1 - done_t) | ||
final_state, done = (new_state_t, done_t) if done_t else (final_state, done) | ||
|
||
return reward, final_state, done | ||
|
||
|
||
class EpisodicExperienceStream(ExperienceSource, IterableDataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question about wording with episodic |
||
""" | ||
Basic experience stream that iteratively yield the current experience of the agent in the env | ||
|
||
Args: | ||
env: Environmen that is being used | ||
agent: Agent being used to make decisions | ||
""" | ||
|
||
def __init__(self, env, agent, device: torch.device, episodes: int = 1): | ||
super().__init__(env, agent) | ||
self.episodes = episodes | ||
self.device = device | ||
|
||
def __getitem__(self, item): | ||
return item | ||
|
||
def __iter__(self) -> List[Experience]: | ||
""" | ||
Plays a step through the environment until the episode is complete | ||
|
||
Returns: | ||
Batch of all transitions for the entire episode | ||
""" | ||
episode_steps, batch = [], [] | ||
|
||
while len(batch) < self.episodes: | ||
exp = self.step(self.device) | ||
episode_steps.append(exp) | ||
|
||
if exp.done: | ||
batch.append(episode_steps) | ||
episode_steps = [] | ||
|
||
yield batch | ||
|
||
def step(self, device: torch.device) -> Experience: | ||
"""Carries out a single step in the environment""" | ||
action = self.agent(self.state, device) | ||
new_state, reward, done, _ = self.env.step(action) | ||
experience = Experience( | ||
state=self.state, | ||
action=action, | ||
reward=reward, | ||
new_state=new_state, | ||
done=done, | ||
) | ||
self.state = new_state | ||
|
||
if done: | ||
self.state = self.env.reset() | ||
|
||
return experience |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
episode
a common RL term for this? Intuitively I would have called this sequence...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It depends on the task. Most tasks are Episodic in some form and will have a termination state denoting the end of the episode. This function was originally used for carrying out a validation episode and is useful