Skip to content

Commit

Permalink
normalise_scores arg
Browse files Browse the repository at this point in the history
  • Loading branch information
davidcpage committed Apr 21, 2021
1 parent ae221ca commit 7f952f9
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions bonito/crf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def prepare_ctc_scores(self, scores, targets):
move_scores = scores.gather(2, move_indices.expand(T, -1, -1))
return stay_scores, move_scores

def ctc_loss(self, scores, targets, target_lengths, loss_clip=None, reduction='mean'):
scores = self.normalise(scores)
def ctc_loss(self, scores, targets, target_lengths, loss_clip=None, reduction='mean', normalise_scores=True):
if normalise_scores:
scores = self.normalise(scores)
stay_scores, move_scores = self.prepare_ctc_scores(scores, targets)
logz = logZ_cupy(stay_scores, move_scores, target_lengths + 1 - self.state_len)
loss = - (logz / target_lengths)
Expand Down Expand Up @@ -164,9 +165,9 @@ def __init__(self, config):
state_len=config['global_norm']['state_len'],
alphabet=config['labels']['labels']
)
if 'type' in config['encoder']: #new-skool
if 'type' in config['encoder']: #new-style config
encoder = from_dict(config['encoder'])
else: #old-skool
else: #old-style
encoder = rnn_encoder(seqdist.n_base, seqdist.state_len, insize=config['input']['features'], **config['encoder'])
super().__init__(encoder, seqdist)
self.config = config

0 comments on commit 7f952f9

Please sign in to comment.