Skip to content

Commit

Permalink
replace view with reshape in aishell/asr1 (#3887)
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatV authored Nov 15, 2024
1 parent 6f44ac9 commit 62c21e9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
33 changes: 18 additions & 15 deletions paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _calc_att_loss(self,
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
loss_att = loss_att * (1 - reverse_weight) + r_loss_att * reverse_weight
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
decoder_out.reshape([-1, self.vocab_size]),
ys_out_pad,
ignore_label=self.ignore_id, )
return loss_att, acc_att
Expand Down Expand Up @@ -271,11 +271,13 @@ def recognize(
maxlen = encoder_out.shape[1]
encoder_dim = encoder_out.shape[2]
running_size = batch_size * beam_size
encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view(
running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim)
encoder_out = encoder_out.unsqueeze(1).repeat(
1, beam_size, 1, 1).reshape(
[running_size, maxlen,
encoder_dim]) # (B*N, maxlen, encoder_dim)
encoder_mask = encoder_mask.unsqueeze(1).repeat(
1, beam_size, 1, 1).view(running_size, 1,
maxlen) # (B*N, 1, max_len)
1, beam_size, 1, 1).reshape([running_size, 1,
maxlen]) # (B*N, 1, max_len)

hyps = paddle.ones(
[running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1)
Expand Down Expand Up @@ -305,34 +307,35 @@ def recognize(

# 2.3 Seconde beam prune: select topk score with history
scores = scores + top_k_logp # (B*N, N), broadcast add
scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N)
scores = scores.reshape(
[batch_size, beam_size * beam_size]) # (B, N*N)
scores, offset_k_index = scores.topk(k=beam_size) # (B, N)
scores = scores.view(-1, 1) # (B*N, 1)
scores = scores.reshape([-1, 1]) # (B*N, 1)

# 2.4. Compute base index in top_k_index,
# regard top_k_index as (B*N*N),regard offset_k_index as (B*N),
# then find offset_k_index in top_k_index
base_k_index = paddle.arange(batch_size).view(-1, 1).repeat(
base_k_index = paddle.arange(batch_size).reshape([-1, 1]).repeat(
1, beam_size) # (B, N)
base_k_index = base_k_index * beam_size * beam_size
best_k_index = base_k_index.view(-1) + offset_k_index.view(
-1) # (B*N)
best_k_index = base_k_index.reshape([-1]) + offset_k_index.reshape(
[-1]) # (B*N)

# 2.5 Update best hyps
best_k_pred = paddle.index_select(
top_k_index.view(-1), index=best_k_index, axis=0) # (B*N)
top_k_index.reshape([-1]), index=best_k_index, axis=0) # (B*N)
best_hyps_index = best_k_index // beam_size
last_best_k_hyps = paddle.index_select(
hyps, index=best_hyps_index, axis=0) # (B*N, i)
hyps = paddle.cat(
(last_best_k_hyps, best_k_pred.view(-1, 1)),
(last_best_k_hyps, best_k_pred.reshape([-1, 1])),
dim=1) # (B*N, i+1)

# 2.6 Update end flag
end_flag = paddle.equal(hyps[:, -1], self.eos).view(-1, 1)
end_flag = paddle.equal(hyps[:, -1], self.eos).reshape([-1, 1])

# 3. Select best of best
scores = scores.view(batch_size, beam_size)
scores = scores.reshape([batch_size, beam_size])
# TODO: length normalization
best_index = paddle.argmax(scores, axis=-1).long() # (B)
best_hyps_index = best_index + paddle.arange(
Expand Down Expand Up @@ -379,7 +382,7 @@ def ctc_greedy_search(
ctc_probs = self.ctc.log_softmax(encoder_out) # (B, maxlen, vocab_size)

topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
topk_index = topk_index.reshape([batch_size, maxlen]) # (B, maxlen)
pad_mask = make_pad_mask(encoder_out_lens) # (B, maxlen)
topk_index = topk_index.masked_fill_(pad_mask, self.eos) # (B, maxlen)

Expand Down
11 changes: 6 additions & 5 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def forward_attention(

p_attn = self.dropout(attn)
x = paddle.matmul(p_attn, value) # (batch, head, time1, d_k)
x = x.transpose([0, 2, 1, 3]).reshape([n_batch, -1, self.h *
self.d_k]) # (batch, time1, d_model)
x = x.transpose([0, 2, 1, 3]).reshape(
[n_batch, -1, self.h * self.d_k]) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)

Expand Down Expand Up @@ -280,8 +280,8 @@ def rel_shift(self, x, zero_triu: bool=False):
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)

x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x_padded = x_padded.reshape(
[x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]

if zero_triu:
Expand Down Expand Up @@ -349,7 +349,8 @@ def forward(self,
new_cache = paddle.concat((k, v), axis=-1)

n_batch_pos = pos_emb.shape[0]
p = self.linear_pos(pos_emb).reshape([n_batch_pos, -1, self.h, self.d_k])
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
p = p.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
Expand Down

0 comments on commit 62c21e9

Please sign in to comment.