Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/cnn vae #52

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ __pycache__
*.egg-info
.idea
build/
output/
/wandb/
/outputs/
*output/
*wandb/
*outputs/
2 changes: 2 additions & 0 deletions priorCVAE/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .gp_dataset import GPDataset
from .offline_dataset import OfflineDataset

52 changes: 52 additions & 0 deletions priorCVAE/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Offline dataset.
"""
import random

import jax
import jax.numpy as jnp


class OfflineDataset:
"""Use offline dataset, and randomly generate a batch from it."""

def __init__(self, dataset: jnp.ndarray, x: jnp.ndarray = None, c: jnp.ndarray = None):
"""
Initialize the OfflineDataset class.

:param dataset: jax ndarray of the dataset.
:param x: jax ndarray of the x values, defaults to None.
:param c: jax ndarray of the conditional value, defaults to None.
"""
self.dataset = dataset
self.x = x
self.c = c

if self.x is not None:
assert self.dataset.shape[0] == self.x.shape[0]
if self.c is not None:
assert self.dataset.shape[0] == self.c.shape[0]

def simulatedata(self, n_samples: int = 10000, batch_idx: jnp.ndarray = None) -> [jnp.ndarray, jnp.ndarray,
jnp.ndarray]:
"""
Make a batch of data from the dataset array.

:param n_samples: number of samples.
:param batch_idx: an array of elements to return in a batch

:returns: A tuple
- x
- samples
- conditional value
"""

if batch_idx is None:
rng_key, _ = jax.random.split(jax.random.PRNGKey(random.randint(0, 9999)))
batch_idx = jax.random.randint(rng_key, [n_samples], 0, self.dataset.shape[0])

batch_data = self.dataset[batch_idx]
batch_x = self.x[batch_idx] if self.x is not None else None
batch_c = self.c[batch_idx] if self.c is not None else None

return batch_x, batch_data, batch_c
4 changes: 2 additions & 2 deletions priorCVAE/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .loss_units import scaled_sum_squared_loss, mean_squared_loss, kl_divergence, square_maximum_mean_discrepancy
from .loss_classes import SquaredSumAndKL, Loss, MMDAndKL
from .loss_units import scaled_sum_squared_loss, mean_squared_loss, kl_divergence, square_maximum_mean_discrepancy, pixel_sum_loss
from .loss_classes import SquaredSumAndKL, Loss, MMDAndKL, SumPixelAndKL
47 changes: 46 additions & 1 deletion priorCVAE/losses/loss_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flax.core import FrozenDict
import flax.linen as nn

from priorCVAE.losses import kl_divergence, scaled_sum_squared_loss, square_maximum_mean_discrepancy
from priorCVAE.losses import kl_divergence, scaled_sum_squared_loss, square_maximum_mean_discrepancy, pixel_sum_loss
from priorCVAE.priors import Kernel


Expand Down Expand Up @@ -101,3 +101,48 @@ def __call__(self, state_params: FrozenDict, state: TrainState, batch: [jnp.ndar
kld_loss = kl_divergence(z_mu, z_logvar)
loss = jnp.sqrt(relu_sq_mmd_loss) + self.kl_scaling * kld_loss
return loss


class SumPixelAndKL(Loss):
"""
Loss function with Sum pixel loss and KL.
"""

def __init__(self, conditional: bool = False):
"""
Initialize the SquaredSumAndKL loss.

:param conditional:
"""
super().__init__(conditional)
self.kl_scale = 0.1
self.itr = 0

def step_increase_parameter(self):
"""
Using predefined steps
After every 1000 iterations add 0.1
"""
self.itr = self.itr + 1
if self.itr % 500 == 0:
self.kl_scale = self.kl_scale + 0.1

@partial(jax.jit, static_argnames=['self'])
def __call__(self, state_params: FrozenDict, state: TrainState, batch: [jnp.ndarray, jnp.ndarray, jnp.ndarray],
z_rng: KeyArray) -> jnp.ndarray:
"""
Calculates the loss value.

:param state_params: Current state parameters of the model.
:param state: Current state of the model.
:param batch: Current batch of the data. It is list of [x, y, c] values.
:param z_rng: a PRNG key used as the random key.
"""
_, y, ls = batch
c = ls if self.conditional else None
y_hat, z_mu, z_logvar = state.apply_fn({'params': state_params}, y, z_rng, c=c)
pixel_loss = pixel_sum_loss(y, y_hat)
kld_loss = self.kl_scale * kl_divergence(z_mu, z_logvar)
loss = pixel_loss + kld_loss
self.step_increase_parameter()
return loss
50 changes: 42 additions & 8 deletions priorCVAE/losses/loss_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ def kl_divergence(mean: jnp.ndarray, logvar: jnp.ndarray) -> jnp.ndarray:

Detailed derivation can be found here: https://learnopencv.com/variational-autoencoder-in-tensorflow/

:param mean: the mean of the Gaussian distribution with shape (N,).
:param logvar: the log-variance of the Gaussian distribution with shape (N,) i.e. only diagonal values considered.
:param mean: the mean of the Gaussian distribution with shape (B, D).
:param logvar: the log-variance of the Gaussian distribution with shape (B, D) i.e. only diagonal values considered.

:return: the KL divergence value.

Note: We mean over the batch values.
"""
return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))
assert len(mean.shape) == len(logvar.shape) == 2
assert mean.shape == logvar.shape

return jnp.mean(-0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar), axis=-1), axis=0)


@jax.jit
Expand All @@ -38,14 +43,17 @@ def scaled_sum_squared_loss(y: jnp.ndarray, reconstructed_y: jnp.ndarray, vae_va

-1 * log N (y | y', sigma) \approx -0.5 ((y - y'/sigma)^2)

:param y: the ground-truth value of y with shape (N, D).
:param reconstructed_y: the reconstructed value of y with shape (N, D).
:param y: the ground-truth value of y with shape (B, D).
:param reconstructed_y: the reconstructed value of y with shape (B, D).
:param vae_var: a float value representing the varianc of the VAE.

:returns: the loss value

Note: We mean over the batch values.
"""
assert len(y.shape) == len(reconstructed_y.shape) == 2
assert y.shape == reconstructed_y.shape
return 0.5 * jnp.sum((reconstructed_y - y) ** 2 / vae_var)
return jnp.mean(0.5 * jnp.sum((reconstructed_y - y) ** 2 / vae_var, axis=-1), 0)


@jax.jit
Expand All @@ -55,11 +63,12 @@ def mean_squared_loss(y: jnp.ndarray, reconstructed_y: jnp.ndarray) -> jnp.ndarr

L(y, y') = mean(((y - y')^2))

:param y: the ground-truth value of y with shape (N, D).
:param reconstructed_y: the reconstructed value of y with shape (N, D).
:param y: the ground-truth value of y with shape (B, D).
:param reconstructed_y: the reconstructed value of y with shape (B, D).

:returns: the loss value
"""
assert len(y.shape) == len(reconstructed_y.shape) == 2
assert y.shape == reconstructed_y.shape
return jnp.mean((reconstructed_y - y) ** 2)

Expand Down Expand Up @@ -108,3 +117,28 @@ def square_maximum_mean_discrepancy(kernel: Kernel, target_samples: jnp.ndarray,

mmd_val_square = term_xx + term_yy - term_xy
return mmd_val_square


@jax.jit
def pixel_sum_loss(y: jnp.ndarray, reconstructed_y: jnp.ndarray) -> jnp.ndarray:
"""
Sum of absolute error between pixels of an image and a mean over batch.

L(y, y') = sum(y - y')

:param y: the ground-truth value of y with shape (N, D, D, C).
:param reconstructed_y: the reconstructed value of y with shape (N, D, D, C).

:returns: the loss value
"""
assert len(y.shape) == 4
assert y.shape == reconstructed_y.shape

N, D, D, C = y.shape

pixel_diff = jnp.abs(y - reconstructed_y) # (N, D, D, C)
sum_pixel_diff = jnp.sum(pixel_diff.reshape((N, -1)), axis=-1) # (N, )
assert sum_pixel_diff.shape == (N, )
mean_loss_val = jnp.mean(sum_pixel_diff, axis=0) # Over batch

return mean_loss_val
4 changes: 2 additions & 2 deletions priorCVAE/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .encoder import MLPEncoder, Encoder
from .decoder import MLPDecoder, Decoder
from .encoder import MLPEncoder, Encoder, CNNEncoder
from .decoder import MLPDecoder, Decoder, CNNDecoder
from .vae import VAE
63 changes: 63 additions & 0 deletions priorCVAE/models/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
from abc import ABC
from typing import Tuple, Union
from math import prod

from flax import linen as nn
import jax.numpy as jnp
Expand All @@ -11,6 +12,7 @@

class Decoder(ABC, nn.Module):
"""Parent class for decoder model."""

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -42,3 +44,64 @@ def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
z = activation_fn(z)
z = nn.Dense(self.out_dim, name="dec_out")(z)
return z


class CNNDecoder(Decoder):
"""
CNN based decoder with the following structure:

for _ in hidden_dims:
y = Activation(Dense(y))
y = reshape_into_grid(y)
for _ in conv_features[:-1]:
y = Activation(TransposeConvolution(y))
y = TransposeConvolution(y)

"""
conv_features: Tuple[int]
hidden_dim: Union[Tuple[int], int]
out_channel: int
decoder_reshape: Tuple
conv_activation: Union[Tuple, PjitFunction] = nn.sigmoid
conv_stride: Union[int, Tuple[int]] = 2
conv_kernel_size: Union[Tuple[Tuple[int]], Tuple[int]] = (3, 3)
activations: Union[Tuple, PjitFunction] = nn.sigmoid

@nn.compact
def __call__(self, y: jnp.ndarray) -> (jnp.ndarray, jnp.ndarray):

assert self.conv_features[-1] == self.out_channel

# If a single activation function or single hidden dimension is passed.
hidden_dims = [self.hidden_dim] if isinstance(self.hidden_dim, int) else self.hidden_dim
activations = [self.activations] * len(hidden_dims) if not isinstance(self.activations,
Tuple) else self.activations

conv_activation = [self.conv_activation] * len(self.conv_features) if not isinstance(self.conv_activation,
Tuple) else self.conv_activation
conv_stride = [self.conv_stride] * len(self.conv_features) if not isinstance(self.conv_stride,
Tuple) else self.conv_stride
conv_kernel_size = [self.conv_kernel_size] * len(self.conv_features) if not isinstance(
self.conv_kernel_size[0], Tuple) else self.conv_kernel_size

# MLP layers
for i, (hidden_dim, activation_fn) in enumerate(zip(hidden_dims, activations)):
y = nn.Dense(hidden_dim, name=f"enc_hidden_{i}")(y)
y = activation_fn(y)

# Apply Dense and reshape into grid
y = nn.Dense(prod(self.decoder_reshape), name=f"enc_hidden_reshape")(y)
y = activations[-1](y) # FIXME: should be -1 or new variable?
y = y.reshape((-1,) + self.decoder_reshape)

# Conv layers
for i, (feat, k_s, stride, activation_fn) in enumerate(
zip(self.conv_features, conv_kernel_size, conv_stride, conv_activation)):
if i == (len(self.conv_features) - 1): # no activation for last layer
y = nn.ConvTranspose(features=feat, kernel_size=k_s, strides=(stride, stride),
padding="VALID")(y)
else:
y = nn.ConvTranspose(features=feat, kernel_size=k_s, strides=(stride, stride), padding="VALID")(y)
y = activation_fn(y)

return y
57 changes: 57 additions & 0 deletions priorCVAE/models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class Encoder(ABC, nn.Module):
"""Parent class for encoder model."""

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -45,3 +46,59 @@ def __call__(self, y: jnp.ndarray) -> (jnp.ndarray, jnp.ndarray):
z_mu = nn.Dense(self.latent_dim, name="z_mu")(y)
z_logvar = nn.Dense(self.latent_dim, name="z_logvar")(y)
return z_mu, z_logvar


class CNNEncoder(Encoder):
"""
CNN based encoder with the following structure:

for _ in conv_features:
y = Activation(Convolution(y))

y = flatten(y)

for _ in hidden_dims:
y = Activation(Dense(y))

z_m = Dense(y)
z_logvar = Dense(y)

"""
conv_features: Tuple[int]
hidden_dim: Union[Tuple[int], int]
latent_dim: int
conv_activation: Union[Tuple, PjitFunction] = nn.sigmoid
conv_stride: Union[int, Tuple[int]] = 2
conv_kernel_size: Union[Tuple[Tuple[int]], Tuple[int]] = (3, 3)
activations: Union[Tuple, PjitFunction] = nn.sigmoid

@nn.compact
def __call__(self, y: jnp.ndarray) -> (jnp.ndarray, jnp.ndarray):
# If a single activation function or single hidden dimension is passed.
hidden_dims = [self.hidden_dim] if isinstance(self.hidden_dim, int) else self.hidden_dim
activations = [self.activations] * len(hidden_dims) if not isinstance(self.activations,
Tuple) else self.activations

conv_activation = [self.conv_activation] * len(self.conv_features) if not isinstance(self.conv_activation,
Tuple) else self.conv_activation
conv_stride = [self.conv_stride] * len(self.conv_features) if not isinstance(self.conv_stride,
Tuple) else self.conv_stride
conv_kernel_size = [self.conv_kernel_size] * len(self.conv_features) if not isinstance(
self.conv_kernel_size[0], Tuple) else self.conv_kernel_size

# Conv layers
for i, (feat, k_s, stride, activation_fn) in enumerate(
zip(self.conv_features, conv_kernel_size, conv_stride, conv_activation)):
y = nn.Conv(features=feat, kernel_size=k_s, strides=stride, padding="VALID")(y)
y = activation_fn(y)

# Flatten
y = y.reshape((y.shape[0], -1))

# MLP layers
for i, (hidden_dim, activation_fn) in enumerate(zip(hidden_dims, activations)):
y = nn.Dense(hidden_dim, name=f"enc_hidden_{i}")(y)
y = activation_fn(y)
z_mu = nn.Dense(self.latent_dim, name="z_mu")(y)
z_logvar = nn.Dense(self.latent_dim, name="z_logvar")(y)
return z_mu, z_logvar
2 changes: 1 addition & 1 deletion priorCVAE/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ def generate_decoder_samples(key: KeyArray, decoder_params: Dict, decoder: Decod

def decode(decoder_params: Dict, decoder: Decoder, z: jnp.ndarray):
"""Decode a latent vector z."""
return decoder.apply({'params': decoder_params}, z)
return decoder.apply({'params': decoder_params}, z)
Loading