Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

fixed Criteria and added weights #204

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions inferno/extensions/criteria/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class Criteria(nn.Module):
"""Aggregate multiple criteria to one."""
def __init__(self, *criteria):
def __init__(self, *criteria, weights=None):
super(Criteria, self).__init__()
if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)):
criteria = list(criteria[0])
Expand All @@ -19,6 +19,12 @@ def __init__(self, *criteria):
"Criterion must be a torch module."
self.criteria = criteria

if not weights:
weights = (1,) * len(criteria)
assert len(weights) == len(criteria), \
"weight must be given for every criterion"
self.weights = weights

def forward(self, prediction, target):
assert isinstance(prediction, (list, tuple)), \
"`prediction` must be a list or a tuple, got {} instead."\
Expand All @@ -30,8 +36,9 @@ def forward(self, prediction, target):
"Number of predictions must equal the number of targets. " \
"Got {} predictions but {} targets.".format(len(prediction), len(target))
# Compute losses
losses = [criterion(prediction, target)
for _prediction, _target, criterion in zip(prediction, target, self.criteria)]
losses = [weight * criterion(_prediction, _target)
for weight, _prediction, _target, criterion
in zip(self.weights, prediction, target, self.criteria)]
# Aggegate losses
loss = reduce(lambda x, y: x + y, losses)
# Done
Expand Down