Skip to content

Commit

Permalink
feat: contrastive classifier noise
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Mar 1, 2021
1 parent ab00cd9 commit 7193e0e
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions models/cut_semantic_mask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import util.util as util
from .modules import loss
import torch.nn.functional as F
from util.util import gaussian

class CUTSemanticMaskModel(BaseModel):
""" This class implements CUT and FastCUT model, described in the paper
Expand Down Expand Up @@ -45,6 +46,8 @@ def modify_commandline_options(parser, is_train=True):
parser.add_argument('--lambda_out_mask', type=float, default=10.0, help='weight for loss out mask')
parser.add_argument('--loss_out_mask', type=str, default='L1', help='loss mask')

parser.add_argument('--contrastive_noise', type=float, default=0.0, help='noise on constrastive classifier')

parser.set_defaults(pool_size=0) # no image pooling

opt, _ = parser.parse_known_args()
Expand Down Expand Up @@ -276,6 +279,9 @@ def calculate_NCE_loss(self, src, tgt):

total_nce_loss = 0.0
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
if self.opt.contrastive_noise>0.0:
f_q=gaussian(f_q,self.opt.contrastive_noise)
f_k=gaussian(f_k,self.opt.contrastive_noise)
loss = crit(f_q, f_k) * self.opt.lambda_NCE
total_nce_loss += loss.mean()

Expand Down

0 comments on commit 7193e0e

Please sign in to comment.