Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: saving DeepSAD model with pretrain=False #4

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ scikit-learn==0.21.2
scipy==1.3.0
seaborn==0.9.0
six==1.12.0
torch==1.1.0
torchvision==0.3.0
torch>=1.1.0
torchvision>=0.3.0
7 changes: 4 additions & 3 deletions src/DeepSAD.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,17 @@ def set_network(self, net_name):

def train(self, dataset: BaseADDataset, optimizer_name: str = 'adam', lr: float = 0.001, n_epochs: int = 50,
lr_milestones: tuple = (), batch_size: int = 128, weight_decay: float = 1e-6, device: str = 'cuda',
n_jobs_dataloader: int = 0):
n_jobs_dataloader: int = 0, validate : bool = False):
"""Trains the Deep SAD model on the training data."""

self.optimizer_name = optimizer_name
self.trainer = DeepSADTrainer(self.c, self.eta, optimizer_name=optimizer_name, lr=lr, n_epochs=n_epochs,
lr_milestones=lr_milestones, batch_size=batch_size, weight_decay=weight_decay,
device=device, n_jobs_dataloader=n_jobs_dataloader)
# Get the model
self.net = self.trainer.train(dataset, self.net)
self.net = self.trainer.train(dataset, self.net, validate=validate)
self.results['train_time'] = self.trainer.train_time
self.train_loss = self.trainer.train_loss
self.c = self.trainer.c.cpu().data.numpy().tolist() # get as list

def test(self, dataset: BaseADDataset, device: str = 'cuda', n_jobs_dataloader: int = 0):
Expand Down Expand Up @@ -130,7 +131,7 @@ def save_model(self, export_model, save_ae=True):
"""Save Deep SAD model to export_model."""

net_dict = self.net.state_dict()
ae_net_dict = self.ae_net.state_dict() if save_ae else None
ae_net_dict = self.ae_net.state_dict() if (save_ae and self.ae_net is not None) else None

torch.save({'c': self.c,
'net_dict': net_dict,
Expand Down
47 changes: 38 additions & 9 deletions src/optim/DeepSAD_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ def __init__(self, c, eta: float, optimizer_name: str = 'adam', lr: float = 0.00

# Results
self.train_time = None
self.train_loss = None
self.test_auc = None
self.test_time = None
self.test_scores = None

def train(self, dataset: BaseADDataset, net: BaseNet):
def train(self, dataset: BaseADDataset, net: BaseNet, validate: bool = False):
logger = logging.getLogger()

# Get train data loader
Expand All @@ -45,25 +46,26 @@ def train(self, dataset: BaseADDataset, net: BaseNet):
optimizer = optim.Adam(net.parameters(), lr=self.lr, weight_decay=self.weight_decay)

# Set learning rate scheduler
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)
self.scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=self.lr_milestones, gamma=0.1)

# Initialize hypersphere center c (if c not loaded)
if self.c is None:
logger.info('Initializing center c...')
self.c = self.init_center_c(train_loader, net)
logger.info('Center c initialized.')
logger.info('Center c initialized to {}.'.format(self.c))

# Training
logger.info('Starting training...')
start_time = time.time()
net.train()
self.train_loss = []
for epoch in range(self.n_epochs):

scheduler.step()
self.scheduler.step()
if epoch in self.lr_milestones:
logger.info(' LR scheduler: new learning rate is %g' % float(scheduler.get_lr()[0]))
logger.info(' LR scheduler: new learning rate is %g' % float(self.scheduler.get_lr()[0]))

epoch_loss = 0.0
train_epoch_loss = 0.0
n_batches = 0
epoch_start_time = time.time()
for data in train_loader:
Expand All @@ -81,13 +83,40 @@ def train(self, dataset: BaseADDataset, net: BaseNet):
loss.backward()
optimizer.step()

epoch_loss += loss.item()
train_epoch_loss += loss.item()
n_batches += 1

train_loss = train_epoch_loss/n_batches
epoch_loss_history = (epoch + 1, train_loss)

if validate:
n_batches = 0
valid_epoch_loss = 0.0
valid_loader = dataset.validation_loader(batch_size=self.batch_size, num_workers=self.n_jobs_dataloader)
with torch.set_grad_enabled(False):
for data in valid_loader:
inputs, _, semi_targets, _ = data
inputs, semi_targets = inputs.to(self.device), semi_targets.to(self.device)

outputs = net(inputs)
dist = torch.sum((outputs - self.c) ** 2, dim=1)
losses = torch.where(semi_targets == 0, dist, self.eta * ((dist + self.eps) ** semi_targets.float()))
loss = torch.mean(losses)

valid_epoch_loss += loss.item()
n_batches += 1
valid_loss = valid_epoch_loss/n_batches
epoch_loss_history = (epoch + 1, train_loss, valid_loss)

self.train_loss.append(epoch_loss_history)
# log epoch statistics
epoch_train_time = time.time() - epoch_start_time
logger.info(f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s '
f'| Train Loss: {epoch_loss / n_batches:.6f} |')

stats = f'| Epoch: {epoch + 1:03}/{self.n_epochs:03} | Train Time: {epoch_train_time:.3f}s ' \
f'| Train Loss: {train_loss:.6f}'
if validate:
stats = stats + f' | Valid Loss: {valid_loss:.6f}'
logger.info(stats)

self.train_time = time.time() - start_time
logger.info('Training Time: {:.3f}s'.format(self.train_time))
Expand Down