Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rnnt loss #891

Merged
merged 14 commits into from
Jan 17, 2022
5 changes: 5 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
show-source=true
statistics=true
max-line-length=80
per-file-ignores =
# line too long E501
# line break before operator W503
k2/python/k2/rnnt_loss.py: E501, W503
k2/python/tests/rnnt_loss_test.py: W503
exclude =
.git,
setup.py,
Expand Down
2 changes: 2 additions & 0 deletions k2/python/csrc/torch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "k2/python/csrc/torch/fsa_algo.h"
#include "k2/python/csrc/torch/index_add.h"
#include "k2/python/csrc/torch/index_select.h"
#include "k2/python/csrc/torch/mutual_information.h"
#include "k2/python/csrc/torch/nbest.h"
#include "k2/python/csrc/torch/ragged.h"
#include "k2/python/csrc/torch/ragged_ops.h"
Expand All @@ -44,6 +45,7 @@ void PybindTorch(py::module &m) {
PybindFsaAlgo(m);
PybindIndexAdd(m);
PybindIndexSelect(m);
PybindMutualInformation(m);
PybindNbest(m);
PybindRagged(m);
PybindRaggedOps(m);
Expand Down
6 changes: 6 additions & 0 deletions k2/python/csrc/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set(torch_srcs
fsa_algo.cu
index_add.cu
index_select.cu
mutual_information.cu
mutual_information_cpu.cu
nbest.cu
ragged.cu
ragged_ops.cu
Expand All @@ -19,6 +21,10 @@ set(torch_srcs
v2/ragged_shape.cu
)

if (K2_WITH_CUDA)
list(APPEND torch_srcs mutual_information_cuda.cu)
endif()

set(torch_srcs_with_prefix)
foreach(src IN LISTS torch_srcs)
list(APPEND torch_srcs_with_prefix "torch/${src}")
Expand Down
68 changes: 68 additions & 0 deletions k2/python/csrc/torch/mutual_information.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Wei Kang)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/device_guard.h"
#include "k2/python/csrc/torch/mutual_information.h"
#include "k2/python/csrc/torch/torch_util.h"

void PybindMutualInformation(py::module &m) {
m.def(
"mutual_information_forward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor {
k2::DeviceGuard guard(k2::GetContext(px));
if (px.device().is_cpu()) {
return k2::MutualInformationCpu(px, py, boundary, p);
} else {
#ifdef K2_WITH_CUDA
return k2::MutualInformationCuda(px, py, boundary, p);
#else
K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
<< "that you compiled the code with K2_WITH_CUDA.";
return torch::Tensor();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"));

m.def(
"mutual_information_backward",
[](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
k2::DeviceGuard guard(k2::GetContext(px));
if (px.device().is_cpu()) {
return k2::MutualInformationBackwardCpu(px, py, boundary, p,
ans_grad);
} else {
#ifdef K2_WITH_CUDA
return k2::MutualInformationBackwardCuda(px, py, boundary, p,
ans_grad, true);
#else
K2_LOG(FATAL) << "Failed to find native CUDA module, make sure "
<< "that you compiled the code with K2_WITH_CUDA.";
return std::vector<torch::Tensor>();
#endif
}
},
py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"),
py::arg("ans_grad"));
}
107 changes: 107 additions & 0 deletions k2/python/csrc/torch/mutual_information.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_
#define K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_

#include <torch/extension.h>

#include <vector>

#include "k2/python/csrc/torch.h"

namespace k2 {
/*
Forward of mutual_information. See also comment of `mutual_information`
in mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.

@param px Tensor of shape [B][S][T + 1]; contains the log-odds ratio of
generating the next x in the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p(x_s | x_0..x_{s-1}, y_0..y_{s-1})

--> (the last one: s-1 --> t-1)

p(x_s | x_0..x_{s-1}, y_0..y_{t-1})

i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating
x_s. (See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:

p[b,0,0] = 0.0
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if s > 0 or t > 0,
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend using torch::optional<torch::Tensor>.
You don't need to create a tensor of shape (0, 0). Just leave it unset.

Also, you can pass a None from Python.

int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: directy -> directly


The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output

torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output

/*
backward of mutual_information; returns (grad_px, grad_py)

if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor> MutualInformationBackwardCpu(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);

std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);

} // namespace k2

void PybindMutualInformation(py::module &m);

#endif // K2_PYTHON_CSRC_TORCH_MUTUAL_INFORMATION_H_
Loading