Skip to content

Commit

Permalink
model name
Browse files Browse the repository at this point in the history
  • Loading branch information
xu-hao committed Mar 18, 2020
1 parent 78aafb6 commit 53aa7a6
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from plotnine import *
import itertools

def load_data(input_file, columns):
def load_data(input_file, columns, plotdata=True):
df0 = pd.read_csv(input_file)
cols = columns if columns is not None else df0.columns
df = df0[cols]
if plotdata:
plot_sample(df, f"{model_name}_input_plot.png")
onehotencodeddf = pd.get_dummies(df, columns=cols)
return onehotencodeddf

Expand Down Expand Up @@ -83,11 +85,13 @@ def sample_decoder(decoder,
df.to_csv(filename, index=False)

if len(categorical_columns) >= 3:
df2 = df.groupby(list(df.columns)).size().reset_index(name="Frequency")
cols = list(df2.columns)
(ggplot(df2, aes(x = cols[1], y = "np.log(Frequency + 1)", color = cols[2])) + geom_point() + geom_line() + facet_grid(f"{cols[0]} ~ {cols[2]}")).save(f"{model_name}_samples_plot.png")

plot_sample(df, f"{model_name}_samples_plot.png")


def plot_sample(df, plotfile):
df2 = df.groupby(list(df.columns)).size().reset_index(name="Frequency")
cols = list(df2.columns)
(ggplot(df2, aes(x = cols[1], y = "np.log(Frequency + 1)", color = cols[2])) + geom_point() + geom_line() + facet_grid(f"{cols[0]} ~ {cols[2]}")).save(plotfile)


# MNIST dataset
Expand Down Expand Up @@ -141,7 +145,7 @@ def get_model(original_dim, scale_width, latent_dim, loss_function="xent"):

# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
vae = Model(inputs, outputs, name=model_name)
# VAE loss = mse_loss or xent_loss + kl_loss
if loss_function == "mse":
reconstruction_loss = mse(inputs, outputs)
Expand Down

0 comments on commit 53aa7a6

Please sign in to comment.