Skip to content

Commit

Permalink
Add Value Function and corresponding example script to Diffuser imple…
Browse files Browse the repository at this point in the history
…mentation (#884)

* valuefunction code

* start example scripts

* missing imports

* bug fixes and placeholder example script

* add value function scheduler

* load value function from hub and get best actions in example

* very close to working example

* larger batch size for planning

* more tests

* merge unet1d changes

* wandb for debugging, use newer models

* success!

* turns out we just need more diffusion steps

* run on modal

* merge and code cleanup

* use same api for rl model

* fix variance type

* wrong normalization function

* add tests

* style

* style and quality

* edits based on comments

* style and quality

* remove unused var

* hack unet1d into a value function

* add pipeline

* fix arg order

* add pipeline to core library

* community pipeline

* fix couple shape bugs

* style

* Apply suggestions from code review

Co-authored-by: Nathan Lambert <[email protected]>
  • Loading branch information
bglick13 and Nathan Lambert authored Oct 21, 2022
1 parent a6314f6 commit 48a7414
Show file tree
Hide file tree
Showing 14 changed files with 1,143 additions and 28 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ tags
*.lock

# DS_Store (MacOS)
.DS_Store
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4
99 changes: 99 additions & 0 deletions examples/community/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch

import tqdm
from diffusers import DiffusionPipeline
from diffusers.models.unet_1d import UNet1DModel
from diffusers.utils.dummy_pt_objects import DDPMScheduler


class ValueGuidedDiffuserPipeline(DiffusionPipeline):
def __init__(
self,
value_function: UNet1DModel,
unet: UNet1DModel,
scheduler: DDPMScheduler,
env,
):
super().__init__()
self.value_function = value_function
self.unet = unet
self.scheduler = scheduler
self.env = env
self.data = env.get_dataset()
self.means = dict()
for key in self.data.keys():
try:
self.means[key] = self.data[key].mean()
except:
pass
self.stds = dict()
for key in self.data.keys():
try:
self.stds[key] = self.data[key].std()
except:
pass
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]

def normalize(self, x_in, key):
return (x_in - self.means[key]) / self.stds[key]

def de_normalize(self, x_in, key):
return x_in * self.stds[key] + self.means[key]

def to_torch(self, x_in):
if type(x_in) is dict:
return {k: self.to_torch(v) for k, v in x_in.items()}
elif torch.is_tensor(x_in):
return x_in.to(self.unet.device)
return torch.tensor(x_in, device=self.unet.device)

def reset_x0(self, x_in, cond, act_dim):
for key, val in cond.items():
x_in[:, key, act_dim:] = val.clone()
return x_in

def run_diffusion(self, x, conditions, n_guide_steps, scale):
batch_size = x.shape[0]
y = None
for i in tqdm.tqdm(self.scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
# 3. call the sample function
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
grad = torch.autograd.grad([y.sum()], [x])[0]

posterior_variance = self.scheduler._get_variance(i)
model_std = torch.exp(0.5 * posterior_variance)
grad = model_std * grad
grad[timesteps < 2] = 0
x = x.detach()
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]

# 4. apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y

def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
obs = self.normalize(obs, "observations")
obs = obs[None].repeat(batch_size, axis=0)
conditions = {0: self.to_torch(obs)}
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
x1 = torch.randn(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
sorted_idx = y.argsort(0, descending=True).squeeze()
sorted_values = x[sorted_idx]
actions = sorted_values[:, :, : self.action_dim]
actions = actions.detach().cpu().numpy()
denorm_actions = self.de_normalize(actions, key="actions")
denorm_actions = denorm_actions[0, 0]
return denorm_actions
99 changes: 99 additions & 0 deletions examples/community/value_guided_diffuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torch

import tqdm
from diffusers import DiffusionPipeline
from diffusers.models.unet_1d import UNet1DModel
from diffusers.utils.dummy_pt_objects import DDPMScheduler


class ValueGuidedDiffuserPipeline(DiffusionPipeline):
def __init__(
self,
value_function: UNet1DModel,
unet: UNet1DModel,
scheduler: DDPMScheduler,
env,
):
super().__init__()
self.value_function = value_function
self.unet = unet
self.scheduler = scheduler
self.env = env
self.data = env.get_dataset()
self.means = dict()
for key in self.data.keys():
try:
self.means[key] = self.data[key].mean()
except:
pass
self.stds = dict()
for key in self.data.keys():
try:
self.stds[key] = self.data[key].std()
except:
pass
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.shape[0]

def normalize(self, x_in, key):
return (x_in - self.means[key]) / self.stds[key]

def de_normalize(self, x_in, key):
return x_in * self.stds[key] + self.means[key]

def to_torch(self, x_in):
if type(x_in) is dict:
return {k: self.to_torch(v) for k, v in x_in.items()}
elif torch.is_tensor(x_in):
return x_in.to(self.unet.device)
return torch.tensor(x_in, device=self.unet.device)

def reset_x0(self, x_in, cond, act_dim):
for key, val in cond.items():
x_in[:, key, act_dim:] = val.clone()
return x_in

def run_diffusion(self, x, conditions, n_guide_steps, scale):
batch_size = x.shape[0]
y = None
for i in tqdm.tqdm(self.scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
# 3. call the sample function
for _ in range(n_guide_steps):
with torch.enable_grad():
x.requires_grad_()
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
grad = torch.autograd.grad([y.sum()], [x])[0]

posterior_variance = self.scheduler._get_variance(i)
model_std = torch.exp(0.5 * posterior_variance)
grad = model_std * grad
grad[timesteps < 2] = 0
x = x.detach()
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]

# 4. apply conditions to the trajectory
x = self.reset_x0(x, conditions, self.action_dim)
x = self.to_torch(x)
return x, y

def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
obs = self.normalize(obs, "observations")
obs = obs[None].repeat(batch_size, axis=0)
conditions = {0: self.to_torch(obs)}
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
x1 = torch.randn(shape, device=self.unet.device)
x = self.reset_x0(x1, conditions, self.action_dim)
x = self.to_torch(x)
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
sorted_idx = y.argsort(0, descending=True).squeeze()
sorted_values = x[sorted_idx]
actions = sorted_values[:, :, : self.action_dim]
actions = actions.detach().cpu().numpy()
denorm_actions = self.de_normalize(actions, key="actions")
denorm_actions = denorm_actions[0, 0]
return denorm_actions
122 changes: 122 additions & 0 deletions examples/diffuser/run_diffuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import numpy as np
import torch

import d4rl # noqa
import gym
import tqdm
import train_diffuser
from diffusers import DDPMScheduler, UNet1DModel


env_name = "hopper-medium-expert-v2"
env = gym.make(env_name)
data = env.get_dataset() # dataset is only used for normalization in this colab

DEVICE = "cpu"
DTYPE = torch.float

# diffusion model settings
n_samples = 4 # number of trajectories planned via diffusion
horizon = 128 # length of sampled trajectories
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
num_inference_steps = 100 # number of difusion steps


# Two generators for different parts of the diffusion loop to work in colab
generator_cpu = torch.Generator(device="cpu")

scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2")

# 3 different pretrained models are available for this task.
# The horizion represents the length of trajectories used in training.
network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE)
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE)
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE)


# network specific constants for inference
clip_denoised = network.clip_denoised
predict_epsilon = network.predict_epsilon

# [ observation_dim ] --> [ n_samples x observation_dim ]
obs = env.reset()
total_reward = 0
done = False
T = 300
rollout = [obs.copy()]

try:
for t in tqdm.tqdm(range(T)):
obs_raw = obs

# normalize observations for forward passes
obs = train_diffuser.normalize(obs, data, "observations")
obs = obs[None].repeat(n_samples, axis=0)
conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)}

# constants for inference
batch_size = len(conditions[0])
shape = (batch_size, horizon, state_dim + action_dim)

# sample random initial noise vector
x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu)

# this model is conditioned from an initial state, so you will see this function
# multiple times to change the initial state of generated data to the state
# generated via env.reset() above or env.step() below
x = train_diffuser.reset_x0(x1, conditions, action_dim)

# convert a np observation to torch for model forward pass
x = train_diffuser.to_torch(x)

eta = 1.0 # noise factor for sampling reconstructed state

# run the diffusion process
# for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
for i in tqdm.tqdm(scheduler.timesteps):
# create batch of timesteps to pass into model
timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long)

# 1. generate prediction from model
with torch.no_grad():
residual = network(x, timesteps).sample

# 2. use the model prediction to reconstruct an observation (de-noise)
obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"]

# 3. [optional] add posterior noise to the sample
if eta > 0:
noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device)
posterior_variance = scheduler._get_variance(i) # * noise
# no noise when t == 0
# NOTE: original implementation missing sqrt on posterior_variance
obs_reconstruct = (
obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise
) # MJ had as log var, exponentiated

# 4. apply conditions to the trajectory
obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim)
x = train_diffuser.to_torch(obs_reconstruct_postcond)
plans = train_diffuser.helpers.to_np(x[:, :, :action_dim])
# select random plan
idx = np.random.randint(plans.shape[0])
# select action at correct time
action = plans[idx, 0, :]
actions = train_diffuser.de_normalize(action, data, "actions")
# execute action in environment
next_observation, reward, terminal, _ = env.step(action)

# update return
total_reward += reward
print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}")

# save observations for rendering
rollout.append(next_observation.copy())
obs = next_observation
except KeyboardInterrupt:
pass

print(f"Total reward: {total_reward}")
render = train_diffuser.MuJoCoRenderer(env)
train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0))
Loading

0 comments on commit 48a7414

Please sign in to comment.