From 60f310976e57acacf06c5017e6f047d5034a7be7 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Thu, 11 Jun 2020 08:47:57 +0800 Subject: [PATCH] replace AuxLabels with Array2 --- k2/csrc/array.h | 37 ++++++- k2/csrc/aux_labels.cc | 188 ++++++++++++++++++-------------- k2/csrc/aux_labels.h | 169 +++++++++++++++++++---------- k2/csrc/aux_labels_test.cc | 215 ++++++++++++++++++++++--------------- 4 files changed, 383 insertions(+), 226 deletions(-) diff --git a/k2/csrc/array.h b/k2/csrc/array.h index cb1b266c1..8801b823a 100644 --- a/k2/csrc/array.h +++ b/k2/csrc/array.h @@ -35,15 +35,18 @@ struct StridedPtr { : data(data), stride(stride) {} StridedPtr(const StridedPtr &other) : data(other.data), stride(other.stride) {} - StridedPtr &operator=(const StridedPtr &other) { - StridedPtr tmp(other); - std::swap(data, tmp.data); - std::swap(stride, tmp.stride); + StridedPtr(StridedPtr &&other) { this->Swap(other); } + StridedPtr &operator=(StridedPtr other) { + this->Swap(other); return *this; } bool operator==(const StridedPtr &other) const { return data == other.data && stride == other.stride; } + void Swap(StridedPtr &other) { + std::swap(data, other.data); + std::swap(stride, other.stride); + } }; /* @@ -69,6 +72,10 @@ struct Array2 { using PtrT = Ptr; using ValueType = typename std::iterator_traits::value_type; + Array2() = default; + Array2(IndexT size1, IndexT *indexes, IndexT size2, PtrT data) + : size1(size1), indexes(indexes), size2(size2), data(data) {} + IndexT size1; IndexT *indexes; // indexes[0,1,...size1] should be defined; note, this // means the array must be of at least size1+1. We @@ -85,6 +92,17 @@ struct Array2 { bool Empty() const { return size1 == 0; } + // just to replace `Swap` functions for Fsa and AuxLabels for now, + // may delete it if we finally find that we don't need to call it. + void Swap(Array2 &other) { + std::swap(size1, other.size1); + std::swap(size2, other.size2); + std::swap(indexes, other.indexes); + // it's OK here for Ptr=StridedPtr as we have specialized + // std::swap for StridedPtr + std::swap(data, other.data); + } + /* initialized definition: An Array2 object is initialized if its `size1` member and `size2` member are set and its `indexes` and `data` pointer allocated, and the values of @@ -197,10 +215,21 @@ struct Array2Storage { namespace std { template + struct iterator_traits> { typedef T value_type; }; +template +void swap(k2::StridedPtr &lhs, k2::StridedPtr &rhs) { + lhs.Swap(rhs); +} + +template +void swap(k2::Array2 &lhs, k2::Array2 &rhs) { + lhs.Swap(rhs); +} + } // namespace std #endif // K2_CSRC_ARRAY_H_ diff --git a/k2/csrc/aux_labels.cc b/k2/csrc/aux_labels.cc index 4f6730094..305c65bd2 100644 --- a/k2/csrc/aux_labels.cc +++ b/k2/csrc/aux_labels.cc @@ -37,9 +37,9 @@ static void CountExtraStates(const k2::Fsa &fsa_in, auto &states = *num_extra_states; for (int32_t i = 0; i != fsa_in.arcs.size(); ++i) { const auto &arc = fsa_in.arcs[i]; - int32_t pos_start = labels_in.start_pos[i]; - int32_t pos_end = labels_in.start_pos[i + 1]; - states[arc.dest_state] += std::max(0, pos_end - pos_start - 1); + int32_t begin = labels_in.indexes[i]; + int32_t end = labels_in.indexes[i + 1]; + states[arc.dest_state] += std::max(0, end - begin - 1); } } @@ -87,129 +87,152 @@ static void MapStates(const std::vector &num_extra_states, namespace k2 { -void Swap(AuxLabels *labels1, AuxLabels *labels2) { - CHECK_NOTNULL(labels1); - CHECK_NOTNULL(labels2); - std::swap(labels1->start_pos, labels2->start_pos); - std::swap(labels1->labels, labels2->labels); +void AuxLabels1Mapper::GetSizes(Array2Size *aux_size) { + CHECK_NOTNULL(aux_size); + aux_size->size1 = arc_map_.size(); + int32_t num_labels = 0; + for (const auto &arc_index : arc_map_) { + int32_t begin = labels_in_.indexes[arc_index]; + int32_t end = labels_in_.indexes[arc_index + 1]; + num_labels += end - begin; + } + aux_size->size2 = num_labels; } -void MapAuxLabels1(const AuxLabels &labels_in, - const std::vector &arc_map, AuxLabels *labels_out) { +void AuxLabels1Mapper::GetOutput(AuxLabels *labels_out) { CHECK_NOTNULL(labels_out); - auto &start_pos = labels_out->start_pos; - auto &labels = labels_out->labels; - start_pos.clear(); - start_pos.reserve(arc_map.size() + 1); - labels.clear(); + auto &start_pos = labels_out->indexes; + auto &labels = labels_out->data; + int32_t num_labels = 0; + int32_t i = 0; + for (; i != arc_map_.size(); ++i) { + start_pos[i] = num_labels; + const auto arc_index = arc_map_[i]; + int32_t begin = labels_in_.indexes[arc_index]; + int32_t end = labels_in_.indexes[arc_index + 1]; + for (auto it = begin; it != end; ++it) { + labels[num_labels++] = labels_in_.data[it]; + } + } + start_pos[i] = num_labels; +} +void AuxLabels2Mapper::GetSizes(Array2Size *aux_size) { + CHECK_NOTNULL(aux_size); + aux_size->size1 = arc_map_.size(); int32_t num_labels = 0; - auto labels_in_iter_begin = labels_in.labels.begin(); - for (const auto &arc_index : arc_map) { - start_pos.push_back(num_labels); - int32_t pos_start = labels_in.start_pos[arc_index]; - int32_t pos_end = labels_in.start_pos[arc_index + 1]; - labels.insert(labels.end(), labels_in_iter_begin + pos_start, - labels_in_iter_begin + pos_end); - num_labels += pos_end - pos_start; + for (const auto &arc_indexes : arc_map_) { + for (const auto &arc_index : arc_indexes) { + int32_t begin = labels_in_.indexes[arc_index]; + int32_t end = labels_in_.indexes[arc_index + 1]; + num_labels += end - begin; + } } - start_pos.push_back(num_labels); + aux_size->size2 = num_labels; } -void MapAuxLabels2(const AuxLabels &labels_in, - const std::vector> &arc_map, - AuxLabels *labels_out) { +void AuxLabels2Mapper::GetOutput(AuxLabels *labels_out) { CHECK_NOTNULL(labels_out); - auto &start_pos = labels_out->start_pos; - auto &labels = labels_out->labels; - start_pos.clear(); - start_pos.reserve(arc_map.size() + 1); - labels.clear(); - + auto &start_pos = labels_out->indexes; + auto &labels = labels_out->data; int32_t num_labels = 0; - auto labels_in_iter_begin = labels_in.labels.begin(); - for (const auto &arc_indexes : arc_map) { - start_pos.push_back(num_labels); - for (const auto &arc_index : arc_indexes) { - int32_t pos_start = labels_in.start_pos[arc_index]; - int32_t pos_end = labels_in.start_pos[arc_index + 1]; - labels.insert(labels.end(), labels_in_iter_begin + pos_start, - labels_in_iter_begin + pos_end); - num_labels += pos_end - pos_start; + int32_t i = 0; + for (; i != arc_map_.size(); ++i) { + start_pos[i] = num_labels; + for (const auto &arc_index : arc_map_[i]) { + int32_t begin = labels_in_.indexes[arc_index]; + int32_t end = labels_in_.indexes[arc_index + 1]; + for (auto it = begin; it != end; ++it) { + labels[num_labels++] = labels_in_.data[it]; + } } } - start_pos.push_back(num_labels); + start_pos[i] = num_labels; +} + +void FstInverter::GetSizes(Array2Size *fsa_size, + Array2Size *aux_size) { + CHECK_NOTNULL(fsa_size); + CHECK_NOTNULL(aux_size); + int32_t num_extra_states = 0; + int32_t num_arcs = 0; + int32_t num_non_eps_labels = 0; + for (int32_t i = 0; i != fsa_in_.arcs.size(); ++i) { + const auto &arc = fsa_in_.arcs[i]; + int32_t begin = labels_in_.indexes[i]; + int32_t end = labels_in_.indexes[i + 1]; + num_extra_states += std::max(0, end - begin - 1); + num_arcs += std::max(1, end - begin); + if (arc.label != kEpsilon) ++num_non_eps_labels; + } + fsa_size->size1 = num_extra_states + fsa_in_.NumStates(); + fsa_size->size2 = num_arcs; + aux_size->size1 = num_arcs; + aux_size->size2 = num_non_eps_labels; } -void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out, - AuxLabels *aux_labels_out) { +void FstInverter::GetOutput(Fsa *fsa_out, AuxLabels *labels_out) { CHECK_NOTNULL(fsa_out); - CHECK_NOTNULL(aux_labels_out); - fsa_out->arc_indexes.clear(); - fsa_out->arcs.clear(); - aux_labels_out->start_pos.clear(); - aux_labels_out->labels.clear(); - - if (IsEmpty(fsa_in)) { - aux_labels_out->start_pos.push_back(0); + CHECK_NOTNULL(labels_out); + + if (IsEmpty(fsa_in_)) { + labels_out->indexes[0] = 0; return; } - auto num_states_in = fsa_in.NumStates(); + auto num_states_in = fsa_in_.NumStates(); // get the number of extra states we need to create for each state // in fsa_in when inverting std::vector num_extra_states(num_states_in, 0); - CountExtraStates(fsa_in, labels_in, &num_extra_states); + CountExtraStates(fsa_in_, labels_in_, &num_extra_states); // map state in fsa_in to state in fsa_out std::vector state_map(num_states_in, 0); std::vector state_ids(num_states_in, 0); MapStates(num_extra_states, &state_map, &state_ids); - // a maximal approximation - int32_t num_arcs_out = labels_in.labels.size() + fsa_in.arcs.size(); + // TODO(haowen): replace with fsa_out->size2 std::vector arcs; - arcs.reserve(num_arcs_out); - // `+1` for the end position of the last arc's olabel sequence + arcs.reserve(labels_out->size1); std::vector start_pos; - start_pos.reserve(num_arcs_out + 1); + start_pos.reserve(labels_out->size1 + 1); std::vector labels; - labels.reserve(fsa_in.arcs.size()); - int32_t final_state_in = fsa_in.FinalState(); + labels.reserve(labels_out->size2); + int32_t final_state_in = fsa_in_.FinalState(); int32_t num_non_eps_ilabel_processed = 0; start_pos.push_back(0); - for (auto i = 0; i != fsa_in.arcs.size(); ++i) { - const auto &arc = fsa_in.arcs[i]; - int32_t pos_start = labels_in.start_pos[i]; - int32_t pos_end = labels_in.start_pos[i + 1]; + for (auto i = 0; i != fsa_in_.arcs.size(); ++i) { + const auto &arc = fsa_in_.arcs[i]; + int32_t pos_begin = labels_in_.indexes[i]; + int32_t pos_end = labels_in_.indexes[i + 1]; int32_t src_state = arc.src_state; int32_t dest_state = arc.dest_state; if (dest_state == final_state_in) { // every arc entering the final state must have exactly // one olabel == kFinalSymbol - CHECK_EQ(pos_start + 1, pos_end); - CHECK_EQ(labels_in.labels[pos_start], kFinalSymbol); + CHECK_EQ(pos_begin + 1, pos_end); + CHECK_EQ(labels_in_.data[pos_begin], kFinalSymbol); } - if (pos_end - pos_start <= 1) { + if (pos_end - pos_begin <= 1) { int32_t curr_label = - (pos_end - pos_start == 0) ? kEpsilon : labels_in.labels[pos_start]; + (pos_end - pos_begin == 0) ? kEpsilon : labels_in_.data[pos_begin]; arcs.emplace_back(state_map[src_state], state_map[dest_state], curr_label); } else { // expand arcs with olabels arcs.emplace_back(state_map[src_state], state_ids[dest_state] + 1, - labels_in.labels[pos_start]); + labels_in_.data[pos_begin]); start_pos.push_back(num_non_eps_ilabel_processed); - for (int32_t pos = pos_start + 1; pos < pos_end - 1; ++pos) { + for (int32_t pos = pos_begin + 1; pos < pos_end - 1; ++pos) { ++state_ids[dest_state]; arcs.emplace_back(state_ids[dest_state], state_ids[dest_state] + 1, - labels_in.labels[pos]); + labels_in_.data[pos]); start_pos.push_back(num_non_eps_ilabel_processed); } ++state_ids[dest_state]; arcs.emplace_back(state_ids[dest_state], state_map[arc.dest_state], - labels_in.labels[pos_end - 1]); + labels_in_.data[pos_end - 1]); } // push non-epsilon ilabel in fsa_in as olabel of fsa_out if (arc.label != kEpsilon) { @@ -219,15 +242,18 @@ void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out, start_pos.push_back(num_non_eps_ilabel_processed); } - labels.resize(labels.size()); - arcs.resize(arcs.size()); - start_pos.resize(start_pos.size()); + // any failure indicates there are some errors + // TODO(haowen): replace with fsa_out->size2 + CHECK_EQ(arcs.size(), labels_out->size1); + CHECK_EQ(start_pos.size(), labels_out->size1 + 1); + CHECK_EQ(labels.size(), labels_out->size2); std::vector arc_map; ReorderArcs(arcs, fsa_out, &arc_map); - AuxLabels labels_tmp; - labels_tmp.start_pos = std::move(start_pos); - labels_tmp.labels = std::move(labels); - MapAuxLabels1(labels_tmp, arc_map, aux_labels_out); + AuxLabels labels_tmp(labels_out->size1, start_pos.data(), labels_out->size2, + labels.data()); + AuxLabels1Mapper aux_mapper(labels_tmp, arc_map); + // don't need to call `GetSizes` here as `labels_out` has been initialized + aux_mapper.GetOutput(labels_out); } } // namespace k2 diff --git a/k2/csrc/aux_labels.h b/k2/csrc/aux_labels.h index f0d9ab248..a14e1e71d 100644 --- a/k2/csrc/aux_labels.h +++ b/k2/csrc/aux_labels.h @@ -33,103 +33,162 @@ namespace k2 { /* This allows you to store auxiliary labels (e.g. olabels or ilabels) on each arc of an Fsa. - */ -struct AuxLabels { - /* Suppose this is associated with an Fsa f. start_pos will be of + + auto &start_pos = AuxLabels::Indexes; + + Suppose this is associated with an Fsa f. start_pos will be of size f.arcs.size() + 1; start_pos[i] is the start position in `labels` of the label sequence on arc i. start_pos.end() - equals labels.size(). */ - std::vector start_pos; - /* For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ... - labels[start_pos[i+1]-1]) are the list of labels on that arc. - We treat epsilon the same as other symbols here, so there are no - requirements on elements of `labels`. */ - std::vector labels; -}; + equals labels.size(). -// TODO(haowen): replace AuxLabels above with below definition -using AuxLabels_ = Array2; + auto &labels = AuxLabels::data; -// Swap AuxLabels; it's cheap to to this as we are actually doing shallow swap. -void Swap(AuxLabels *labels1, AuxLabels *labels2); + For arc i, (labels[start_pos[i] ], labels[start_pos[i]+1], ... + labels[start_pos[i+1]-1]) are the list of labels on that arc. + We treat epsilon the same as other symbols here, so there are no + requirements on elements of `labels`. + */ +using AuxLabels = Array2; /* Maps auxiliary labels after an FSA operation where each arc in the output FSA corresponds to exactly one arc in the input FSA. - @param [in] labels_in Labels on the arcs of the input FSA + */ +class AuxLabels1Mapper { + public: + /* Lightweight constructor that just keeps const references to the input + parameters. + @param [in] labels_in Labels on the arcs of the input FSA @param [in] arc_map Vector of size (output_fsa.arcs.size()), saying which arc of the input FSA it corresponds to. - @param [in] labels_out Labels on the arcs of the output FSA - */ -void MapAuxLabels1(const AuxLabels &labels_in, - const std::vector &arc_map, AuxLabels *labels_out); + */ + AuxLabels1Mapper(const AuxLabels &labels_in, + const std::vector &arc_map) + : labels_in_(labels_in), arc_map_(arc_map) {} + + /* + Do enough work that know now much memory will be needed, and output + that information + @param [out] aux_size The number of lists in the output AuxLabels + (equals num-arcs in the output FSA) and + the number of elements (equals num-aux-labels + on the arcs in the output FSA) will be written + to here. + */ + void GetSizes(Array2Size *aux_size); + + /* + Finish the operation and output auxiliary labels to `labels_out`. + @param [out] labels_out Auxiliary labels on the arcs of the output FSA. + Must be initialized; search for 'initialized + definition' in class Array2 in array.h for + meaning. + */ + void GetOutput(AuxLabels *labels_out); + + private: + const AuxLabels &labels_in_; + const std::vector &arc_map_; +}; /* Maps auxiliary labels after an FSA operation where each arc in the output FSA can correspond to a sequence of arcs in the input FSA. - @param [in] labels_in Labels on the arcs of the input FSA + */ +class AuxLabels2Mapper { + public: + /* Lightweight constructor that just keeps const references to the input + parameters. + @param [in] labels_in Labels on the arcs of the input FSA @param [in] arc_map Vector of size (output_fsa.arcs.size()), giving the sequence of arc-indexes in the input FSA that it corresponds to. - @param [in] labels_out Labels on the arcs of the output FSA - */ -void MapAuxLabels2(const AuxLabels &labels_in, - const std::vector> &arc_map, - AuxLabels *labels_out); + */ + AuxLabels2Mapper(const AuxLabels &labels_in, + const std::vector> &arc_map) + : labels_in_(labels_in), arc_map_(arc_map) {} + + /* + Do enough work that know now much memory will be needed, and output + that information + @param [out] aux_size The number of lists in the output AuxLabels + (equals num-arcs in the output FSA) and + the number of elements (equals num-aux-labels + on the arcs in the output FSA) will be written + to here. + */ + void GetSizes(Array2Size *aux_size); + + /* + Finish the operation and output auxiliary labels to `labels_out`. + @param [out] labels_out Auxiliary labels on the arcs of the output FSA. + Must be initialized; search for 'initialized + definition' in class Array2 in array.h for + meaning. + */ + void GetOutput(AuxLabels *labels_out); + + private: + const AuxLabels &labels_in_; + const std::vector> &arc_map_; +}; /* Invert an FST, swapping the symbols in the FSA with the auxiliary labels. (e.g. swap input and output symbols in FST, but you decide which is which). Because each arc may have more than one auxiliary label, in general the output FSA may have more states than the input FSA. - - @param [in] fsa_in Input FSA - @param [in] labels_in Input aux-label sequences, one for each arc in - fsa_in - @param [out] fsa_out Output FSA. Will have a number of states - >= that in fsa_in. If fsa_in was top-sorted it - will be top-sorted. Labels in the FSA will - correspond to those in `labels_in`. - @param [out] aux_labels_out Auxiliary labels on the arcs of - fsa_out. Will be the same as the labels on - `fsa_in`, although epsilons (kEpsilon, zeros) will be - removed. */ -void InvertFst(const Fsa &fsa_in, const AuxLabels &labels_in, Fsa *fsa_out, - AuxLabels *aux_labels_out); - class FstInverter { - /* Constructor. Lightweight. */ - FstInverter(const Fsa &fsa_in, const AuxLabels &labels_in); + public: + /* Lightweight constructor that just keeps const references to the input + parameters. + @param [in] fsa_in Input FSA + @param [in] labels_in Input aux-label sequences, one for each arc in + fsa_in + */ + FstInverter(const Fsa &fsa_in, const AuxLabels &labels_in) + : fsa_in_(fsa_in), labels_in_(labels_in) {} /* Do enough work that know now much memory will be needed, and output that information @param [out] fsa_size The num-states and num-arcs of the FSA will be written to here - @param [out] aux_size The number of lists in the AuxLabels - output (==num-arcs) and the number of - elements will be written to here. + @param [out] aux_size The number of lists in the output AuxLabels + (equals num-arcs in the output FSA) and + the number of elements (equals the number of + labels on `fsa_in`, although epsilons + will be removed) will be written to here. */ void GetSizes(Array2Size *fsa_size, Array2Size *aux_size); /* Finish the operation and output inverted FSA to `fsa_out` and auxiliary labels to `labels_out`. - @param [out] fsa_out The inverted FSA will be written to - here. Must be initialized; search for - 'initialized definition' in class Array2 - in array.h for meaning. - @param [out] labels_out The auxiliary labels will be written to - here. Must be initialized; search for - 'initialized definition' in class Array2 - in array.h for meaning. + @param [out] fsa_out The inverted FSA will be written to here. + Must be initialized; search for 'initialized + definition' in class Array2 in array.h for meaning. + + Will have a number of states >= that in fsa_in. + If fsa_in was top-sorted it will be top-sorted. + Labels in the FSA will correspond to those in + `labels_in`. + @param [out] labels_out The auxiliary labels will be written to here. + Must be initialized; search for 'initialized + definition' in class Array2 in array.h for + meaning. + + Will be the same as the labels on `fsa_in`, + although epsilons (kEpsilon, zeros) will be + removed. */ void GetOutput(Fsa *fsa_out, AuxLabels *labels_out); private: - // ... + const Fsa &fsa_in_; + const AuxLabels &labels_in_; }; } // namespace k2 diff --git a/k2/csrc/aux_labels_test.cc b/k2/csrc/aux_labels_test.cc index 193b517e1..e6c60f574 100644 --- a/k2/csrc/aux_labels_test.cc +++ b/k2/csrc/aux_labels_test.cc @@ -19,81 +19,109 @@ namespace k2 { class AuxLablesTest : public ::testing::Test { protected: AuxLablesTest() { - std::vector start_pos = {0, 1, 3, 6, 7}; - std::vector labels = {1, 2, 3, 4, 5, 6, 7}; - aux_labels_in_.start_pos = std::move(start_pos); - aux_labels_in_.labels = std::move(labels); + aux_labels_in_.size1 = static_cast(start_pos_.size()) - 1; + aux_labels_in_.size2 = static_cast(labels_.size()); + aux_labels_in_.indexes = start_pos_.data(); + aux_labels_in_.data = labels_.data(); } + std::vector start_pos_ = {0, 1, 3, 6, 7}; + std::vector labels_ = {1, 2, 3, 4, 5, 6, 7}; AuxLabels aux_labels_in_; }; -TEST_F(AuxLablesTest, MapAuxLabels1) { +TEST_F(AuxLablesTest, AuxLabels1Mapper) { { // empty arc_map std::vector arc_map; - AuxLabels aux_labels_out; - // some dirty data - aux_labels_out.start_pos = {1, 2, 3}; - aux_labels_out.labels = {4, 5}; - MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); - - EXPECT_TRUE(aux_labels_out.labels.empty()); - ASSERT_EQ(aux_labels_out.start_pos.size(), 1); - EXPECT_EQ(aux_labels_out.start_pos[0], 0); + AuxLabels1Mapper aux_mapper(aux_labels_in_, arc_map); + Array2Size aux_size; + aux_mapper.GetSizes(&aux_size); + Array2Storage storage(aux_size, 1); + auto aux_labels_out = storage.GetArray2(); + aux_mapper.GetOutput(&aux_labels_out); + + ASSERT_EQ(aux_labels_out.size1, 0); + EXPECT_EQ(aux_labels_out.indexes[0], 0); + EXPECT_EQ(aux_labels_out.size2, 0); } { std::vector arc_map = {2, 0, 3}; - AuxLabels aux_labels_out; - MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); - - ASSERT_EQ(aux_labels_out.start_pos.size(), 4); - EXPECT_THAT(aux_labels_out.start_pos, ::testing::ElementsAre(0, 3, 4, 5)); - ASSERT_EQ(aux_labels_out.labels.size(), 5); - EXPECT_THAT(aux_labels_out.labels, ::testing::ElementsAre(4, 5, 6, 1, 7)); + AuxLabels1Mapper aux_mapper(aux_labels_in_, arc_map); + Array2Size aux_size; + aux_mapper.GetSizes(&aux_size); + Array2Storage storage(aux_size, 1); + auto aux_labels_out = storage.GetArray2(); + aux_mapper.GetOutput(&aux_labels_out); + + ASSERT_EQ(aux_labels_out.size1, 3); + ASSERT_EQ(aux_labels_out.size2, 5); + std::vector out_indexes( + aux_labels_out.indexes, + aux_labels_out.indexes + aux_labels_out.size1 + 1); + std::vector out_data(aux_labels_out.data, + aux_labels_out.data + aux_labels_out.size2); + EXPECT_THAT(out_indexes, ::testing::ElementsAre(0, 3, 4, 5)); + EXPECT_THAT(out_data, ::testing::ElementsAre(4, 5, 6, 1, 7)); } { // all arcs in input fsa are remained std::vector arc_map = {2, 0, 3, 1}; - AuxLabels aux_labels_out; - MapAuxLabels1(aux_labels_in_, arc_map, &aux_labels_out); - - ASSERT_EQ(aux_labels_out.start_pos.size(), 5); - EXPECT_THAT(aux_labels_out.start_pos, - ::testing::ElementsAre(0, 3, 4, 5, 7)); - ASSERT_EQ(aux_labels_out.labels.size(), 7); - EXPECT_THAT(aux_labels_out.labels, - ::testing::ElementsAre(4, 5, 6, 1, 7, 2, 3)); + AuxLabels1Mapper aux_mapper(aux_labels_in_, arc_map); + Array2Size aux_size; + aux_mapper.GetSizes(&aux_size); + Array2Storage storage(aux_size, 1); + auto aux_labels_out = storage.GetArray2(); + aux_mapper.GetOutput(&aux_labels_out); + + ASSERT_EQ(aux_labels_out.size2, 7); + ASSERT_EQ(aux_labels_out.size1, 4); + std::vector out_indexes( + aux_labels_out.indexes, + aux_labels_out.indexes + aux_labels_out.size1 + 1); + std::vector out_data(aux_labels_out.data, + aux_labels_out.data + aux_labels_out.size2); + EXPECT_THAT(out_indexes, ::testing::ElementsAre(0, 3, 4, 5, 7)); + EXPECT_THAT(out_data, ::testing::ElementsAre(4, 5, 6, 1, 7, 2, 3)); } } -TEST_F(AuxLablesTest, MapAuxLabels2) { +TEST_F(AuxLablesTest, AuxLabels2Mapper) { { // empty arc_map std::vector> arc_map; - AuxLabels aux_labels_out; - // some dirty data - aux_labels_out.start_pos = {1, 2, 3}; - aux_labels_out.labels = {4, 5}; - MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out); - - EXPECT_TRUE(aux_labels_out.labels.empty()); - ASSERT_EQ(aux_labels_out.start_pos.size(), 1); - EXPECT_EQ(aux_labels_out.start_pos[0], 0); + AuxLabels2Mapper aux_mapper(aux_labels_in_, arc_map); + Array2Size aux_size; + aux_mapper.GetSizes(&aux_size); + Array2Storage storage(aux_size, 1); + auto aux_labels_out = storage.GetArray2(); + aux_mapper.GetOutput(&aux_labels_out); + + ASSERT_EQ(aux_labels_out.size1, 0); + EXPECT_EQ(aux_labels_out.indexes[0], 0); + EXPECT_EQ(aux_labels_out.size2, 0); } { std::vector> arc_map = {{2, 3}, {0, 1}, {0}, {2}}; - AuxLabels aux_labels_out; - MapAuxLabels2(aux_labels_in_, arc_map, &aux_labels_out); - - ASSERT_EQ(aux_labels_out.start_pos.size(), 5); - EXPECT_THAT(aux_labels_out.start_pos, - ::testing::ElementsAre(0, 4, 7, 8, 11)); - ASSERT_EQ(aux_labels_out.labels.size(), 11); - EXPECT_THAT(aux_labels_out.labels, + AuxLabels2Mapper aux_mapper(aux_labels_in_, arc_map); + Array2Size aux_size; + aux_mapper.GetSizes(&aux_size); + Array2Storage storage(aux_size, 1); + auto aux_labels_out = storage.GetArray2(); + aux_mapper.GetOutput(&aux_labels_out); + + ASSERT_EQ(aux_labels_out.size2, 11); + ASSERT_EQ(aux_labels_out.size1, 4); + std::vector out_indexes( + aux_labels_out.indexes, + aux_labels_out.indexes + aux_labels_out.size1 + 1); + std::vector out_data(aux_labels_out.data, + aux_labels_out.data + aux_labels_out.size2); + EXPECT_THAT(out_indexes, ::testing::ElementsAre(0, 4, 7, 8, 11)); + EXPECT_THAT(out_data, ::testing::ElementsAre(4, 5, 6, 7, 1, 2, 3, 1, 4, 5, 6)); } } @@ -102,24 +130,24 @@ TEST(AuxLabels, InvertFst) { { // empty input FSA Fsa fsa_in; - AuxLabels labels_in; std::vector start_pos = {0, 1, 3, 6, 7}; std::vector labels = {1, 2, 3, 4, 5, 6, 7}; - labels_in.start_pos = std::move(start_pos); - labels_in.labels = std::move(labels); - - std::vector arcs = {{0, 1, 1}, {1, 2, -1}}; - Fsa fsa_out(std::move(arcs), 2); - AuxLabels labels_out; - // some dirty data - labels_out.start_pos = {1, 2, 3}; - labels_out.labels = {4, 5}; - InvertFst(fsa_in, labels_in, &fsa_out, &labels_out); + AuxLabels labels_in(static_cast(start_pos.size()) - 1, + start_pos.data(), static_cast(labels.size()), + labels.data()); + + FstInverter fst_inverter(fsa_in, labels_in); + Array2Size fsa_size, aux_size; + fst_inverter.GetSizes(&fsa_size, &aux_size); + Array2Storage aux_storage(aux_size, 1); + auto labels_out = aux_storage.GetArray2(); + Fsa fsa_out; + fst_inverter.GetOutput(&fsa_out, &labels_out); EXPECT_TRUE(IsEmpty(fsa_out)); - EXPECT_TRUE(labels_out.labels.empty()); - ASSERT_EQ(labels_out.start_pos.size(), 1); - EXPECT_EQ(labels_out.start_pos[0], 0); + ASSERT_EQ(labels_out.size1, 0); + EXPECT_EQ(labels_out.indexes[0], 0); + EXPECT_EQ(labels_out.size2, 0); } { @@ -129,16 +157,20 @@ TEST(AuxLabels, InvertFst) { {2, 3, 0}, {2, 5, -1}, {4, 5, -1}}; Fsa fsa_in(std::move(arcs), 5); EXPECT_TRUE(IsTopSorted(fsa_in)); - AuxLabels labels_in; std::vector start_pos = {0, 2, 3, 3, 6, 6, 7, 7, 8, 9}; EXPECT_EQ(start_pos.size(), fsa_in.arcs.size() + 1); std::vector labels = {1, 2, 3, 5, 6, 7, -1, -1, -1}; - labels_in.start_pos = std::move(start_pos); - labels_in.labels = std::move(labels); - + AuxLabels labels_in(static_cast(start_pos.size()) - 1, + start_pos.data(), static_cast(labels.size()), + labels.data()); + + FstInverter fst_inverter(fsa_in, labels_in); + Array2Size fsa_size, aux_size; + fst_inverter.GetSizes(&fsa_size, &aux_size); + Array2Storage aux_storage(aux_size, 1); + auto labels_out = aux_storage.GetArray2(); Fsa fsa_out; - AuxLabels labels_out; - InvertFst(fsa_in, labels_in, &fsa_out, &labels_out); + fst_inverter.GetOutput(&fsa_out, &labels_out); EXPECT_TRUE(IsTopSorted(fsa_out)); std::vector arcs_out = { @@ -152,11 +184,15 @@ TEST(AuxLabels, InvertFst) { ASSERT_EQ(fsa_out.arc_indexes.size(), 10); EXPECT_THAT(fsa_out.arc_indexes, ::testing::ElementsAre(0, 3, 4, 7, 8, 9, 11, 11, 12, 12)); - ASSERT_EQ(labels_out.labels.size(), 7); - EXPECT_THAT(labels_out.labels, - ::testing::ElementsAre(2, 1, 4, -1, 3, -1, -1)); - ASSERT_EQ(labels_out.start_pos.size(), 13); - EXPECT_THAT(labels_out.start_pos, + + ASSERT_EQ(labels_out.size1, 12); + ASSERT_EQ(labels_out.size2, 7); + std::vector out_indexes(labels_out.indexes, + labels_out.indexes + labels_out.size1 + 1); + std::vector out_data(labels_out.data, + labels_out.data + labels_out.size2); + EXPECT_THAT(out_data, ::testing::ElementsAre(2, 1, 4, -1, 3, -1, -1)); + EXPECT_THAT(out_indexes, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6, 7)); } @@ -167,16 +203,20 @@ TEST(AuxLabels, InvertFst) { {2, 5, -1}, {3, 1, 6}, {4, 5, -1}}; Fsa fsa_in(std::move(arcs), 5); EXPECT_FALSE(IsTopSorted(fsa_in)); - AuxLabels labels_in; std::vector start_pos = {0, 2, 3, 3, 6, 6, 7, 8, 10, 11}; EXPECT_EQ(start_pos.size(), fsa_in.arcs.size() + 1); std::vector labels = {1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1}; - labels_in.start_pos = std::move(start_pos); - labels_in.labels = std::move(labels); - + AuxLabels labels_in(static_cast(start_pos.size()) - 1, + start_pos.data(), static_cast(labels.size()), + labels.data()); + + FstInverter fst_inverter(fsa_in, labels_in); + Array2Size fsa_size, aux_size; + fst_inverter.GetSizes(&fsa_size, &aux_size); + Array2Storage aux_storage(aux_size, 1); + auto labels_out = aux_storage.GetArray2(); Fsa fsa_out; - AuxLabels labels_out; - InvertFst(fsa_in, labels_in, &fsa_out, &labels_out); + fst_inverter.GetOutput(&fsa_out, &labels_out); EXPECT_FALSE(IsTopSorted(fsa_out)); std::vector arcs_out = {{0, 1, 1}, {0, 3, 3}, {0, 7, 0}, {1, 3, 2}, @@ -190,13 +230,16 @@ TEST(AuxLabels, InvertFst) { ASSERT_EQ(fsa_out.arc_indexes.size(), 11); EXPECT_THAT(fsa_out.arc_indexes, ::testing::ElementsAre(0, 3, 4, 5, 7, 8, 9, 11, 12, 13, 13)); - ASSERT_EQ(labels_out.labels.size(), 8); - EXPECT_THAT(labels_out.labels, - ::testing::ElementsAre(2, 1, 6, 4, 3, 5, -1, -1)); - ASSERT_EQ(labels_out.start_pos.size(), 14); - EXPECT_THAT( - labels_out.start_pos, - ::testing::ElementsAre(0, 0, 0, 1, 2, 3, 3, 4, 4, 5, 6, 7, 7, 8)); + + ASSERT_EQ(labels_out.size1, 13); + ASSERT_EQ(labels_out.size2, 8); + std::vector out_indexes(labels_out.indexes, + labels_out.indexes + labels_out.size1 + 1); + std::vector out_data(labels_out.data, + labels_out.data + labels_out.size2); + EXPECT_THAT(out_data, ::testing::ElementsAre(2, 1, 6, 4, 3, 5, -1, -1)); + EXPECT_THAT(out_indexes, ::testing::ElementsAre(0, 0, 0, 1, 2, 3, 3, 4, 4, + 5, 6, 7, 7, 8)); } }