Skip to content

Commit

Permalink
Merge pull request pytorch#88 from chsasank/rl_cuda
Browse files Browse the repository at this point in the history
Use cuda on RL tutorial
  • Loading branch information
chsasank authored May 30, 2017
2 parents 78ff247 + 7ae5a2c commit 05ffab1
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions intermediate_source/reinforcement_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,28 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T


env = gym.make('CartPole-v0').unwrapped

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display

plt.ion()

# if gpu is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor


######################################################################
# Replay Memory
# -------------
Expand Down Expand Up @@ -260,12 +271,12 @@ def get_screen():
screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
screen = torch.from_numpy(screen)
# Resize, and add a batch dimension (BCHW)
return resize(screen).unsqueeze(0)
return resize(screen).unsqueeze(0).type(Tensor)

env.reset()
plt.figure()
plt.imshow(get_screen().squeeze(0).permute(
1, 2, 0).numpy(), interpolation='none')
plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(),
interpolation='none')
plt.title('Example extracted screen')
plt.show()

Expand Down Expand Up @@ -300,22 +311,14 @@ def get_screen():
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
USE_CUDA = torch.cuda.is_available()

model = DQN()
memory = ReplayMemory(10000)
optimizer = optim.RMSprop(model.parameters())

if USE_CUDA:
if use_cuda:
model.cuda()


class Variable(autograd.Variable):

def __init__(self, data, *args, **kwargs):
if USE_CUDA:
data = data.cuda()
super(Variable, self).__init__(data, *args, **kwargs)
optimizer = optim.RMSprop(model.parameters())
memory = ReplayMemory(10000)


steps_done = 0
Expand All @@ -328,9 +331,10 @@ def select_action(state):
math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
return model(Variable(state, volatile=True)).data.max(1)[1].cpu()
return model(
Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1]
else:
return torch.LongTensor([[random.randrange(2)]])
return LongTensor([[random.randrange(2)]])


episode_durations = []
Expand All @@ -339,7 +343,7 @@ def select_action(state):
def plot_durations():
plt.figure(2)
plt.clf()
durations_t = torch.Tensor(episode_durations)
durations_t = torch.FloatTensor(episode_durations)
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
Expand All @@ -349,6 +353,8 @@ def plot_durations():
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())

plt.pause(0.001) # pause a bit so that plots are updated
if is_ipython:
display.clear_output(wait=True)
display.display(plt.gcf())
Expand All @@ -370,6 +376,7 @@ def plot_durations():

last_sync = 0


def optimize_model():
global last_sync
if len(memory) < BATCH_SIZE:
Expand All @@ -380,10 +387,9 @@ def optimize_model():
batch = Transition(*zip(*transitions))

# Compute a mask of non-final states and concatenate the batch elements
non_final_mask = torch.ByteTensor(
tuple(map(lambda s: s is not None, batch.next_state)))
if USE_CUDA:
non_final_mask = non_final_mask.cuda()
non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,
batch.next_state)))

# We don't want to backprop through the expected action values and volatile
# will save us on temporarily changing the model parameters'
# requires_grad to False!
Expand All @@ -399,7 +405,7 @@ def optimize_model():
state_action_values = model(state_batch).gather(1, action_batch)

# Compute V(s_{t+1}) for all next states.
next_state_values = Variable(torch.zeros(BATCH_SIZE))
next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))
next_state_values[non_final_mask] = model(non_final_next_states).max(1)[0]
# Now, we don't want to mess up the loss with a volatile flag, so let's
# clear it. After this, we'll just end up with a Variable that has
Expand Down Expand Up @@ -440,7 +446,7 @@ def optimize_model():
# Select and perform an action
action = select_action(state)
_, reward, done, _ = env.step(action[0, 0])
reward = torch.Tensor([reward])
reward = Tensor([reward])

# Observe new state
last_screen = current_screen
Expand All @@ -463,6 +469,8 @@ def optimize_model():
plot_durations()
break

print('Complete')
env.render(close=True)
env.close()
plt.ioff()
plt.show()

0 comments on commit 05ffab1

Please sign in to comment.