-
Notifications
You must be signed in to change notification settings - Fork 175
Conversation
@@ -194,6 +202,9 @@ def text_to_instance( | |||
|
|||
fields_dict["metadata"] = MetadataField(meta_fields) | |||
|
|||
if weight is not None: | |||
fields_dict["weight"] = ArrayField(np.array([float(weight)], dtype=np.single)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just make it a TensorField
, and use Torch methods to create the tensors. ArrayField
is the old way.
# shape: (batch_size,) | ||
if len(weight.shape) > 1: | ||
weight = weight.squeeze() | ||
loss = -(weight * log_likelihood).sum() / batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be / weight.sum()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone back and forth on this.
If you use weight.sum()
to normalize, then the weighting is only relative to each batch, which is probably not what you want.
For example, let's say we normalize by weight.sum()
and your weights range from 0.5 - 1.0. If you have a batch that contains only instances with weights of 0.5, then this will give you the same result as if they all had weights of 1.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, fair enough. I would expect that setting all the weights to 1000 should be the same as setting them all to 0.001. But to get that behavior, and also the behavior you want, we would have to sum up all the weights in the dataset before processing a single batch, and then scale each batch accordingly. That's not practical. So let's leave it like this then.
This is just a minor, general improvement to CopyNet. It adds the option to weight the loss contributions from individual instances when calculating the batch loss.