-
Notifications
You must be signed in to change notification settings - Fork 42
[WIP] Implement 2nd pass training using 1-best decoding results from the 1st pass network #198
base: master
Are you sure you want to change the base?
Conversation
snowfall/models/second_pass_model.py
Outdated
|
||
# now x2 is (B, T, F) | ||
|
||
x_concat = torch.cat((padded_acoustics, x2), dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO(fangjun): Use cross attention here
- query: x2
- key and value: padded_acoustics
and masked self-attention
- key, query, and value: x2
@@ -0,0 +1,484 @@ | |||
#!/usr/bin/env python3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
common2.py is the same as common.py, except that it has some code supporting
the second pass model. To avoid conflicts with the master, a new file is used.
The same goes for the following xxx2.py files, e.g., lm_rescore2.py, mmi2.py.
import k2 | ||
|
||
|
||
class Nbest(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file implements the Nbest
class proposed in
#232 (comment)
Please have a review if it matches the proposal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great! Yes it looks like what I had in mind.
I assume you would separate it from this PR though? Or maybe even submit it to k2? Since there's a lot going on here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will move it to k2.
It implements #106 (comment)
The training objf is decreasing and seems to be converging. Will post the decoding results later.