Skip to content

Commit

Permalink
Update name
Browse files Browse the repository at this point in the history
  • Loading branch information
LirongWu committed Aug 20, 2020
1 parent f3b71a3 commit 674dff4
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 20 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The code includes the following modules:
* eval.py -- Calculate performance metrics from results, each being the average of 10 seeds
* utils.py
* GIFPloter() -- Auxiliary tool for online plot
* GetIndicator() -- Auxiliary tool for evaluating metric
* CompPerformMetrics() -- Auxiliary tool for evaluating metric
* Sampling() -- Sampling in the latent space for generating new data on the learned manifold

## Running the code
Expand Down Expand Up @@ -103,7 +103,7 @@ python main.py -MultiRun
python eval.py -M ML-Enc
python eval.py -M ML-AE
```
The evaluation metrics are available in `./pic/indicators.csv`
The evaluation metrics are available in `./pic/PerformMetrics.csv`

6. To test the generalization to unseen data
```
Expand Down
4 changes: 2 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def LoadData(data_name='SwissRoll', data_num=1500, seed=0, noise=0.0, device=tor
train_data = train_data / 2

# Load 7Mnist Dataset
if data_name == '7MNIST':
if data_name == 'MNIST_7':

train_data = torchvisiondatasets.MNIST(
'~/data', train=True, download=True,
Expand All @@ -138,7 +138,7 @@ def LoadData(data_name='SwissRoll', data_num=1500, seed=0, noise=0.0, device=tor
train_label = train_label[mask][data_num:data_num*2]

# Load 10Mnist Dataset
if data_name == '10MNIST':
if data_name == 'MNIST_10':

train_data = torchvisiondatasets.MNIST(
'~/data', train=True, download=True,
Expand Down
8 changes: 4 additions & 4 deletions eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import csv
import argparse
import numpy as np
from utils import GetIndicator
from utils import CompPerformMetrics

if __name__ == "__main__":

Expand All @@ -19,10 +19,10 @@

if args.mode == 'ML-Enc':
data9 = np.loadtxt(path + '9.txt')
indicator = GetIndicator(data=data0, latent=data9, lat=[data8])
indicator = CompPerformMetrics(data=data0, latent=data9, lat=[data8])
if args.mode == 'ML-AE':
data18 = np.loadtxt(path + '18.txt')
indicator = GetIndicator(data=data0, latent=data18, lat=[data8, data11])
indicator = CompPerformMetrics(data=data0, latent=data18, lat=[data8, data11])

out_seeds.append(np.array(list(indicator.values())))
print(indicator)
Expand All @@ -31,7 +31,7 @@
out_seeds = out_seeds.mean(axis=0)

# Save metrics results to a csv file
outFile = open('./pic/indicators.csv','a+', newline='')
outFile = open('./pic/PerformMetrics.csv','a+', newline='')
writer = csv.writer(outFile, dialect='excel')

names = []
Expand Down
18 changes: 9 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def InlinePlot(model, batch_size, datas, labels, path, name, indicator=False, mo
label_point = np.concatenate((label_point, label.cpu().detach().numpy()), axis=0)

# Plotting a new fig for the current epoch
if param['DATASET'] != '10MNIST':
if param['DATASET'] != 'MNIST_10':
gif_ploter.AddNewFig(
latent_point,
label_point,
Expand All @@ -176,7 +176,7 @@ def InlinePlot(model, batch_size, datas, labels, path, name, indicator=False, mo
# Used for metrics evaluation and executed at the completion of the entire training process.
if indicator:
latent_index = 2 * len(param['NetworkStructure']) - 3
indicator = GetIndicator(
indicator = CompPerformMetrics(
datas.reshape(datas.shape[0], -1),
torch.tensor(latent_point[latent_index], device=device),
dataset = param['DATASET']
Expand All @@ -190,7 +190,7 @@ def InlinePlot(model, batch_size, datas, labels, path, name, indicator=False, mo
np.savetxt(path + '/out/label.txt', label_point)

# Save the metrics to a csv file
outFile = open(path + '/indicators.csv','a+', newline='')
outFile = open(path + '/PerformMetrics.csv','a+', newline='')
writer = csv.writer(outFile, dialect='excel')
names = []
results = []
Expand Down Expand Up @@ -226,7 +226,7 @@ def SetParam():
parser.add_argument("-N", "--name", default=None, type=str) # File names where data and figs are stored
parser.add_argument("-PP", "--ParamPath", default='None', type=str) # Path for an existing parameter
parser.add_argument("-M", "--Mode", default='ML-AE', type=str)
parser.add_argument("-D", "--DATASET", default='SwissRoll', type=str, choices=['SwissRoll', 'SCurve', '7MNIST', '10MNIST', 'Spheres5500'])
parser.add_argument("-D", "--DATASET", default='SwissRoll', type=str, choices=['SwissRoll', 'SCurve', 'MNIST_7', 'MNIST_10', 'Spheres5500'])
parser.add_argument("-LR", "--LEARNINGRATE", default=1e-3, type=float)
parser.add_argument("-B", "--BATCHSIZE", default=800, type=int)
parser.add_argument("-RB", "--RegularB", default=3, type=float) # Boundary parameters for push-away Loss
Expand All @@ -243,10 +243,10 @@ def SetParam():
parser.add_argument("-MultiRun", "--Train_MultiRun", default=False, action='store_true')
args = parser.parse_args()

if args.DATASET == '7MNIST':
args.ParamPath = './param/7mnist.json'
if args.DATASET == '10MNIST':
args.ParamPath = './param/10mnist.json'
if args.DATASET == 'MNIST_7':
args.ParamPath = './param/mnist_7.json'
if args.DATASET == 'MNIST_10':
args.ParamPath = './param/mnist_10.json'
if args.DATASET == 'Spheres5500':
args.ParamPath = './param/spheres5500.json'
if args.ParamPath is not 'None':
Expand Down Expand Up @@ -362,7 +362,7 @@ def Train_MultiRun():

# Plotting the final results and evaluating the metrics
InlinePlot(Model, param['BATCHSIZE'], train_data, train_label, path, name='Train', indicator=True, mode=param['Mode'])
if param['DATASET'] != '10MNIST':
if param['DATASET'] != 'MNIST_10':
gif_ploter.SaveGIF(path=path)

# Testing the generalizability of the model to out-of-samples
Expand Down
2 changes: 1 addition & 1 deletion param/10mnist.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": null,
"ParamPath": "None",
"Mode": "ML-AE",
"DATASET": "10MNIST",
"DATASET": "MNIST_10",
"LEARNINGRATE": 0.001,
"BATCHSIZE": 8000,
"RegularB": 3,
Expand Down
2 changes: 1 addition & 1 deletion param/7mnist.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": null,
"ParamPath": "None",
"Mode": "ML-AE",
"DATASET": "7MNIST",
"DATASET": "MNIST_10",
"LEARNINGRATE": 0.001,
"BATCHSIZE": 8000,
"RegularB": 2.2,
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def Plot_Generation(self, input_data, latent, rec_data, latent_gen, gen_data, la
plt.close()


def GetIndicator(data, latent, lat=None, dataset='None'):
def CompPerformMetrics(data, latent, lat=None, dataset='None'):

"""
function used to evaluate metrics
Expand Down

0 comments on commit 674dff4

Please sign in to comment.