-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Value Function and corresponding example script to Diffuser imple…
…mentation (#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <[email protected]>
- Loading branch information
Showing
14 changed files
with
1,143 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -163,4 +163,6 @@ tags | |
*.lock | ||
|
||
# DS_Store (MacOS) | ||
.DS_Store | ||
.DS_Store | ||
# RL pipelines may produce mp4 outputs | ||
*.mp4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
|
||
import tqdm | ||
from diffusers import DiffusionPipeline | ||
from diffusers.models.unet_1d import UNet1DModel | ||
from diffusers.utils.dummy_pt_objects import DDPMScheduler | ||
|
||
|
||
class ValueGuidedDiffuserPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
value_function: UNet1DModel, | ||
unet: UNet1DModel, | ||
scheduler: DDPMScheduler, | ||
env, | ||
): | ||
super().__init__() | ||
self.value_function = value_function | ||
self.unet = unet | ||
self.scheduler = scheduler | ||
self.env = env | ||
self.data = env.get_dataset() | ||
self.means = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.means[key] = self.data[key].mean() | ||
except: | ||
pass | ||
self.stds = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.stds[key] = self.data[key].std() | ||
except: | ||
pass | ||
self.state_dim = env.observation_space.shape[0] | ||
self.action_dim = env.action_space.shape[0] | ||
|
||
def normalize(self, x_in, key): | ||
return (x_in - self.means[key]) / self.stds[key] | ||
|
||
def de_normalize(self, x_in, key): | ||
return x_in * self.stds[key] + self.means[key] | ||
|
||
def to_torch(self, x_in): | ||
if type(x_in) is dict: | ||
return {k: self.to_torch(v) for k, v in x_in.items()} | ||
elif torch.is_tensor(x_in): | ||
return x_in.to(self.unet.device) | ||
return torch.tensor(x_in, device=self.unet.device) | ||
|
||
def reset_x0(self, x_in, cond, act_dim): | ||
for key, val in cond.items(): | ||
x_in[:, key, act_dim:] = val.clone() | ||
return x_in | ||
|
||
def run_diffusion(self, x, conditions, n_guide_steps, scale): | ||
batch_size = x.shape[0] | ||
y = None | ||
for i in tqdm.tqdm(self.scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) | ||
# 3. call the sample function | ||
for _ in range(n_guide_steps): | ||
with torch.enable_grad(): | ||
x.requires_grad_() | ||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample | ||
grad = torch.autograd.grad([y.sum()], [x])[0] | ||
|
||
posterior_variance = self.scheduler._get_variance(i) | ||
model_std = torch.exp(0.5 * posterior_variance) | ||
grad = model_std * grad | ||
grad[timesteps < 2] = 0 | ||
x = x.detach() | ||
x = x + scale * grad | ||
x = self.reset_x0(x, conditions, self.action_dim) | ||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) | ||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] | ||
|
||
# 4. apply conditions to the trajectory | ||
x = self.reset_x0(x, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
return x, y | ||
|
||
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): | ||
obs = self.normalize(obs, "observations") | ||
obs = obs[None].repeat(batch_size, axis=0) | ||
conditions = {0: self.to_torch(obs)} | ||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) | ||
x1 = torch.randn(shape, device=self.unet.device) | ||
x = self.reset_x0(x1, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) | ||
sorted_idx = y.argsort(0, descending=True).squeeze() | ||
sorted_values = x[sorted_idx] | ||
actions = sorted_values[:, :, : self.action_dim] | ||
actions = actions.detach().cpu().numpy() | ||
denorm_actions = self.de_normalize(actions, key="actions") | ||
denorm_actions = denorm_actions[0, 0] | ||
return denorm_actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torch | ||
|
||
import tqdm | ||
from diffusers import DiffusionPipeline | ||
from diffusers.models.unet_1d import UNet1DModel | ||
from diffusers.utils.dummy_pt_objects import DDPMScheduler | ||
|
||
|
||
class ValueGuidedDiffuserPipeline(DiffusionPipeline): | ||
def __init__( | ||
self, | ||
value_function: UNet1DModel, | ||
unet: UNet1DModel, | ||
scheduler: DDPMScheduler, | ||
env, | ||
): | ||
super().__init__() | ||
self.value_function = value_function | ||
self.unet = unet | ||
self.scheduler = scheduler | ||
self.env = env | ||
self.data = env.get_dataset() | ||
self.means = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.means[key] = self.data[key].mean() | ||
except: | ||
pass | ||
self.stds = dict() | ||
for key in self.data.keys(): | ||
try: | ||
self.stds[key] = self.data[key].std() | ||
except: | ||
pass | ||
self.state_dim = env.observation_space.shape[0] | ||
self.action_dim = env.action_space.shape[0] | ||
|
||
def normalize(self, x_in, key): | ||
return (x_in - self.means[key]) / self.stds[key] | ||
|
||
def de_normalize(self, x_in, key): | ||
return x_in * self.stds[key] + self.means[key] | ||
|
||
def to_torch(self, x_in): | ||
if type(x_in) is dict: | ||
return {k: self.to_torch(v) for k, v in x_in.items()} | ||
elif torch.is_tensor(x_in): | ||
return x_in.to(self.unet.device) | ||
return torch.tensor(x_in, device=self.unet.device) | ||
|
||
def reset_x0(self, x_in, cond, act_dim): | ||
for key, val in cond.items(): | ||
x_in[:, key, act_dim:] = val.clone() | ||
return x_in | ||
|
||
def run_diffusion(self, x, conditions, n_guide_steps, scale): | ||
batch_size = x.shape[0] | ||
y = None | ||
for i in tqdm.tqdm(self.scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) | ||
# 3. call the sample function | ||
for _ in range(n_guide_steps): | ||
with torch.enable_grad(): | ||
x.requires_grad_() | ||
y = self.value_function(x.permute(0, 2, 1), timesteps).sample | ||
grad = torch.autograd.grad([y.sum()], [x])[0] | ||
|
||
posterior_variance = self.scheduler._get_variance(i) | ||
model_std = torch.exp(0.5 * posterior_variance) | ||
grad = model_std * grad | ||
grad[timesteps < 2] = 0 | ||
x = x.detach() | ||
x = x + scale * grad | ||
x = self.reset_x0(x, conditions, self.action_dim) | ||
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) | ||
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] | ||
|
||
# 4. apply conditions to the trajectory | ||
x = self.reset_x0(x, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
return x, y | ||
|
||
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): | ||
obs = self.normalize(obs, "observations") | ||
obs = obs[None].repeat(batch_size, axis=0) | ||
conditions = {0: self.to_torch(obs)} | ||
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) | ||
x1 = torch.randn(shape, device=self.unet.device) | ||
x = self.reset_x0(x1, conditions, self.action_dim) | ||
x = self.to_torch(x) | ||
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) | ||
sorted_idx = y.argsort(0, descending=True).squeeze() | ||
sorted_values = x[sorted_idx] | ||
actions = sorted_values[:, :, : self.action_dim] | ||
actions = actions.detach().cpu().numpy() | ||
denorm_actions = self.de_normalize(actions, key="actions") | ||
denorm_actions = denorm_actions[0, 0] | ||
return denorm_actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import numpy as np | ||
import torch | ||
|
||
import d4rl # noqa | ||
import gym | ||
import tqdm | ||
import train_diffuser | ||
from diffusers import DDPMScheduler, UNet1DModel | ||
|
||
|
||
env_name = "hopper-medium-expert-v2" | ||
env = gym.make(env_name) | ||
data = env.get_dataset() # dataset is only used for normalization in this colab | ||
|
||
DEVICE = "cpu" | ||
DTYPE = torch.float | ||
|
||
# diffusion model settings | ||
n_samples = 4 # number of trajectories planned via diffusion | ||
horizon = 128 # length of sampled trajectories | ||
state_dim = env.observation_space.shape[0] | ||
action_dim = env.action_space.shape[0] | ||
num_inference_steps = 100 # number of difusion steps | ||
|
||
|
||
# Two generators for different parts of the diffusion loop to work in colab | ||
generator_cpu = torch.Generator(device="cpu") | ||
|
||
scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") | ||
|
||
# 3 different pretrained models are available for this task. | ||
# The horizion represents the length of trajectories used in training. | ||
network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) | ||
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) | ||
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) | ||
|
||
|
||
# network specific constants for inference | ||
clip_denoised = network.clip_denoised | ||
predict_epsilon = network.predict_epsilon | ||
|
||
# [ observation_dim ] --> [ n_samples x observation_dim ] | ||
obs = env.reset() | ||
total_reward = 0 | ||
done = False | ||
T = 300 | ||
rollout = [obs.copy()] | ||
|
||
try: | ||
for t in tqdm.tqdm(range(T)): | ||
obs_raw = obs | ||
|
||
# normalize observations for forward passes | ||
obs = train_diffuser.normalize(obs, data, "observations") | ||
obs = obs[None].repeat(n_samples, axis=0) | ||
conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} | ||
|
||
# constants for inference | ||
batch_size = len(conditions[0]) | ||
shape = (batch_size, horizon, state_dim + action_dim) | ||
|
||
# sample random initial noise vector | ||
x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) | ||
|
||
# this model is conditioned from an initial state, so you will see this function | ||
# multiple times to change the initial state of generated data to the state | ||
# generated via env.reset() above or env.step() below | ||
x = train_diffuser.reset_x0(x1, conditions, action_dim) | ||
|
||
# convert a np observation to torch for model forward pass | ||
x = train_diffuser.to_torch(x) | ||
|
||
eta = 1.0 # noise factor for sampling reconstructed state | ||
|
||
# run the diffusion process | ||
# for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): | ||
for i in tqdm.tqdm(scheduler.timesteps): | ||
# create batch of timesteps to pass into model | ||
timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) | ||
|
||
# 1. generate prediction from model | ||
with torch.no_grad(): | ||
residual = network(x, timesteps).sample | ||
|
||
# 2. use the model prediction to reconstruct an observation (de-noise) | ||
obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] | ||
|
||
# 3. [optional] add posterior noise to the sample | ||
if eta > 0: | ||
noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) | ||
posterior_variance = scheduler._get_variance(i) # * noise | ||
# no noise when t == 0 | ||
# NOTE: original implementation missing sqrt on posterior_variance | ||
obs_reconstruct = ( | ||
obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise | ||
) # MJ had as log var, exponentiated | ||
|
||
# 4. apply conditions to the trajectory | ||
obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) | ||
x = train_diffuser.to_torch(obs_reconstruct_postcond) | ||
plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) | ||
# select random plan | ||
idx = np.random.randint(plans.shape[0]) | ||
# select action at correct time | ||
action = plans[idx, 0, :] | ||
actions = train_diffuser.de_normalize(action, data, "actions") | ||
# execute action in environment | ||
next_observation, reward, terminal, _ = env.step(action) | ||
|
||
# update return | ||
total_reward += reward | ||
print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") | ||
|
||
# save observations for rendering | ||
rollout.append(next_observation.copy()) | ||
obs = next_observation | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
print(f"Total reward: {total_reward}") | ||
render = train_diffuser.MuJoCoRenderer(env) | ||
train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) |
Oops, something went wrong.