-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ctc edit distance operator #5300
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
db69417
Add edit distance operator
b7a4e3d
rename some variables in ctc_edit_distance_op
f5681f1
Merge branch 'develop' of upstream into ctc_edit_distance_dev
6bc6ccd
add gpu kernel for ctc_edit_distance_op
116687a
clean up code in ctc_edit_distance_op
b82049b
revise the doc in ctc_edit_distance_op
c16d1ca
Merge branch 'develop' of upstream into ctc_edit_distance_dev
4745a0b
Merge branch 'develop' of upstream into ctc_edit_distance_dev
36ec3e9
Merge branch 'develop' of upstream into ctc_edit_distance_dev
2c1adb0
Rename ctc_edit_distance_op to edit_distance_op
2e49fac
Rename inputs & format license
0250e54
Enable batch input in edit_distance_op
f594ca4
Reuse the usable variable in edit_distance_op
a1935b2
Remove unnecessary prefix in test name of edit_distance_op
f3dcd00
Merge branch 'develop' of upstream into ctc_edit_distance_dev
fe0ef91
fix ci error in edit_distance_op
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/operators/edit_distance_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class EditDistanceOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null."); | ||
PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); | ||
auto hyp_dims = ctx->GetInputDim("Hyps"); | ||
auto ref_dims = ctx->GetInputDim("Refs"); | ||
PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1, | ||
"Input(Hyps) must be a 2-D LoDTensor with the 2nd dimension " | ||
"equal to 1."); | ||
PADDLE_ENFORCE(ref_dims.size() == 2 && ref_dims[1] == 1, | ||
"Input(Refs) must be a 2-D LoDTensor with the 2nd dimension " | ||
"equal to 1."); | ||
ctx->SetOutputDim("Out", ctx->GetInputDim("Refs")); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext &ctx) const override { | ||
return framework::OpKernelType(framework::proto::DataType::FP32, | ||
ctx.device_context()); | ||
} | ||
}; | ||
|
||
class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
EditDistanceOpMaker(OpProto *proto, OpAttrChecker *op_checker) | ||
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("Hyps", | ||
"(2-D LoDTensor<int>, 2nd dim. equal to 1) " | ||
"The indices for hypothesis strings."); | ||
AddInput("Refs", | ||
"(2-D LoDTensor<int>, 2nd dim. equal to 1) " | ||
"The indices for reference strings."); | ||
AddAttr<bool>("normalized", | ||
"(bool, default false) Indicated whether to normalize " | ||
"the edit distance by the length of reference string.") | ||
.SetDefault(false); | ||
AddOutput("Out", | ||
"(2-D Tensor with shape [`batch_size` x 1]) " | ||
"The output edit distances of EditDistance operator."); | ||
AddComment(R"DOC( | ||
|
||
EditDistance operator computes the edit distances between a batch of hypothesis | ||
strings and their references. | ||
|
||
Edit distance, also called Levenshtein distance, measures how dissimilar two strings | ||
are by counting the minimum number of operations to transform one string into anthor. | ||
Here the operations include insertion, deletion, and substitution. For example, | ||
given hypothesis string A = "kitten" and reference B = "sitting", the edit distance | ||
is 3 for A will be transformed into B at least after two substitutions and one | ||
insertion: | ||
|
||
"kitten" -> "sitten" -> "sittin" -> "sitting" | ||
|
||
Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total | ||
number denoted by `batch_size`, and the separation is specified by the LoD information. | ||
And the `batch_size` reference strings are arranged in order in the same way in the | ||
LoDTensor Input(Refs). | ||
|
||
Output(Out) contains the `batch_size` results and each stands for the edit stance | ||
for a pair of strings respectively. If Attr(normalized) is true, the edit distance | ||
will be divided by the length of reference string. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR(edit_distance, ops::EditDistanceOp, ops::EditDistanceOpMaker, | ||
paddle::framework::EmptyGradOpMaker); | ||
REGISTER_OP_CPU_KERNEL( | ||
edit_distance, ops::EditDistanceKernel<paddle::platform::CPUPlace, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include <algorithm> | ||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/platform/cuda_helper.h" | ||
#include "paddle/platform/gpu_info.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using platform::PADDLE_CUDA_NUM_THREADS; | ||
|
||
template <typename T> | ||
__global__ void FillFirstRow(T* dist, const int N) { | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (idx < N + 1) { | ||
dist[idx] = idx; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void FillFirstColumn(T* dist, const int M, const int N) { | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (idx < M + 1) { | ||
dist[idx * (N + 1)] = idx; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void Levenshtein(T* dist, const int* x1, const int* x2, const int M, | ||
const int N, const int start) { | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
int offset = N; | ||
int index = start + idx * offset; | ||
int row = index / (N + 1); | ||
int col = index % (N + 1); | ||
if (row > 0 && col > 0 && row < M + 1 && col < N + 1) { | ||
int cost = x1[row - 1] == x2[col - 1] ? 0 : 1; | ||
int dels = dist[(row - 1) * (N + 1) + col] + 1; | ||
int ins = dist[row * (N + 1) + col - 1] + 1; | ||
int subs = dist[(row - 1) * (N + 1) + (col - 1)] + cost; | ||
dist[index] = min(dels, min(ins, subs)); | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void SetOutput(T* out, const T* dist, const int M, const int N, | ||
bool normalized) { | ||
int idx = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (idx == 0) { | ||
out[0] = normalized ? dist[M * (N + 1) + N] / N : dist[M * (N + 1) + N]; | ||
} | ||
} | ||
|
||
template <typename Place, typename T> | ||
class EditDistanceGPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const { | ||
auto* out_t = ctx.Output<framework::Tensor>("Out"); | ||
|
||
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps"); | ||
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs"); | ||
|
||
auto normalized = ctx.Attr<bool>("normalized"); | ||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( | ||
ctx.device_context()) | ||
.stream(); | ||
|
||
auto hyp_lod = x1_t->lod()[0]; | ||
auto ref_lod = x2_t->lod()[0]; | ||
PADDLE_ENFORCE( | ||
hyp_lod.size() == ref_lod.size(), | ||
"Input(Hyps) and Input(Refs) must have the same batch size."); | ||
for (size_t i = 1; i < ref_lod.size(); ++i) { | ||
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], | ||
"Reference string %d is empty.", i); | ||
} | ||
|
||
auto num_strs = hyp_lod.size() - 1; | ||
out_t->Resize({static_cast<int64_t>(num_strs), 1}); | ||
out_t->mutable_data<T>(ctx.GetPlace()); | ||
auto out = out_t->data<T>(); | ||
|
||
T distance = 0.0; | ||
for (size_t num = 0; num < num_strs; num++) { | ||
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]); | ||
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]); | ||
if (m == 0 || n == 0) { | ||
distance = std::max(m, n); | ||
if (normalized) { | ||
PADDLE_ENFORCE(n > 0, | ||
"The reference string (#%d) cannot be empty " | ||
"when Attr(normalized) is enabled.", | ||
n); | ||
distance = distance / n; | ||
} | ||
memory::Copy(boost::get<Place>(ctx.GetPlace()), out + num, | ||
platform::CPUPlace(), &distance, sizeof(T), stream); | ||
} else { | ||
framework::Tensor dist_t; | ||
dist_t.Resize({m + 1, n + 1}); | ||
dist_t.mutable_data<T>(ctx.GetPlace()); | ||
auto dist = dist_t.data<T>(); | ||
auto x1 = x1_t->data<int>() + hyp_lod[num]; | ||
auto x2 = x2_t->data<int>() + ref_lod[num]; | ||
|
||
FillFirstColumn<T><<<1 + m / PADDLE_CUDA_NUM_THREADS, | ||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, m, n); | ||
|
||
FillFirstRow<T><<<1 + n / PADDLE_CUDA_NUM_THREADS, | ||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, n); | ||
// Compute the elements of distance matrix in the anti-diagonal diretion | ||
for (int64_t slice = 2; slice < m + n + 1; ++slice) { | ||
int z_m = slice < m + 1 ? 0 : slice - m; | ||
int z_n = slice < n + 1 ? 0 : slice - n; | ||
int size = slice - (z_m + z_n) + 1; // number of elments in the same | ||
// anti-diagonal line to update | ||
// the start index at which computes from | ||
int start = slice < n + 1 ? slice : (z_n + 1) * (n + 1) - 1; | ||
Levenshtein<T><<<1 + (size - 1) / PADDLE_CUDA_NUM_THREADS, | ||
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(dist, x1, x2, | ||
m, n, start); | ||
} | ||
SetOutput<T><<<1, 1, 0, stream>>>(out + num, dist, m, n, normalized); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OP_CUDA_KERNEL( | ||
edit_distance, | ||
ops::EditDistanceGPUKernel<paddle::platform::CUDAPlace, float>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <algorithm> | ||
#include "paddle/framework/eigen.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename Place, typename T> | ||
class EditDistanceKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const { | ||
auto* out_t = ctx.Output<framework::Tensor>("Out"); | ||
|
||
auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps"); | ||
auto* x2_t = ctx.Input<framework::LoDTensor>("Refs"); | ||
|
||
auto normalized = ctx.Attr<bool>("normalized"); | ||
|
||
auto hyp_lod = x1_t->lod()[0]; | ||
auto ref_lod = x2_t->lod()[0]; | ||
PADDLE_ENFORCE( | ||
hyp_lod.size() == ref_lod.size(), | ||
"Input(Hyps) and Input(Refs) must have the same batch size."); | ||
for (size_t i = 1; i < ref_lod.size(); ++i) { | ||
PADDLE_ENFORCE(ref_lod[i] > ref_lod[i - 1], | ||
"Reference string %d is empty.", i); | ||
} | ||
auto num_strs = hyp_lod.size() - 1; | ||
|
||
out_t->Resize({static_cast<int64_t>(num_strs), 1}); | ||
out_t->mutable_data<float>(ctx.GetPlace()); | ||
auto out = out_t->data<T>(); | ||
|
||
T distance = 0.0; | ||
for (size_t num = 0; num < num_strs; ++num) { | ||
auto m = static_cast<int64_t>(hyp_lod[num + 1] - hyp_lod[num]); | ||
auto n = static_cast<int64_t>(ref_lod[num + 1] - ref_lod[num]); | ||
|
||
if (m == 0) { | ||
distance = n; | ||
} else if (n == 0) { | ||
distance = m; | ||
} else { | ||
framework::Tensor dist_t; | ||
dist_t.Resize({m + 1, n + 1}); | ||
dist_t.mutable_data<T>(ctx.GetPlace()); | ||
auto dist = dist_t.data<T>(); | ||
auto x1 = x1_t->data<int>() + hyp_lod[num]; | ||
auto x2 = x2_t->data<int>() + ref_lod[num]; | ||
for (int64_t i = 0; i < m + 1; ++i) { | ||
dist[i * (n + 1)] = i; | ||
} | ||
for (int64_t j = 0; j < n + 1; ++j) { | ||
dist[j] = j; | ||
} | ||
for (int64_t i = 1; i < m + 1; ++i) { | ||
for (int64_t j = 1; j < n + 1; ++j) { | ||
int cost = x1[i - 1] == x2[j - 1] ? 0 : 1; | ||
int dels = dist[(i - 1) * (n + 1) + j] + 1; | ||
int ins = dist[i * (n + 1) + (j - 1)] + 1; | ||
int subs = dist[(i - 1) * (n + 1) + (j - 1)] + cost; | ||
dist[i * (n + 1) + j] = std::min(dels, std::min(ins, subs)); | ||
} | ||
} | ||
distance = dist[m * (n + 1) + n]; | ||
} | ||
|
||
if (normalized) { | ||
PADDLE_ENFORCE(n > 0, | ||
"The reference string (#%d) cannot be empty " | ||
"when Attr(normalized) is enabled.", | ||
n); | ||
distance = distance / n; | ||
} | ||
out[num] = distance; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GPU implementation may be less efficient, it may be slower than CPU implementation. The for loop in line 97 also can be paralleled. But you can not change it in this PR. We can optimize it in the future when necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. There should be a lot efficiency improvement for the batch input