Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement MaxAuxLabels1(2) #50

Merged
merged 3 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# please sort the source files alphabetically
add_library(fsa
aux_labels.cc
fsa_algo.cc
fsa_equivalent.cc
fsa_renderer.cc
Expand Down Expand Up @@ -35,6 +36,7 @@ endfunction()

# please sort the source files alphabetically
set(fsa_tests
aux_labels_test
fsa_algo_test
fsa_equivalent_test
fsa_renderer_test
Expand Down
66 changes: 66 additions & 0 deletions k2/csrc/aux_labels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// k2/csrc/aux_labels.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/aux_labels.h"

#include <numeric>
#include <vector>

#include "glog/logging.h"
#include "k2/csrc/fsa.h"

namespace k2 {

void MapAuxLabels1(const AuxLabels &labels_in,
const std::vector<int32_t> &arc_map, AuxLabels *labels_out) {
CHECK_NOTNULL(labels_out);
auto &start_pos = labels_out->start_pos;
auto &labels = labels_out->labels;
start_pos.clear();
labels.clear();

qindazhu marked this conversation as resolved.
Show resolved Hide resolved
int32_t num_labels = 0;
for (const auto &arc_index : arc_map) {
start_pos.push_back(num_labels);
int32_t pos_start = labels_in.start_pos[arc_index];
int32_t pos_end = labels_in.start_pos[arc_index + 1];
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
for (int32_t pos = pos_start; pos != pos_end; ++pos) {
int32_t label = labels_in.labels[pos];
DCHECK_NE(label, kEpsilon);
labels.push_back(label);
++num_labels;
}
}
start_pos.push_back(num_labels);
}

void MapAuxLabels2(const AuxLabels &labels_in,
const std::vector<std::vector<int32_t>> &arc_map,
AuxLabels *labels_out) {
CHECK_NOTNULL(labels_out);
auto &start_pos = labels_out->start_pos;
auto &labels = labels_out->labels;
start_pos.clear();
labels.clear();
qindazhu marked this conversation as resolved.
Show resolved Hide resolved

int32_t num_labels = 0;
for (const auto &arc_indexes : arc_map) {
start_pos.push_back(num_labels);
for (const auto &arc_index : arc_indexes) {
int32_t pos_start = labels_in.start_pos[arc_index];
int32_t pos_end = labels_in.start_pos[arc_index + 1];
for (int32_t pos = pos_start; pos != pos_end; ++pos) {
int32_t label = labels_in.labels[pos];
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
DCHECK_NE(label, kEpsilon);
labels.push_back(label);
++num_labels;
}
}
}
start_pos.push_back(num_labels);
}

} // namespace k2
86 changes: 86 additions & 0 deletions k2/csrc/aux_labels_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// k2/csrc/aux_labels_test.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/aux_labels.h"

#include <utility>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "k2/csrc/fsa.h"

namespace k2 {

class AuxLablesTest : public ::testing::Test {
protected:
AuxLablesTest() {
std::vector<int32_t> start_pos = {0, 1, 3, 6, 7};
std::vector<int32_t> labels = {1, 2, 3, 4, 5, 6, 7};
aux_labels_in_.start_pos = std::move(start_pos);
aux_labels_in_.labels = std::move(labels);
}

AuxLabels aux_labels_in_;
};

TEST_F(AuxLablesTest, MapAuxLabels1) {
{
// empty arc_map
std::vector<int32_t> arc_map;
AuxLabels aux_labels_out;
// some dirty data
aux_labels_out.start_pos = {1, 2, 3};
aux_labels_out.labels = {4, 5};
MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out);

EXPECT_TRUE(aux_labels_out.labels.empty());
EXPECT_EQ(aux_labels_out.start_pos.size(), 1);
EXPECT_EQ(aux_labels_out.start_pos[0], 0);
}

{
std::vector<int32_t> arc_map = {2, 0, 3};
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
AuxLabels aux_labels_out;
MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out);

EXPECT_EQ(aux_labels_out.start_pos.size(), 4);
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
EXPECT_THAT(aux_labels_out.start_pos, ::testing::ElementsAre(0, 3, 4, 5));
EXPECT_EQ(aux_labels_out.labels.size(), 5);
EXPECT_THAT(aux_labels_out.labels, ::testing::ElementsAre(4, 5, 6, 1, 7));
}
}

TEST_F(AuxLablesTest, MapAuxLabels2) {
{
// empty arc_map
std::vector<std::vector<int32_t>> arc_map;
AuxLabels aux_labels_out;
// some dirty data
aux_labels_out.start_pos = {1, 2, 3};
aux_labels_out.labels = {4, 5};
MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out);

EXPECT_TRUE(aux_labels_out.labels.empty());
EXPECT_EQ(aux_labels_out.start_pos.size(), 1);
EXPECT_EQ(aux_labels_out.start_pos[0], 0);
}

{
std::vector<std::vector<int32_t>> arc_map = {{2, 3}, {0, 1}, {0}, {2}};
AuxLabels aux_labels_out;
MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out);

EXPECT_EQ(aux_labels_out.start_pos.size(), 5);
EXPECT_THAT(aux_labels_out.start_pos,
::testing::ElementsAre(0, 4, 7, 8, 11));
EXPECT_EQ(aux_labels_out.labels.size(), 11);
EXPECT_THAT(aux_labels_out.labels,
::testing::ElementsAre(4, 5, 6, 7, 1, 2, 3, 1, 4, 5, 6));
}
}

} // namespace k2
Loading