From 2e44c39ea5a11477ade18a286b362b8d90267217 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Mon, 1 Jun 2020 19:08:33 +0800 Subject: [PATCH 1/3] implement MaxAuxLabels1(2) --- k2/csrc/CMakeLists.txt | 2 + k2/csrc/aux_labels.cc | 66 ++++++++++++++ k2/csrc/aux_labels_test.cc | 86 ++++++++++++++++++ k2/csrc/dense_fsa.h | 177 ++++++++++++++++++------------------- k2/csrc/fsa.h | 47 ++++------ 5 files changed, 255 insertions(+), 123 deletions(-) create mode 100644 k2/csrc/aux_labels.cc create mode 100644 k2/csrc/aux_labels_test.cc diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 118875048..cd5b730de 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -1,5 +1,6 @@ # please sort the source files alphabetically add_library(fsa + aux_labels.cc fsa_algo.cc fsa_equivalent.cc fsa_renderer.cc @@ -35,6 +36,7 @@ endfunction() # please sort the source files alphabetically set(fsa_tests + aux_labels_test fsa_algo_test fsa_equivalent_test fsa_renderer_test diff --git a/k2/csrc/aux_labels.cc b/k2/csrc/aux_labels.cc new file mode 100644 index 000000000..3d9c27c4d --- /dev/null +++ b/k2/csrc/aux_labels.cc @@ -0,0 +1,66 @@ +// k2/csrc/aux_labels.cc + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../LICENSE for clarification regarding multiple authors + +#include "k2/csrc/aux_labels.h" + +#include +#include + +#include "glog/logging.h" +#include "k2/csrc/fsa.h" + +namespace k2 { + +void MapAuxLabels1(const AuxLabels &labels_in, + const std::vector &arc_map, AuxLabels *labels_out) { + CHECK_NOTNULL(labels_out); + auto &start_pos = labels_out->start_pos; + auto &labels = labels_out->labels; + start_pos.clear(); + labels.clear(); + + int32_t num_labels = 0; + for (const auto &arc_index : arc_map) { + start_pos.push_back(num_labels); + int32_t pos_start = labels_in.start_pos[arc_index]; + int32_t pos_end = labels_in.start_pos[arc_index + 1]; + for (int32_t pos = pos_start; pos != pos_end; ++pos) { + int32_t label = labels_in.labels[pos]; + DCHECK_NE(label, kEpsilon); + labels.push_back(label); + ++num_labels; + } + } + start_pos.push_back(num_labels); +} + +void MapAuxLabels2(const AuxLabels &labels_in, + const std::vector> &arc_map, + AuxLabels *labels_out) { + CHECK_NOTNULL(labels_out); + auto &start_pos = labels_out->start_pos; + auto &labels = labels_out->labels; + start_pos.clear(); + labels.clear(); + + int32_t num_labels = 0; + for (const auto &arc_indexes : arc_map) { + start_pos.push_back(num_labels); + for (const auto &arc_index : arc_indexes) { + int32_t pos_start = labels_in.start_pos[arc_index]; + int32_t pos_end = labels_in.start_pos[arc_index + 1]; + for (int32_t pos = pos_start; pos != pos_end; ++pos) { + int32_t label = labels_in.labels[pos]; + DCHECK_NE(label, kEpsilon); + labels.push_back(label); + ++num_labels; + } + } + } + start_pos.push_back(num_labels); +} + +} // namespace k2 diff --git a/k2/csrc/aux_labels_test.cc b/k2/csrc/aux_labels_test.cc new file mode 100644 index 000000000..c70808b83 --- /dev/null +++ b/k2/csrc/aux_labels_test.cc @@ -0,0 +1,86 @@ +// k2/csrc/aux_labels_test.cc + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../LICENSE for clarification regarding multiple authors + +#include "k2/csrc/aux_labels.h" + +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "k2/csrc/fsa.h" + +namespace k2 { + +class AuxLablesTest : public ::testing::Test { + protected: + AuxLablesTest() { + std::vector start_pos = {0, 1, 3, 6, 7}; + std::vector labels = {1, 2, 3, 4, 5, 6, 7}; + aux_labels_in_.start_pos = std::move(start_pos); + aux_labels_in_.labels = std::move(labels); + } + + AuxLabels aux_labels_in_; +}; + +TEST_F(AuxLablesTest, MapAuxLabels1) { + { + // empty arc_map + std::vector arc_map; + AuxLabels aux_labels_out; + // some dirty data + aux_labels_out.start_pos = {1, 2, 3}; + aux_labels_out.labels = {4, 5}; + MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); + + EXPECT_TRUE(aux_labels_out.labels.empty()); + EXPECT_EQ(aux_labels_out.start_pos.size(), 1); + EXPECT_EQ(aux_labels_out.start_pos[0], 0); + } + + { + std::vector arc_map = {2, 0, 3}; + AuxLabels aux_labels_out; + MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); + + EXPECT_EQ(aux_labels_out.start_pos.size(), 4); + EXPECT_THAT(aux_labels_out.start_pos, ::testing::ElementsAre(0, 3, 4, 5)); + EXPECT_EQ(aux_labels_out.labels.size(), 5); + EXPECT_THAT(aux_labels_out.labels, ::testing::ElementsAre(4, 5, 6, 1, 7)); + } +} + +TEST_F(AuxLablesTest, MapAuxLabels2) { + { + // empty arc_map + std::vector> arc_map; + AuxLabels aux_labels_out; + // some dirty data + aux_labels_out.start_pos = {1, 2, 3}; + aux_labels_out.labels = {4, 5}; + MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out); + + EXPECT_TRUE(aux_labels_out.labels.empty()); + EXPECT_EQ(aux_labels_out.start_pos.size(), 1); + EXPECT_EQ(aux_labels_out.start_pos[0], 0); + } + + { + std::vector> arc_map = {{2, 3}, {0, 1}, {0}, {2}}; + AuxLabels aux_labels_out; + MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out); + + EXPECT_EQ(aux_labels_out.start_pos.size(), 5); + EXPECT_THAT(aux_labels_out.start_pos, + ::testing::ElementsAre(0, 4, 7, 8, 11)); + EXPECT_EQ(aux_labels_out.labels.size(), 11); + EXPECT_THAT(aux_labels_out.labels, + ::testing::ElementsAre(4, 5, 6, 7, 1, 2, 3, 1, 4, 5, 6)); + } +} + +} // namespace k2 diff --git a/k2/csrc/dense_fsa.h b/k2/csrc/dense_fsa.h index a7c29c297..e64f3a2e5 100644 --- a/k2/csrc/dense_fsa.h +++ b/k2/csrc/dense_fsa.h @@ -13,12 +13,11 @@ #include #include "glog/logging.h" -#include "k2/csrc/util.h" #include "k2/csrc/fsa.h" +#include "k2/csrc/util.h" namespace k2 { - /* DenseFsa represents an FSA stored as a matrix, representing something like CTC output from a neural net. `data` is a (T+1) by N @@ -53,30 +52,31 @@ struct DenseFsa { */ int32_t num_states() { return T + 2; } int32_t arc_indexes(int32_t state_index) { - return arc_offset + (state_index <= T ? state_index*num_symbols : - T*num_symbols + 1); + return arc_offset + + (state_index <= T ? state_index * num_symbols : T * num_symbols + 1); } Arc arc(int32_t arc_index) { arc_index -= arc_offset; - int32_t state_index = (arc_index-arc_offset) / num_symbols; - return Arc(state_index, state_index+1, - (state_index < T ? (arc_index%num_symbols) : kFinalSymbol)); + int32_t state_index = (arc_index - arc_offset) / num_symbols; + return Arc(state_index, state_index + 1, + (state_index < T ? (arc_index % num_symbols) : kFinalSymbol)); } - /* Constructor - @param [in] T number of frames / sequence length. This FSA has T+2 states; - state T+1 is the final-state, and state T has only a single - arc, with symbol kFinalSymbol, with arc-index - `arc_index = arc_offset+(T*num_symbols)`, to state T+1; the weight - on this is data[arc_index] == 0. - - All other states t( - seg_frame_index + (num_segs + 1)); + DenseFsaVecFrameInfo *frame_info() { + return reinterpret_cast(seg_frame_index + + (num_segs + 1)); } int32_t frame_info_dim() { return seg_frame_index[num_segs]; } int32_t num_frames_padded() { return seg_frame_index[num_segs]; } - // and next we have the following, which will be used from // the Python calling code to copy the neural-net output to the correct // location. This lists the elements of `frame_info` but with the @@ -138,15 +135,17 @@ struct DenseFsaVecMeta { // where num_frames == \sum_segment length(segment) == total number of nnet // output frames over all segments. // - // int32_t frame_index[num_frames]; # frame_index will be an index into `frame_info`; - // # this will be of the form 0 1 2 4 5 6 7 8 9 11 ... - // # (note the gaps where the zero padding was!) - int32_t* frame_index() { - return reinterpret_cast(frame_info() + frame_info_dim()); + // int32_t frame_index[num_frames]; # frame_index will be an index into + // `frame_info`; + // # this will be of the form 0 1 2 4 5 6 7 8 + // 9 11 ... # (note the gaps where the zero + // padding was!) + int32_t *frame_index() { + return reinterpret_cast(frame_info() + frame_info_dim()); } int32_t frame_index_dim() { int32 num_frames_padded = seg_frame_index[num_segs], - num_frames = num_frames_padded - num_segs; + num_frames = num_frames_padded - num_segs; return num_frames; } @@ -159,7 +158,6 @@ struct DenseFsaVecMeta { // 4 + 5*num_segs + 5*num_frames }; - struct DenseFsaVecFrameInfo { int32_t seg_id; // The segment-id that this frame is part of (0 <= seg_id <= // num_segs)... will be of the form 0 0 0 0 1 1 1 1 1 1 1 1 2 @@ -168,15 +166,16 @@ struct DenseFsaVecFrameInfo { // Would equal seg_id in the case where it was one segment // per sequence. int32_t frame_in_seg; // The frame-index within the segment, so would be 0 - // for the 1st frame of each segment, and `this_seg_num_frames` - // for the last (which could contain all zeros). - // Will be of the form 0 1 2 3 0 1 2 3 4 5 6 0 1 2 0 1 2.... - int32_t frame_in_seq; // The frame-index within the sequence that this segment - // is a part of. Would be the same as `frame_in_seg` if - // this segment starts at frame zero of its sequence. + // for the 1st frame of each segment, and + // `this_seg_num_frames` for the last (which could + // contain all zeros). Will be of the form 0 1 2 3 0 1 + // 2 3 4 5 6 0 1 2 0 1 2.... + int32_t + frame_in_seq; // The frame-index within the sequence that this segment + // is a part of. Would be the same as `frame_in_seg` if + // this segment starts at frame zero of its sequence. }; - /** Creates meta-info for DenseFsaVec (DenseFsaVecMeta) as one block in memory. For some of the terminology, see the comment above the definition class @@ -192,42 +191,38 @@ struct DenseFsaVecFrameInfo { There are `num_segs` segments. Each segment is a subset of the frames in a sequence. - @param [in] num_seqs Number of sequences of (e.g.) phone posteriors/loglikes - from the neural net + @param [in] num_seqs Number of sequences of (e.g.) phone + posteriors/loglikes from the neural net @param [in] frames_per_seq Number of frames in each sequence - @param [in] num_symbols Dimension of the neural network output, interpreted - for instance as epsilon/blank and the rest are phones - or letters. - @param [in] num_segs Number of segments. Each segment represents a range of - frames within a sequence. There will in general be - at least as many segments as sequences. - @param [in] seq_id Indexed by 0 <= seg_id < num_segs, seq_id[seg_id] contains the - sequence index 0 <= s < num_seqs to which this - segment belongs - @param [in] frame_begin Indexed by 0 <= seg_id < num_segs, frame_begin[seg_id] - contains the index of the first frame of that segment. + @param [in] num_symbols Dimension of the neural network output, + interpreted for instance as epsilon/blank + and the rest are phones or letters. + @param [in] num_segs Number of segments. Each segment represents a range + of frames within a sequence. There will in general + be at least as many segments as sequences. + @param [in] seq_id Indexed by 0 <= seg_id < num_segs, seq_id[seg_id] + contains the sequence index 0 <= s < num_seqs to + which this segment belongs + @param [in] frame_begin Indexed by 0 <= seg_id < num_segs, + frame_begin[seg_id] contains the index of the + first frame of that segment. @param [in] frame_end Indexed by 0 <= seg_id < num_segs, frame_end[seg_id] - contains the index of the last-plus-one frame of that segment. - @param [in] storage_size Size of `storage` array, in int32_t elements. Defining - num_frames = sum(frame_end) - sum(frame_begin), - storage_size must equal 4 + 5*num_segs + 5*num_frames. - It is provided as an arg for checking purposes. - @param [out] storage Pointer to an array of int32_t where we put the meta-info (probably - part of a torch.Tensor). It will be interpreted internally - as type DenseFsaVecMeta. + contains the index of the last-plus-one frame of that + segment. + @param [in] storage_size Size of `storage` array, in int32_t elements. + Defining num_frames = sum(frame_end) - + sum(frame_begin), storage_size must equal + 4 + 5*num_segs + 5*num_frames. It is provided as + an arg for checking purposes. + @param [out] storage Pointer to an array of int32_t where we put the + meta-info (probably part of a torch.Tensor). It will + be interpreted internally as type DenseFsaVecMeta. */ -void CreateDenseFsaVecMeta( - int32_t num_seqs, - int32_t frames_per_seq, - int32_t num_symbols, - int32_t num_segs, - const int32_t *seq_id, - const int32_t *frame_begin, - const int32_t *frame_end, - ssize_t storage_size, - int32_t *storage); - - +void CreateDenseFsaVecMeta(int32_t num_seqs, int32_t frames_per_seq, + int32_t num_symbols, int32_t num_segs, + const int32_t *seq_id, const int32_t *frame_begin, + const int32_t *frame_end, ssize_t storage_size, + int32_t *storage); /** DenseFsaVec represents a vector of FSAs with a special regular @@ -246,32 +241,30 @@ void CreateDenseFsaVecMeta( */ struct DenseFsaVec { - - /* Constructor. - @param [in] meta The meta-info, as written to by CreateDenseFsaVecMeta(). + @param [in] meta The meta-info, as written to by + CreateDenseFsaVecMeta(). @param [in] data A contiguous, row-major matrix of shape (meta_info->num_frames_padded(),meta_info->num_symbols), containing the neural-net outputs for each segment with zero rows in between for padding. */ - DenseFsaVec(const DenseFsaVecMeta *meta, - const float *data): - meta(meta), data(data) { Check(); } - + DenseFsaVec(const DenseFsaVecMeta *meta, const float *data) + : meta(meta), data(data) { + Check(); + } void Check(); // Sanity check (spot check, not thorough) on `meta_info` - const DenseFsaVecMeta *meta; const float *data; - DenseFsa operator[] (int32_t seg_id) { + DenseFsa operator[](int32_t seg_id) { CHECK_LT(seg_id, meta->num_segs); int32_t start_frame_index = meta->seg_frame_index[seg_id], - end_frame_index = meta->seg_frame_index[seg_id+1]; + end_frame_index = meta->seg_frame_index[seg_id + 1]; // below, the -1 is to exclude the zero-padding frame. int23_t T = end_frame_index - start_frame_index - 1; int32_t arc_offset = meta->num_symbols * start_frame_index; @@ -299,11 +292,9 @@ void IntersectPruned(const DenseFsa &a, const Fsa &b, float beam, Fsa *c, std::vector *arc_map_a = nullptr, std::vector *arc_map_b = nullptr); - /* Convert DenseFsa to regular Fsa (for testing purposes) */ void DenseToFsa(const DenseFsa &a, Fsa *b); - } // namespace k2 #endif // K2_CSRC_DENSE_FSA_H_ diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index 0c2b2f78d..164987e7c 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -25,7 +25,6 @@ enum { // like in OpenFst. }; - // CAUTION: the sizeof() this is probably 128, not 96. This could be a // waste of space. We may later either use the extra field for something, or // find a way to reduce the size. @@ -126,7 +125,6 @@ struct Fsa { } }; - /* Cfsa is a 'const' FSA, which we'll use as the input to operations. It is designed in such a way that the storage underlying it may either be an Fsa @@ -147,31 +145,27 @@ struct Cfsa { // one-past-the-last index in `arcs` for the arcs in this // FSA. Will be >= begin_arc. - const int32_t *arc_indexes; // an array, indexed by state index, giving the - // first arc index of each state. The last one - // is repeated, so for any valid state 0 <= s < - // num_states we can use arc_indexes[s+1]. That - // is: elements 0 through num_states (inclusive) - // are valid. CAUTION: arc_indexes[0] may be - // greater than zero. - + const int32_t *arc_indexes; // an array, indexed by state index, giving the + // first arc index of each state. The last one + // is repeated, so for any valid state 0 <= s < + // num_states we can use arc_indexes[s+1]. That + // is: elements 0 through num_states (inclusive) + // are valid. CAUTION: arc_indexes[0] may be + // greater than zero. - Arc *arcs; // Note: arcs[BeginArcIndex()] through arcs[EndArcIndex() - 1] - // are valid. + Arc *arcs; // Note: arcs[BeginArcIndex()] through arcs[EndArcIndex() - 1] + // are valid. // Constructor from Fsa - Cfsa(const Fsa &fsa); + explicit Cfsa(const Fsa &fsa); - Cfsa &operator = (const Cfsa &cfsa) = default; + Cfsa &operator=(const Cfsa &cfsa) = default; Cfsa(const Cfsa &cfsa) = default; int32_t NumStates() const { return num_states; } int32_t FinalState() const { return num_states - 1; } }; - - - class CfsaVec { public: /* @@ -208,8 +202,8 @@ class CfsaVec { FSA f is empty, and >= 2 otherwise. [possibly some padding here] - - arc_indexes[tot_states + 1] This gives the indexes into the `arcs` - array of where we can find the first of each state's arcs. + - arc_indexes[tot_states + 1] This gives the indexes into the + `arcs` array of where we can find the first of each state's arcs. [pad as needed for memory-alignment purposes then...] @@ -219,11 +213,11 @@ class CfsaVec { int32_t NumFsas() const { return num_fsas_; } - Cfsa operator[] (int32_t f) const; + Cfsa operator[](int32_t f) const; private: - CfsaVec &operator = (const CfsaVec &); // Disable - CfsaVec(const CfsaVec&); // Disable + CfsaVec &operator=(const CfsaVec &); // Disable + CfsaVec(const CfsaVec &); // Disable int32_t num_fsas_; @@ -233,9 +227,6 @@ class CfsaVec { size_t size_; }; - - - /* Return the number of bytes we'd need to represent this vector of Cfsas linearly as a CfsaVec. */ @@ -256,11 +247,7 @@ size_t GetCfsaVecSize(const Cfsa &fsa_in); must equal the return value of GetCfsaVecSize(fsas_in). */ -void CreateCfsaVec(const std::vector &fsas_in, - void *data, - size_t size); - - +void CreateCfsaVec(const std::vector &fsas_in, void *data, size_t size); struct Fst { Fsa core; From 8ba4db9b689793b13d587b6777c764972d6937d7 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Mon, 1 Jun 2020 22:21:23 +0800 Subject: [PATCH 2/3] fix some issues --- k2/csrc/aux_labels.cc | 23 +++++++++++------------ k2/csrc/aux_labels_test.cc | 26 ++++++++++++++++++++------ k2/csrc/dense_fsa.h | 4 ++-- k2/csrc/util.h | 2 -- 4 files changed, 33 insertions(+), 22 deletions(-) diff --git a/k2/csrc/aux_labels.cc b/k2/csrc/aux_labels.cc index 3d9c27c4d..8ddb0f242 100644 --- a/k2/csrc/aux_labels.cc +++ b/k2/csrc/aux_labels.cc @@ -20,19 +20,19 @@ void MapAuxLabels1(const AuxLabels &labels_in, auto &start_pos = labels_out->start_pos; auto &labels = labels_out->labels; start_pos.clear(); + start_pos.reserve(arc_map.size() + 1); labels.clear(); int32_t num_labels = 0; + auto labels_in_iter_begin = labels_in.labels.begin(); for (const auto &arc_index : arc_map) { start_pos.push_back(num_labels); int32_t pos_start = labels_in.start_pos[arc_index]; int32_t pos_end = labels_in.start_pos[arc_index + 1]; - for (int32_t pos = pos_start; pos != pos_end; ++pos) { - int32_t label = labels_in.labels[pos]; - DCHECK_NE(label, kEpsilon); - labels.push_back(label); - ++num_labels; - } + // TODO(haowen): should we check labels contains no Epsilon? + labels.insert(labels.end(), labels_in_iter_begin + pos_start, + labels_in_iter_begin + pos_end); + num_labels += pos_end - pos_start; } start_pos.push_back(num_labels); } @@ -44,20 +44,19 @@ void MapAuxLabels2(const AuxLabels &labels_in, auto &start_pos = labels_out->start_pos; auto &labels = labels_out->labels; start_pos.clear(); + start_pos.reserve(arc_map.size() + 1); labels.clear(); int32_t num_labels = 0; + auto labels_in_iter_begin = labels_in.labels.begin(); for (const auto &arc_indexes : arc_map) { start_pos.push_back(num_labels); for (const auto &arc_index : arc_indexes) { int32_t pos_start = labels_in.start_pos[arc_index]; int32_t pos_end = labels_in.start_pos[arc_index + 1]; - for (int32_t pos = pos_start; pos != pos_end; ++pos) { - int32_t label = labels_in.labels[pos]; - DCHECK_NE(label, kEpsilon); - labels.push_back(label); - ++num_labels; - } + labels.insert(labels.end(), labels_in_iter_begin + pos_start, + labels_in_iter_begin + pos_end); + num_labels += pos_end - pos_start; } } start_pos.push_back(num_labels); diff --git a/k2/csrc/aux_labels_test.cc b/k2/csrc/aux_labels_test.cc index c70808b83..be86eb0f5 100644 --- a/k2/csrc/aux_labels_test.cc +++ b/k2/csrc/aux_labels_test.cc @@ -38,7 +38,7 @@ TEST_F(AuxLablesTest, MapAuxLabels1) { MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); EXPECT_TRUE(aux_labels_out.labels.empty()); - EXPECT_EQ(aux_labels_out.start_pos.size(), 1); + ASSERT_EQ(aux_labels_out.start_pos.size(), 1); EXPECT_EQ(aux_labels_out.start_pos[0], 0); } @@ -47,11 +47,25 @@ TEST_F(AuxLablesTest, MapAuxLabels1) { AuxLabels aux_labels_out; MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); - EXPECT_EQ(aux_labels_out.start_pos.size(), 4); + ASSERT_EQ(aux_labels_out.start_pos.size(), 4); EXPECT_THAT(aux_labels_out.start_pos, ::testing::ElementsAre(0, 3, 4, 5)); - EXPECT_EQ(aux_labels_out.labels.size(), 5); + ASSERT_EQ(aux_labels_out.labels.size(), 5); EXPECT_THAT(aux_labels_out.labels, ::testing::ElementsAre(4, 5, 6, 1, 7)); } + + { + // all arcs in input fsa are remained + std::vector arc_map = {2, 0, 3, 1}; + AuxLabels aux_labels_out; + MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); + + ASSERT_EQ(aux_labels_out.start_pos.size(), 5); + EXPECT_THAT(aux_labels_out.start_pos, + ::testing::ElementsAre(0, 3, 4, 5, 7)); + ASSERT_EQ(aux_labels_out.labels.size(), 7); + EXPECT_THAT(aux_labels_out.labels, + ::testing::ElementsAre(4, 5, 6, 1, 7, 2, 3)); + } } TEST_F(AuxLablesTest, MapAuxLabels2) { @@ -65,7 +79,7 @@ TEST_F(AuxLablesTest, MapAuxLabels2) { MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out); EXPECT_TRUE(aux_labels_out.labels.empty()); - EXPECT_EQ(aux_labels_out.start_pos.size(), 1); + ASSERT_EQ(aux_labels_out.start_pos.size(), 1); EXPECT_EQ(aux_labels_out.start_pos[0], 0); } @@ -74,10 +88,10 @@ TEST_F(AuxLablesTest, MapAuxLabels2) { AuxLabels aux_labels_out; MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out); - EXPECT_EQ(aux_labels_out.start_pos.size(), 5); + ASSERT_EQ(aux_labels_out.start_pos.size(), 5); EXPECT_THAT(aux_labels_out.start_pos, ::testing::ElementsAre(0, 4, 7, 8, 11)); - EXPECT_EQ(aux_labels_out.labels.size(), 11); + ASSERT_EQ(aux_labels_out.labels.size(), 11); EXPECT_THAT(aux_labels_out.labels, ::testing::ElementsAre(4, 5, 6, 7, 1, 2, 3, 1, 4, 5, 6)); } diff --git a/k2/csrc/dense_fsa.h b/k2/csrc/dense_fsa.h index e64f3a2e5..5400f52e7 100644 --- a/k2/csrc/dense_fsa.h +++ b/k2/csrc/dense_fsa.h @@ -50,8 +50,8 @@ struct DenseFsa { it may be more efficient to use what is known about the structure of this object. But they may be useful for documentation and testing. */ - int32_t num_states() { return T + 2; } - int32_t arc_indexes(int32_t state_index) { + int32_t NumStates() const { return T + 2; } + int32_t ArcIndexes(int32_t state_index) const { return arc_offset + (state_index <= T ? state_index * num_symbols : T * num_symbols + 1); } diff --git a/k2/csrc/util.h b/k2/csrc/util.h index a4dd88b81..541f5846d 100644 --- a/k2/csrc/util.h +++ b/k2/csrc/util.h @@ -14,8 +14,6 @@ #include #include -#include "k2/csrc/fsa.h" - namespace k2 { #define EXPECT_DOUBLE_ARRAY_APPROX_EQ(expected, actual, abs_error) \ From 3d27c601ab54c477eedfcc936bbd145764e677e4 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Mon, 1 Jun 2020 22:39:00 +0800 Subject: [PATCH 3/3] documented that treat epsilon same as other symbols --- k2/csrc/aux_labels.cc | 1 - k2/csrc/aux_labels.h | 21 +++++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/k2/csrc/aux_labels.cc b/k2/csrc/aux_labels.cc index 8ddb0f242..5c72f2308 100644 --- a/k2/csrc/aux_labels.cc +++ b/k2/csrc/aux_labels.cc @@ -29,7 +29,6 @@ void MapAuxLabels1(const AuxLabels &labels_in, start_pos.push_back(num_labels); int32_t pos_start = labels_in.start_pos[arc_index]; int32_t pos_end = labels_in.start_pos[arc_index + 1]; - // TODO(haowen): should we check labels contains no Epsilon? labels.insert(labels.end(), labels_in_iter_begin + pos_start, labels_in_iter_begin + pos_end); num_labels += pos_end - pos_start; diff --git a/k2/csrc/aux_labels.h b/k2/csrc/aux_labels.h index e6ab199c8..9dabd61cd 100644 --- a/k2/csrc/aux_labels.h +++ b/k2/csrc/aux_labels.h @@ -29,7 +29,6 @@ namespace k2 { */ - /* This allows you to store auxiliary labels (e.g. olabels or ilabels) on each arc of an Fsa. @@ -40,13 +39,13 @@ struct AuxLabels { `labels` of the label sequence on arc i. start_pos.end() equals labels.size(). */ std::vector start_pos; - /* For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ... labels[start_pos[i+1]-1]) - are the list of labels on that arc. None of the elements of `labels` are - expected to be zero (epsilon). */ + /* For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ... + labels[start_pos[i+1]-1]) are the list of labels on that arc. + We treat epsilon the same as other symbols here, so there are no + requirements on elements of `labels`. */ std::vector labels; }; - /* Maps auxiliary labels after an FSA operation where each arc in the output FSA corresponds to exactly one arc in the input FSA. @@ -57,8 +56,7 @@ struct AuxLabels { @param [in] labels_out Labels on the arcs of the output FSA */ void MapAuxLabels1(const AuxLabels &labels_in, - const std::vector &arc_map, - AuxLabels *labels_out); + const std::vector &arc_map, AuxLabels *labels_out); /* Maps auxiliary labels after an FSA operation where each arc in the output @@ -70,10 +68,9 @@ void MapAuxLabels1(const AuxLabels &labels_in, @param [in] labels_out Labels on the arcs of the output FSA */ void MapAuxLabels2(const AuxLabels &labels_in, - const std::vector > &arc_map, + const std::vector> &arc_map, AuxLabels *labels_out); - /* Invert an FST, swapping the symbols in the FSA with the auxiliary labels. (e.g. swap input and output symbols in FST, but you decide which is which). @@ -92,13 +89,9 @@ void MapAuxLabels2(const AuxLabels &labels_in, `fsa_in`, although epsilons (kEpsilon, zeros) will be removed. */ -void InvertFst(const Fsa &fsa_in, - const AuxLabels &labels_in, - Fsa *fsa_out, +void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out, AuxLabels *aux_labels_out); - - } // namespace k2 #endif // K2_CSRC_AUX_LABELS_H_