From 67960472dfdeafd245956900e954919a4b35963d Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Fri, 3 Jul 2020 17:55:06 +0800 Subject: [PATCH] remove arcs and arc_indexes in fsa and add some functions in Array3 --- k2/csrc/array.h | 89 ++++++++-- k2/csrc/array_test.cc | 92 +++++++++- k2/csrc/fsa.cc | 229 +----------------------- k2/csrc/fsa.h | 291 +++--------------------------- k2/csrc/fsa_equivalent_test.cc | 2 +- k2/csrc/fsa_test.cc | 312 +++++++-------------------------- k2/csrc/fsa_util.cc | 7 - k2/csrc/fsa_util.h | 9 +- k2/csrc/properties.h | 6 +- k2/python/csrc/CMakeLists.txt | 2 +- k2/python/csrc/fsa.cc | 170 +++++++++--------- k2/python/csrc/k2.cc | 2 +- k2/python/csrc/k2.h | 7 +- 13 files changed, 356 insertions(+), 862 deletions(-) diff --git a/k2/csrc/array.h b/k2/csrc/array.h index e3b6c14d9..a56e5c56a 100644 --- a/k2/csrc/array.h +++ b/k2/csrc/array.h @@ -130,11 +130,11 @@ struct Array2 { // as we require `indexes[0] == 0` if Array2 is empty, // the implementation of `begin` and `end` would be fine for empty object. - PtrT begin() { return data + indexes[0]; } - const PtrT begin() const { return data + indexes[0]; } + PtrT begin() { return data + indexes[0]; } // NOLINT + const PtrT begin() const { return data + indexes[0]; } // NOLINT - PtrT end() { return data + indexes[size1]; } - const PtrT end() const { return data + indexes[size1]; } + PtrT end() { return data + indexes[size1]; } // NOLINT + const PtrT end() const { return data + indexes[size1]; } // NOLINT // just to replace `Swap` functions for Fsa and AuxLabels for now, // may delete it if we finally find that we don't need to call it. @@ -161,24 +161,33 @@ struct Array3 { using IndexT = I; using PtrT = Ptr; - IndexT size; - IndexT *indexes1; // indexes1[0,1,...size] should be defined; note, - // this means the array must be of at least size+1. + IndexT size1; // equal to the number of Array2 object in this Array3 object; + // `size1 + 1` will be the number of elements in indexes1. + + IndexT size2; // equal to indexes1[size1] - indexes1[0]; + // `size2 + 1` will be the number of elements in indexes2; + + IndexT size3; // the number of elements in `data`, equal to + // indexes2[indexes1[size1]] - indexes2[indexes1[0]]. + + IndexT *indexes1; // indexes1[0,1,...size1] should be defined; note, + // this means the array must be of at least size1+1. // We require that indexes[i] <= indexes[i+1], but it - // is not required that indexes[0] == 0, it may be - // greater than 0. + // is not required that indexes[0] == 0, it may be greater + // than 0. IndexT *indexes2; // indexes2[indexes1[0]] - // .. indexes2[indexes1[size]-1] should be defined. + // .. indexes2[indexes1[size1]] should be defined; + // note, this means the array must be of at least size2+1. Ptr data; // `data` might be an actual pointer, or might be some object // supporting operator []. data[indexes2[indexes1[0]]] through - // data[indexes2[indexes1[size] - 1]] must be accessible through - // this object. + // data[indexes2[indexes1[size1]] - 1] must be accessible + // through this object. Array2 operator[](I i) const { DCHECK_GE(i, 0); - DCHECK_LT(i, size); + DCHECK_LT(i, size1); Array2 array; array.size1 = indexes1[i + 1] - indexes1[i]; @@ -187,6 +196,59 @@ struct Array3 { array.data = data; return array; } + + /* + Set `size1`, `size2` and `size3` so that we can know how much memory we need + to allocate for `indexes1`, `indexes2` and `data` to represent the vector + of Array2 as an Array3. + @param [in] arrays A vector of Array2; + @param [in] array_size The number element of vector `arrays` + */ + void GetSizes(const Array2 *arrays, I array_size) { + size1 = array_size; + size2 = size3 = 0; + for (I i = 0; i != array_size; ++i) { + size2 += arrays[i].size1; + size3 += arrays[i].size2; + } + } + + /* + Create Array3 from the vector of Array2. `size1`, `size2` and `size3` must + have been set by calling `GetSizes` above, and the memory of `indexes1`, + `indexes2`and `data` must have been allocated according to those size. + @param [in] arrays A vector of Array2; + @param [in] array_size The number element of vector `arrays` + */ + void Create(const Array2 *arrays, I array_size) { + CHECK_EQ(size1, array_size); + I size2_tmp = 0, size3_tmp = 0; + for (I i = 0; i != array_size; ++i) { + const auto &curr_array = arrays[i]; + + indexes1[i] = size2_tmp; + + // copy indexes + CHECK_LE(size2_tmp + curr_array.size1, size2); + I begin_index = curr_array.indexes[0]; // indexes[0] is always valid and + // may be greater than 0 + for (I j = 0; j != curr_array.size1; ++j) { + indexes2[size2_tmp++] = size3_tmp + curr_array.indexes[j] - begin_index; + } + + // copy data + CHECK_LE(size3_tmp + curr_array.size2, size3); + for (I n = 0; n != curr_array.size2; ++n) { + data[size3_tmp + n] = curr_array.data[n + begin_index]; + } + size3_tmp += curr_array.size2; + } + CHECK_EQ(size2_tmp, size2); + CHECK_EQ(size3_tmp, size3); + + indexes1[size1] = size2_tmp; + indexes2[indexes1[size1]] = size3_tmp; + } }; // Note: we can create Array4 later if we need it. @@ -261,7 +323,6 @@ struct Array2Storage { namespace std { template - struct iterator_traits> { typedef T value_type; }; diff --git a/k2/csrc/array_test.cc b/k2/csrc/array_test.cc index f9437a8e7..c14dc3832 100644 --- a/k2/csrc/array_test.cc +++ b/k2/csrc/array_test.cc @@ -7,6 +7,7 @@ #include "k2/csrc/array.h" #include +#include #include #include #include @@ -132,10 +133,97 @@ void TestArray2(int32_t stride) { } } -TEST(Array2Test, RawPointer) { TestArray2(1); } +template +void TestArray3(int32_t stride) { + using ValueType = typename std::iterator_traits::value_type; + + Array2Size size1 = {4, 10}; + std::vector indexes1 = {0, 3, 5, 9, 10}; + std::vector data1(size1.size2); + std::iota(data1.begin(), data1.end(), 0); + Array2Storage storage1(size1, stride); + storage1.FillIndexes(indexes1); + storage1.FillData(data1); + Array2 &array1 = storage1.GetArray2(); + EXPECT_EQ(array1.data[array1.indexes[0]], 0); + + Array2Size size2 = {3, 10}; + // note indexes2[0] starts from 3 instead of 0 + std::vector indexes2 = {3, 5, 8, 10}; + std::vector data2(10); // 10 instead of 7 here on purpose + std::iota(data2.begin(), data2.end(), 0); + Array2Storage storage2(size2, stride); + storage2.FillIndexes(indexes2); + storage2.FillData(data2); + Array2 &array2 = storage2.GetArray2(); + array2.size2 = 7; // change the size to the correct value + EXPECT_EQ(array2.data[array2.indexes[0]], 3); + + std::vector> arrays; + arrays.emplace_back(array1); + arrays.emplace_back(array2); + + Array3 array3; + array3.GetSizes(arrays.data(), 2); + EXPECT_EQ(array3.size1, 2); + EXPECT_EQ(array3.size2, 7); + EXPECT_EQ(array3.size3, 17); + + // Test Array3 Creation + std::vector array3_indexes1(array3.size1 + 1); + std::vector array3_indexes2(array3.size2 + 1); + std::unique_ptr array3_data( + new ValueType[array3.size3 * stride]); + array3.indexes1 = array3_indexes1.data(); + array3.indexes2 = array3_indexes2.data(); + array3.data = DataPtrCreator::Create(array3_data, stride); + + array3.Create(arrays.data(), 2); + EXPECT_THAT(array3_indexes1, ::testing::ElementsAre(0, 4, 7)); + EXPECT_THAT(array3_indexes2, + ::testing::ElementsAre(0, 3, 5, 9, 10, 12, 15, 17)); + for (auto i = array1.indexes[0]; i != array1.indexes[array1.size1]; ++i) { + EXPECT_EQ(array3.data[i], array1.data[i]); + } + EXPECT_EQ(array2.indexes[0], 3); + for (auto i = array2.indexes[0]; i != array2.indexes[array2.size1]; ++i) { + EXPECT_EQ(array3.data[array1.size2 + i - array2.indexes[0]], + array2.data[i]); + } + + // Test Array3's operator[] + Array2 array1_copy = array3[0]; + EXPECT_EQ(array1_copy.size1, array1.size1); + EXPECT_EQ(array1_copy.size2, array1.size2); + for (auto i = 0; i != array1.size1 + 1; ++i) { + EXPECT_EQ(array1_copy.indexes[i], array1.indexes[i]); + } + for (auto i = array1.indexes[0]; i != array1.indexes[array1.size1]; ++i) { + EXPECT_EQ(array1_copy.data[i], array1.data[i]); + } + + Array2 array2_copy = array3[1]; + EXPECT_EQ(array2_copy.size1, array2.size1); + EXPECT_EQ(array2_copy.size2, array2.size2); + for (auto i = 0; i != array2.size1 + 1; ++i) { + // output indexes may starts from n > 0 + EXPECT_EQ(array2_copy.indexes[i], + array2.indexes[i] + array1.size1 + array2.indexes[0]); + } + for (auto i = array2.indexes[0]; i != array2.indexes[array2.size1]; ++i) { + EXPECT_EQ(array1_copy.data[i + array1.size2 - array2.indexes[0]], + array1.data[i]); + } +} + +TEST(ArrayTest, RawPointer) { + TestArray2(1); + TestArray3(1); +} -TEST(Array2Test, StridedPtr) { +TEST(ArrayTest, StridedPtr) { TestArray2, int32_t>(2); + TestArray3, int32_t>(2); } } // namespace k2 diff --git a/k2/csrc/fsa.cc b/k2/csrc/fsa.cc index 1666564d4..9efc81cbd 100644 --- a/k2/csrc/fsa.cc +++ b/k2/csrc/fsa.cc @@ -14,7 +14,6 @@ constexpr std::size_t kAlignment = 64; static_assert((kAlignment & 15) == 0, "kAlignment should be at least multiple of 16"); static_assert(kAlignment % alignof(k2::Arc) == 0, ""); -static_assert(kAlignment % alignof(k2::CfsaVecHeader) == 0, ""); inline std::size_t AlignTo(std::size_t b, std::size_t alignment) { // alignment should be power of 2 @@ -30,231 +29,13 @@ std::ostream &operator<<(std::ostream &os, const Arc &arc) { return os; } -std::ostream &operator<<(std::ostream &os, const Cfsa &cfsa) { - os << "num_states: " << cfsa.num_states << "\n"; - os << "begin_arc: " << cfsa.begin_arc << "\n"; - os << "end_arc: " << cfsa.end_arc << "\n"; - os << "num_arcs: " << cfsa.NumArcs() << "\n"; - for (int i = cfsa.begin_arc; i != cfsa.end_arc; ++i) { - os << cfsa.arcs[i] << "\n"; +std::ostream &operator<<(std::ostream &os, const Fsa &fsa) { + os << "num_states: " << fsa.NumStates() << "\n"; + os << "num_arcs: " << fsa.size2 << "\n"; + for (const auto &arc : fsa) { + os << arc << "\n"; } return os; } -Cfsa::Cfsa() - : num_states(0), - begin_arc(0), - end_arc(0), - arc_indexes(nullptr), - arcs(nullptr) {} - -Cfsa::Cfsa(const Fsa &fsa) { - begin_arc = 0; - num_states = fsa.NumStates(); - if (num_states != 0) { - // this is not an empty fsa - arc_indexes = fsa.arc_indexes.data(); - end_arc = fsa.arc_indexes.back(); - arcs = const_cast(fsa.arcs.data()); - } else { - // this is an empty fsa - arc_indexes = nullptr; - end_arc = 0; - arcs = nullptr; - } -} - -CfsaVec::CfsaVec(std::size_t size, void *data) - : data_(reinterpret_cast(data)), size_(size) { - const auto header = reinterpret_cast(data_); - num_fsas_ = header->num_fsas; -} - -Cfsa CfsaVec::operator[](int32_t i) const { - DCHECK_GE(i, 0); - DCHECK_LT(i, num_fsas_); - - Cfsa cfsa; - - const auto header = reinterpret_cast(data_); - const auto state_offsets_array = data_ + header->state_offsets_start; - - int32_t num_states = state_offsets_array[i + 1] - state_offsets_array[i]; - if (num_states == 0) return cfsa; - - // we have to decrease num_states by one since the last entry of arc_indexes - // is repeated. - --num_states; - DCHECK_GE(num_states, 2); - - const auto arc_indexes_array = data_ + header->arc_indexes_start; - const auto arcs_array = reinterpret_cast(data_) + header->arcs_start; - - cfsa.num_states = num_states; - cfsa.begin_arc = arc_indexes_array[state_offsets_array[i]]; - cfsa.end_arc = arc_indexes_array[state_offsets_array[i + 1] - 1]; - cfsa.arc_indexes = arc_indexes_array; - cfsa.arcs = arcs_array; - - return cfsa; -} - -std::size_t GetCfsaVecSize(const Cfsa &cfsa) { - std::size_t res_bytes = 0; - - std::size_t header_bytes = sizeof(CfsaVecHeader); - res_bytes += header_bytes; - - // padding to the alignment boundary for state_offsets_array - res_bytes = AlignTo(res_bytes, kAlignment); - - // size in bytes for `int32_t state_offsets_array[num_fsas + 1];` - std::size_t state_offsets_array_bytes = sizeof(int32_t) * 2; - res_bytes += state_offsets_array_bytes; - - // padding to the alignment boundary for arc_indexes_array - res_bytes = AlignTo(res_bytes, kAlignment); - - // size in bytes for `int32_t arc_indexes_array[num_states + num_fsas];` - std::size_t arc_indexes_array_bytes = - sizeof(int32_t) * (cfsa.NumStates() + 1); - res_bytes += arc_indexes_array_bytes; - - // align res_bytes to be multiple of sizeof(Arc) - res_bytes = (res_bytes + sizeof(Arc) - 1) / sizeof(Arc) * sizeof(Arc); - - DCHECK_EQ(res_bytes % alignof(Arc), 0); - - // size in bytes for `Arc arcs[num_arcs];` - std::size_t arcs_array_bytes = sizeof(Arc) * cfsa.NumArcs(); - res_bytes += arcs_array_bytes; - - return res_bytes; -} - -std::size_t GetCfsaVecSize(const std::vector &cfsas) { - std::size_t res_bytes = 0; - - std::size_t header_bytes = sizeof(CfsaVecHeader); - res_bytes += header_bytes; - - // padding to the alignment boundary for state_offsets_array - res_bytes = AlignTo(res_bytes, kAlignment); - - // size in bytes for `int32_t state_offsets_array[num_fsas + 1];` - std::size_t state_offsets_array_bytes = sizeof(int32_t) * (cfsas.size() + 1); - res_bytes += state_offsets_array_bytes; - - // padding to the alignment boundary for arc_indexes_array - res_bytes = AlignTo(res_bytes, kAlignment); - - int32_t num_states = 0; - int32_t num_arcs = 0; - for (const auto &cfsa : cfsas) { - num_states += cfsa.NumStates(); - num_arcs += cfsa.NumArcs(); - } - - // size in bytes for `int32_t arc_indexes_array[num_states + num_fsas];` - std::size_t arc_indexes_array_bytes = - sizeof(int32_t) * (num_states + cfsas.size()); - res_bytes += arc_indexes_array_bytes; - - // align res_bytes to be multiple of sizeof(Arc) - res_bytes = (res_bytes + sizeof(Arc) - 1) / sizeof(Arc) * sizeof(Arc); - - DCHECK_EQ(res_bytes % alignof(Arc), 0); - - // size in bytes for `Arc arcs[num_arcs];` - std::size_t arcs_array_bytes = sizeof(Arc) * num_arcs; - res_bytes += arcs_array_bytes; - - return res_bytes; -} - -void CreateCfsaVec(const std::vector &cfsas, void *data, - std::size_t size) { - DCHECK_EQ(size, GetCfsaVecSize(cfsas)); - - auto header = reinterpret_cast(data); - header->version = kCfsaVecVersion; - header->num_fsas = static_cast(cfsas.size()); - - std::size_t offset = sizeof(CfsaVecHeader); - - // the state_offsets_array is aligned to the boundary `kAlignment`. - offset = AlignTo(offset, kAlignment); - header->state_offsets_start = offset / sizeof(int32_t); - - auto state_offsets_array = - reinterpret_cast(data) + header->state_offsets_start; - - int32_t num_states = 0; - - state_offsets_array[0] = 0; - int32_t i = 1; - for (; i < static_cast(cfsas.size()); ++i) { - if (cfsas[i - 1].NumStates() != 0) { - state_offsets_array[i] = - state_offsets_array[i - 1] + cfsas[i - 1].NumStates() + 1; - } else { - state_offsets_array[i] = state_offsets_array[i - 1]; - } - - num_states += cfsas[i].NumStates(); - } - - if (!cfsas.empty()) { - if (cfsas.back().NumStates() != 0) { - state_offsets_array[i] = - state_offsets_array[i - 1] + cfsas.back().NumStates() + 1; - } else { - state_offsets_array[i] = state_offsets_array[i - 1]; - } - - num_states += cfsas[0].NumStates(); - } - - // int32_t state_offsets_array[num_fsas + 1]; - offset += sizeof(int32_t) * (cfsas.size() + 1); - - // arc_indexes_array is aligned to the boundary `kAlignment` - offset = AlignTo(offset, kAlignment); - header->arc_indexes_start = offset / sizeof(int32_t); - - auto arc_indexes_array = - reinterpret_cast(data) + header->arc_indexes_start; - - // int32_t arc_indexes_array[num_states + num_fsas]; - // we add `num_fsas` here because each fsa has its final states repeated. - offset += sizeof(int32_t) * (num_states + cfsas.size()); - - // align offset to sizeof(Arc) - offset = (offset + sizeof(Arc) - 1) / sizeof(Arc) * sizeof(Arc); - DCHECK_EQ(offset % sizeof(Arc), 0); - - header->arcs_start = offset / sizeof(Arc); - Arc *arcs_array = reinterpret_cast(data) + header->arcs_start; - - int32_t num_states_so_far = 0; - int32_t num_arcs_so_far = 0; - - for (const auto &cfsa : cfsas) { - for (int32_t s = 0; s != cfsa.NumStates(); ++s, ++num_states_so_far) { - arc_indexes_array[num_states_so_far] = num_arcs_so_far; - int32_t num_arcs = cfsa.arc_indexes[s + 1] - cfsa.arc_indexes[s]; - - std::copy_n(cfsa.arcs + cfsa.arc_indexes[s], num_arcs, - arcs_array + num_arcs_so_far); - - num_arcs_so_far += num_arcs; - } - if (cfsa.NumStates() != 0) { - arc_indexes_array[num_states_so_far] = - arc_indexes_array[num_states_so_far - 1]; - ++num_states_so_far; - } - } -} - } // namespace k2 diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index 70ae55ca2..73cd4629c 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -1,6 +1,7 @@ // k2/csrc/fsa.h -// Copyright (c) 2020 Daniel Povey +// Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey +// Haowen Qiu) // See ../../LICENSE for clarification regarding multiple authors @@ -78,285 +79,45 @@ struct ArcHash { final_state. The start-state is always numbered zero and the final-state is always the - last-numbered state. However, we represent the empty FSA (the one that - accepts no strings) by having no states at all, so `arcs_indexes` would be - empty. - - TODO(haowen): add below comments in the final version: - (In Array2 representation) We represent an empty FSA(the one that accepts no + last-numbered state. We represent an empty FSA(the one that accepts no strings) by having no states at all, so `size1` would be 0 (As an empty FSA is - an initialized Array2 object, so `indexes` would be allocated and has at least - one element, but we don't care about it here). + an initialized Array2 object, and `indexes` would be allocated and has at + least one element, but we don't care about it here). */ -// TODO(haowen): finally we will remove `arc_indexes` and `arcs`, but for now, -// we would keep them to replace Fsa with Array2 incrementally. struct Fsa : public Array2 { - // `arc_indexes` is indexed by state-index, is of length num-states + 1; it - // contains the first arc-index leaving this state (index into `arcs`). The - // next element of this array gives the end of that range. Note: the - // final-state is numbered last, and implicitly has no arcs leaving it. For - // non-empty FSA, we put a duplicate of the final state at the end of - // `arc_indexes` to avoid boundary check for some FSA operations. Caution: - // users should never call `arc_indexes.size()` to get the number of states, - // they should call `NumStates()` to get the number. - std::vector arc_indexes; - - // Note: an index into the `arcs` array is called an arc-index. - std::vector arcs; + // `size1` is equal to num-states of the FSA. + // + // `size2` is equal to num-arcs of the FSA. + // + // `data` stores the arcs of the Fsa and is indexed by arc-index (an index + // into the `data` array is called an arc-index). We may use `arcs` as an + // alias of `data` in the context of FSA. + // + // `indexes` is indexed by state-index, is of length num-states + 1; it + // contains the first arc-index leaving this state (index into `arcs`). + // The next element of this array gives the end of that range. Note: the + // final-state is numbered last, and implicitly has no arcs leaving it. + // We may use `arc-indexes` as an alias of `indexes`. // inherits constructors in Array2 using Array2::Array2; - Fsa() : Array2() { - // TODO(haowen): remove this after replacing Fsa with Array2 - indexes = nullptr; - } - // just for creating testing FSA examples for now. - Fsa(std::vector fsa_arcs, int32_t final_state) - : arcs(std::move(fsa_arcs)) { - indexes = nullptr; - if (arcs.empty()) return; - - int32_t curr_state = -1; - int32_t index = 0; - for (const auto &arc : arcs) { - CHECK_LE(arc.src_state, final_state); - CHECK_LE(arc.dest_state, final_state); - CHECK_LE(curr_state, arc.src_state); - while (curr_state < arc.src_state) { - arc_indexes.push_back(index); - ++curr_state; - } - ++index; - } - // noted that here we push two `final_state` at the end, the last element is - // just to avoid boundary check for some FSA operations. - for (; curr_state <= final_state; ++curr_state) - arc_indexes.push_back(index); - } - - // TODO(haowen): finally we'll implement NumStates with: - // CHECK_GE(size1, 0); - // return size1; int32_t NumStates() const { - if (indexes != nullptr) { // Fsa is initialized as Array2 - // users should not use `arc_indexes` and `arcs` while using Array2 - CHECK(arc_indexes.empty()); - CHECK(arcs.empty()); - - CHECK_GE(size1, 0); - return size1; - } - return !arc_indexes.empty() ? (static_cast(arc_indexes.size()) - 1) - : 0; + CHECK_GE(size1, 0); + return size1; } - // TODO(haowen): finally we'll implement FinalStates with: - // CHECK_GE(size1, 2); - // return size1 - 1; int32_t FinalState() const { - if (indexes != nullptr) { // Fsa is initialized as Array2 - // users should not use `arc_indexes` and `arcs` while using Array2 - CHECK(arc_indexes.empty()); - CHECK(arcs.empty()); - - // It's not valid to call FinalState if the FSA is empty. - CHECK_GE(size1, 2); - return size1 - 1; - } - // It's not valid to call this if the FSA is empty. - CHECK(!arc_indexes.empty()); - return static_cast(arc_indexes.size()) - 2; - } -}; - -// TODO(haowen): replace Cfsa and CfsaVec with below definitions -using Cfsa_ = Array2; -using CfsaVec_ = Array3; - -/* - Cfsa is a 'const' FSA, which we'll use as the input to operations. It is - designed in such a way that the storage underlying it may either be an Fsa - (i.e. with std::vectors) or may be some kind of tensor (probably CfsaVec). - Note: the pointers it holds aren't const for now, because there may be - situations where it makes sense to change them (even though the number of - states and arcs can't be changed). - */ -struct Cfsa { - int32_t num_states; // number of states including final state. States are - // numbered `0 ... num_states - 1`. Start state is 0, - // final state is state `num_states - 1`. We store a - // redundant representation here out of a belief that it - // might reduce the number of instructions in code. - int32_t begin_arc; // a copy of arc_indexes[0]; gives the first index in - // `arcs` for the arcs in this FSA. Will be >= 0. - int32_t end_arc; // a copy of arc_indexes[num_states]; gives the - // one-past-the-last index in `arcs` for the arcs in this - // FSA. Will be >= begin_arc. - - const int32_t *arc_indexes; // an array, indexed by state index, giving the - // first arc index of each state. The last one - // is repeated, so for any valid state 0 <= s < - // num_states we can use arc_indexes[s+1]. That - // is: elements 0 through num_states (inclusive) - // are valid. CAUTION: arc_indexes[0] may be - // greater than zero. - - Arc *arcs; // Note: arcs[begin_arc] through arcs[end_arc - 1] - // are valid. - - Cfsa(); - // Constructor from Fsa. The passed `fsa` should be kept alive - // as long as this cfsa is alive. - explicit Cfsa(const Fsa &fsa); - - Cfsa &operator=(const Cfsa &cfsa) = default; - Cfsa(const Cfsa &cfsa) = default; - - int32_t NumStates() const { return num_states; } - int32_t NumArcs() const { return end_arc - begin_arc; } - int32_t FinalState() const { - CHECK_GE(num_states, 2) << "It's an error to invoke this method for " - << "an empty cfsa"; - return num_states - 1; - } - - // for test only - bool operator==(const Cfsa &other) const { - if (other.num_states != num_states) return false; - - if (other.NumArcs() != NumArcs()) return false; - - for (int32_t i = 0; i != NumArcs(); ++i) { - const auto &this_arc = arcs[begin_arc + i]; - const auto &other_arc = other.arcs[other.begin_arc + i]; - - if (this_arc != other_arc) return false; - } - - return true; + // It's not valid to call FinalState if the FSA is empty. + CHECK_GE(size1, 2); + return size1 - 1; } }; -std::ostream &operator<<(std::ostream &os, const Cfsa &cfsa); - -constexpr int32_t kCfsaVecVersion = 0x01; - -struct CfsaVecHeader { - int32_t version; - int32_t num_fsas; - int32_t state_offsets_start; - int32_t arc_indexes_start; - int32_t arcs_start; -}; - -class CfsaVec { - public: - /* - Constructor from linear data, e.g. from data stored in a torch.Tensor. - This would previously have been created using CreateCfsaVec(). - - @param [in] size size in int32_t elements of `data`, only - needed for checking purposes. - @param [in] data The underlying data. Format of data is - described below (all elements are of type - int32_t unless stated otherwise). Would have - been created by CreateCfsaVec(). - - - version Format version number, currently always 1. - - num_fsas The number of FSAs - - state_offsets_start The offset from the start of `data` of - where the `state_offsets` array is, in int32_t - (4-byte) elements. - - arc_indexes_start The offset from the start of `data` of - where the `arc_indexes` array is, in int32_t - (4-byte) elements. - - arcs_start The offset from the start of `data` of where - the first Arc is, in sizeof(Arc) multiples, i.e. - Arc *arcs = ((Arc*)data) + arcs_start - - [possibly some padding here] - - state_offsets[num_fsas + 1] state_offsets[f] is the sum of - the num-states of all the FSAs preceding f. It is - also the offset from the beginning of the - `arc_indexes` array of where the part corresponding - to FSA f starts. The number of states in FSA f - is given by - `state_offsets[f+1] - state_offsets[f] - 1`. - Caution: one is subtracted above because the last - entry in the arc_indexes array is repeated. - This is >= 0; it will be zero if the - FSA f is empty, and >= 2 otherwise. - [possibly some padding here] - - - arc_indexes[tot_states + num_fsas] This gives the indexes - into the `arcs` array of where we can find the - first of each state's arcs. `num_fsas` is needed - since the final state of every fsa is repeated in - `arc_indexes`. - - [pad as needed for memory-alignment purposes then...] - - - arcs[tot_arcs] - */ - CfsaVec(std::size_t size, void *data); - - int32_t NumFsas() const { return num_fsas_; } - - Cfsa operator[](int32_t i) const; - - CfsaVec &operator=(const CfsaVec &) = delete; - CfsaVec(const CfsaVec &) = delete; - - ~CfsaVec() { - if (opaque_deleter_) (*opaque_deleter_)(opaque_ptr_); - } - - void SetDeleter(void (*deleter)(void *), void *p) { - opaque_deleter_ = deleter; - opaque_ptr_ = p; - } - - private: - int32_t num_fsas_; +std::ostream &operator<<(std::ostream &os, const Fsa &fsa); - // The raw underlying data; - // CAUTION: we do NOT own the memory here. - int32_t *data_; - // The size of the underlying data; - // Caution: it is the number of `int32_t` in data_, NOT the number of bytes. - std::size_t size_; - - // the following two fields are for DLPack, which enables us to - // share memory with `torch::Tensor`. - // - // C++ code will in generate not touch them. - void (*opaque_deleter_)(void *) = nullptr; - void *opaque_ptr_ = nullptr; -}; - -/* - Return the number of bytes we'd need to represent this vector of Cfsas - linearly as a CfsaVec. */ -std::size_t GetCfsaVecSize(const std::vector &cfsas); - -// Return the number of bytes we'd need to represent this Cfsa -// linearly as a CfsaVec with one element -std::size_t GetCfsaVecSize(const Cfsa &cfsa); - -/* - Create a CfsaVec from a vector of Cfsas (this involves representing - the vector of Fsas in one big linear memory region). - - @param [in] cfsas The vector of Cfsas to be linearized; - must be nonempty - @param [in] data The allocated data of size `size` bytes - @param [in] size The size of the memory block in bytes passed; - must equal the return value of - GetCfsaVecSize(cfsas). - */ -void CreateCfsaVec(const std::vector &cfsas, void *data, - std::size_t size); +using Cfsa = Fsa; +using CfsaVec = Array3; struct Fst { Fsa core; diff --git a/k2/csrc/fsa_equivalent_test.cc b/k2/csrc/fsa_equivalent_test.cc index 589291320..e5a687c30 100644 --- a/k2/csrc/fsa_equivalent_test.cc +++ b/k2/csrc/fsa_equivalent_test.cc @@ -317,7 +317,7 @@ TEST(FsaEquivalent, RandomPathWithoutEpsilonArc) { bool status = rand_path.GetOutput(&path, state_map.data()); EXPECT_TRUE(status); EXPECT_GT(state_map.size(), 0); - for (const auto &arc : path.arcs) { + for (const auto &arc : path) { EXPECT_NE(arc.label, kEpsilon); } } diff --git a/k2/csrc/fsa_test.cc b/k2/csrc/fsa_test.cc index 11e02b065..c6d800a76 100644 --- a/k2/csrc/fsa_test.cc +++ b/k2/csrc/fsa_test.cc @@ -1,6 +1,7 @@ // k2/csrc/fsa_test.cc // Copyright (c) 2020 Fangjun Kuang (csukuangfj@gmail.com) +// Xiaomi Corporation (author: Haowen Qiu) // See ../../LICENSE for clarification regarding multiple authors @@ -9,274 +10,91 @@ #include #include +#include "gmock/gmock.h" #include "gtest/gtest.h" +#include "k2/csrc/array.h" #include "k2/csrc/fsa_util.h" #include "k2/csrc/util.h" namespace k2 { -TEST(Cfsa, ConstructorNonEmptyFsa) { - std::vector arcs = { - {0, 1, 1}, {0, 2, 2}, {1, 3, 3}, {2, 3, 3}, {3, 4, -1}, - }; - Fsa fsa(std::move(arcs), 4); - Cfsa cfsa(fsa); - - EXPECT_EQ(cfsa.num_states, 5); - EXPECT_EQ(cfsa.begin_arc, 0); - EXPECT_EQ(cfsa.end_arc, 5); - EXPECT_EQ(cfsa.arc_indexes, fsa.arc_indexes.data()); - EXPECT_EQ(cfsa.arcs, fsa.arcs.data()); - - EXPECT_EQ(cfsa.NumStates(), 5); - EXPECT_EQ(cfsa.FinalState(), 4); -} - -TEST(Cfsa, ConstructorEmptyFsa) { - Fsa fsa; - Cfsa cfsa(fsa); - EXPECT_EQ(cfsa.num_states, 0); - EXPECT_EQ(cfsa.begin_arc, 0); - EXPECT_EQ(cfsa.end_arc, 0); - EXPECT_EQ(cfsa.arc_indexes, nullptr); - EXPECT_EQ(cfsa.arcs, nullptr); -} - -TEST(GetCfsaVecSize, Empty) { - Cfsa cfsa; - std::size_t bytes = GetCfsaVecSize(cfsa); - // 20-byte header (20) - // 44-byte padding (64) - // 8-byte state_offsets_array (72) - // 56-byte padding (128) - // 4-byte arc_indexes_array (132) - EXPECT_EQ(bytes, 132u); - - std::vector cfsa_vec; - cfsa_vec.push_back(cfsa); - bytes = GetCfsaVecSize(cfsa_vec); - EXPECT_EQ(bytes, 132u); -} - -TEST(GetCfsaVecSize, NonEmpty) { - std::vector arcs = { - {0, 1, 1}, {0, 2, 2}, {1, 3, 3}, {2, 3, 3}, {3, 16, -1}, - }; - Fsa fsa(std::move(arcs), 16); - Cfsa cfsa(fsa); - - std::size_t bytes = GetCfsaVecSize(cfsa); - // 20-byte header (20) - // 44-byte padding (64) - // 8-byte state_offset_array (72) - // 56-byte padding (128) - // 72-byte arc_indexes_array (200) - // 4-byte padding (204) -> to be multiple of sizeof(Arc) - // 60-byte arcs_array (264) - EXPECT_EQ(bytes, 264u); - // Note that there are 5 arcs, sizeof(Arc) == 12. - // There are 17 states and each state needs 4 bytes - // and the last state is repeated, so the arc_indexes - // array needs 18*4 = 72-byte - - { - std::vector cfsa_vec; - cfsa_vec.push_back(cfsa); - bytes = GetCfsaVecSize(cfsa); - EXPECT_EQ(bytes, 264u); - } -} -TEST(GetCfsaVecSize, NonEmptyMutlipeFsas) { +TEST(CfsaVec, CreateCfsa) { std::vector arcs1 = { - {0, 1, 1}, {0, 2, 2}, {1, 3, 3}, {2, 3, 3}, {3, 16, -1}, + {0, 1, 1}, {0, 2, 2}, {1, 2, 3}, {1, 3, 4}, {3, 4, -1}, }; - // 5 arcs, 17 states - Fsa fsa1(std::move(arcs1), 16); - Cfsa cfsa1(fsa1); + FsaCreator fsa_creator1(arcs1, 4); + Cfsa cfsa1 = fsa_creator1.GetFsa(); + EXPECT_EQ(cfsa1.NumStates(), 5); + EXPECT_EQ(cfsa1.size2, 5); // num-arcs - // 4 arcs, 11 states std::vector arcs2 = { - {0, 1, 1}, - {0, 2, 2}, - {1, 3, 3}, - {3, 10, -1}, + {0, 2, 1}, + {0, 3, -1}, + {1, 3, -1}, + {2, 3, -1}, }; - // 5 arcs, 17 states - Fsa fsa2(std::move(arcs2), 10); - Cfsa cfsa2(fsa2); + FsaCreator fsa_creator2(arcs2, 3); + Cfsa cfsa2 = fsa_creator2.GetFsa(); + EXPECT_EQ(cfsa2.NumStates(), 4); + EXPECT_EQ(cfsa2.size2, 4); // num-arcs - std::vector cfsa_vec = {cfsa1, cfsa2}; - - std::size_t bytes = GetCfsaVecSize(cfsa_vec); - // 28 states,9 arcs - // - // 20-byte header (20) - // 44-byte padding (64) - // 12-byte state_offset_array (76) - // 52-byte padding (128) - // 120-byte arc_indexes_array (248) - // 4-byte padding (252) -> to be multiple of sizeof(Arc) - // 108-byte arcs_array (360) - EXPECT_EQ(bytes, 360u); -} - -TEST(CfsaVec, Empty) { - Cfsa cfsa; std::vector cfsas; - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), 0); -} - -TEST(CfsaVec, OneEmptyCfsa) { - Cfsa cfsa; - std::vector cfsas = {cfsa}; - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), 1); -} - -TEST(CfsaVec, OneNonEmptyCfsa) { - std::vector arcs = { - {0, 1, 10}, {0, 2, 2}, {1, 3, 3}, {2, 3, 3}, {2, 4, -1}, {3, 4, -1}, - }; - Fsa fsa(std::move(arcs), 16); - Cfsa cfsa(fsa); - - std::vector cfsas = {cfsa}; - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), 1); - - Cfsa f = cfsa_vec[0]; - EXPECT_EQ(f, cfsa); -} - -TEST(CfsaVec, TwoNoneEmptyCfsa) { - std::vector arcs1 = { - {0, 1, 10}, {0, 2, 2}, {1, 3, 3}, {2, 3, 3}, {2, 4, -1}, {3, 4, -1}, - }; - Fsa fsa1(std::move(arcs1), 4); - Cfsa cfsa1(fsa1); - - std::vector arcs2 = {{0, 1, 10}, {0, 2, 2}, {1, 3, 3}, {2, 3, 3}, - {2, 4, 3}, {3, 10, -1}, {4, 10, -1}}; - Fsa fsa2(std::move(arcs2), 10); - Cfsa cfsa2(fsa2); - - { - // both fsa are not empty - std::vector cfsas = {cfsa1, cfsa2}; - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), 2); - - Cfsa f = cfsa_vec[0]; - EXPECT_EQ(f, cfsa1); - - Cfsa g = cfsa_vec[1]; - EXPECT_EQ(g, cfsa2); + cfsas.emplace_back(cfsa1); + cfsas.emplace_back(cfsa2); + + CfsaVec cfsa_vec; + cfsa_vec.GetSizes(cfsas.data(), 2); + EXPECT_EQ(cfsa_vec.size1, 2); + EXPECT_EQ(cfsa_vec.size2, cfsa1.NumStates() + cfsa2.NumStates()); + EXPECT_EQ(cfsa_vec.size3, cfsa1.size2 + cfsa2.size2); + + // Test CfsaVec Creation + std::vector cfsa_vec_indexes1(cfsa_vec.size1 + 1); + std::vector cfsa_vec_indexes2(cfsa_vec.size2 + 1); + std::vector cfsa_vec_data(cfsa_vec.size3); + cfsa_vec.indexes1 = cfsa_vec_indexes1.data(); + cfsa_vec.indexes2 = cfsa_vec_indexes2.data(); + cfsa_vec.data = cfsa_vec_data.data(); + + cfsa_vec.Create(cfsas.data(), 2); + EXPECT_THAT(cfsa_vec_indexes1, ::testing::ElementsAre(0, 5, 9)); + EXPECT_THAT(cfsa_vec_indexes2, + ::testing::ElementsAre(0, 2, 4, 4, 5, 5, 7, 8, 9, 9)); + for (auto i = cfsa1.indexes[0]; i != cfsa1.indexes[cfsa1.size1]; ++i) { + EXPECT_EQ(cfsa_vec.data[i], cfsa1.data[i]); } - - { - // the first fsa is empty - Cfsa cfsa; - std::vector cfsas = {cfsa, cfsa2}; - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), 2); - - Cfsa f = cfsa_vec[0]; - EXPECT_EQ(f, cfsa); - - Cfsa g = cfsa_vec[1]; - EXPECT_EQ(g, cfsa2); + for (auto i = cfsa2.indexes[0]; i != cfsa2.indexes[cfsa2.size1]; ++i) { + EXPECT_EQ(cfsa_vec.data[cfsa1.size2 + i - cfsa2.indexes[0]], cfsa2.data[i]); } - { - // the second fsa is empty - Cfsa cfsa; - std::vector cfsas = {cfsa1, cfsa}; - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), 2); - - Cfsa f = cfsa_vec[0]; - EXPECT_EQ(f, cfsa1); - - Cfsa g = cfsa_vec[1]; - EXPECT_EQ(g, cfsa); + // Test operator[] + auto array1_copy = cfsa_vec[0]; + Cfsa *cfsa1_copy_ptr = static_cast(&array1_copy); // cast here + const auto &cfsa1_copy = *cfsa1_copy_ptr; + // should call `NumStates` successfully + EXPECT_EQ(cfsa1_copy.NumStates(), cfsa1.NumStates()); + EXPECT_EQ(cfsa1_copy.size2, cfsa1.size2); + for (auto i = 0; i != cfsa1.size1 + 1; ++i) { + EXPECT_EQ(cfsa1_copy.indexes[i], cfsa1.indexes[i]); } -} -// TODO(haowen): un-comment below lines after replacing Cfsa with Array3 -/* -TEST(CfsaVec, RandomFsa) { - RandFsaOptions opts; - opts.num_syms = 20; - opts.num_states = 30; - opts.num_arcs = 50; - opts.allow_empty = false; - opts.acyclic = false; - opts.seed = 20200531; - - int32_t n = 5; - std::vector fsa_vec; - fsa_vec.reserve(n); - for (int32_t i = 0; i != n; ++i) { - Fsa fsa; - GenerateRandFsa(opts, &fsa); - fsa_vec.emplace_back(std::move(fsa)); + for (auto i = cfsa1.indexes[0]; i != cfsa1.indexes[cfsa1.size1]; ++i) { + EXPECT_EQ(cfsa1_copy.data[i], cfsa1.data[i]); } - std::vector cfsas; - cfsas.reserve(n); - for (const auto &fsa : fsa_vec) { - cfsas.emplace_back(fsa); + auto array2_copy = cfsa_vec[1]; + Cfsa *cfsa2_copy_ptr = static_cast(&array2_copy); // cast here + const auto &cfsa2_copy = *cfsa2_copy_ptr; + // should call `NumStates` successfully + EXPECT_EQ(cfsa2_copy.NumStates(), cfsa2.NumStates()); + EXPECT_EQ(cfsa2_copy.size2, cfsa2.size2); + for (auto i = 0; i != cfsa2.size1 + 1; ++i) { + // output indexes may starts from n > 0 + EXPECT_EQ(cfsa2_copy.indexes[i], cfsa2.indexes[i] + cfsa1.size1); } - - std::size_t bytes = GetCfsaVecSize(cfsas); - std::unique_ptr data(MemAlignedMalloc(bytes, 64), - &MemFree); - - CreateCfsaVec(cfsas, data.get(), bytes); - - CfsaVec cfsa_vec(bytes / sizeof(int32_t), data.get()); - EXPECT_EQ(cfsa_vec.NumFsas(), n); - - for (int32_t i = 0; i != n; ++i) { - EXPECT_EQ(cfsa_vec[i], cfsas[i]); + for (auto i = cfsa2.indexes[0]; i != cfsa2.indexes[cfsa2.size1]; ++i) { + EXPECT_EQ(cfsa2_copy.data[i + cfsa1.size2 - cfsa2.indexes[0]], + cfsa2.data[i]); } } -*/ } // namespace k2 diff --git a/k2/csrc/fsa_util.cc b/k2/csrc/fsa_util.cc index 303881ec3..f36726a28 100644 --- a/k2/csrc/fsa_util.cc +++ b/k2/csrc/fsa_util.cc @@ -220,13 +220,6 @@ void ReorderArcs(const std::vector &arcs, Fsa *fsa, if (arc_map != nullptr) arc_map->swap(arc_map_out); } -void Swap(Fsa *a, Fsa *b) { - CHECK_NOTNULL(a); - CHECK_NOTNULL(b); - std::swap(a->arc_indexes, b->arc_indexes); - std::swap(a->arcs, b->arcs); -} - void StringToFsa::GetSizes(Array2Size *fsa_size) { CHECK_NOTNULL(fsa_size); fsa_size->size1 = fsa_size->size2 = 0; diff --git a/k2/csrc/fsa_util.h b/k2/csrc/fsa_util.h index c5420ac6b..e2d9ed817 100644 --- a/k2/csrc/fsa_util.h +++ b/k2/csrc/fsa_util.h @@ -123,18 +123,11 @@ void GetArcIndexes2(const std::vector> &arc_map, std::vector *indexes1, std::vector *indexes2); -void Swap(Fsa *a, Fsa *b); - // Create Fsa for test purpose. class FsaCreator { public: // Create an empty Fsa - FsaCreator() { - // TODO(haowen): remove below line and use `FsaCreator() = default` - // we need this for now as we reset `indexes = nullptr` in the constructor - // of Fsa - fsa_.indexes = &fsa_.size1; - } + FsaCreator() = default; /* Initialize Fsa with Array2size, search for 'initialized definition' in class diff --git a/k2/csrc/properties.h b/k2/csrc/properties.h index 258535c1f..67940a237 100644 --- a/k2/csrc/properties.h +++ b/k2/csrc/properties.h @@ -113,11 +113,7 @@ inline bool IsTopSortedAndConnected(const Fsa &fsa) { Returns true if `fsa` is empty. (Note: if `fsa` is not empty, it would contain at least two states, the start state and the final state). */ -// TODO(haowen): finally we'll implement just with: `return fsa.size1 == 0` -inline bool IsEmpty(const Fsa &fsa) { - if (fsa.indexes != nullptr) return fsa.size1 == 0; - return fsa.arc_indexes.empty() && fsa.arcs.empty(); -} +inline bool IsEmpty(const Fsa &fsa) { return fsa.size1 == 0; } /* Returns true if `fsa` is valid AND satisfies the list of properties diff --git a/k2/python/csrc/CMakeLists.txt b/k2/python/csrc/CMakeLists.txt index 817cb22ec..f2ad66a9f 100644 --- a/k2/python/csrc/CMakeLists.txt +++ b/k2/python/csrc/CMakeLists.txt @@ -1,6 +1,6 @@ # sort the files alphabetically pybind11_add_module(k2 - fsa.cc + #fsa.cc fsa_renderer.cc fsa_util.cc k2.cc diff --git a/k2/python/csrc/fsa.cc b/k2/python/csrc/fsa.cc index 35041142f..842c2f6f6 100644 --- a/k2/python/csrc/fsa.cc +++ b/k2/python/csrc/fsa.cc @@ -31,6 +31,7 @@ static const char *kDLPackTensorName = "dltensor"; // PyTorch, TVM and CuPy name the used dltensor to be `used_dltensor` static const char *kDLPackUsedTensorName = "used_dltensor"; +/* static void DLPackDeleter(void *p) { auto dl_managed_tensor = reinterpret_cast(p); @@ -48,7 +49,7 @@ static void DLPackDeleter(void *p) { static CfsaVec *CfsaVecFromDLPack(py::capsule *capsule, const std::vector *cfsas = nullptr) { // the following error message is modified from - // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp#L384 + // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp#L384 CHECK_EQ(strcmp(kDLPackTensorName, capsule->name()), 0) << "Expected capsule name: " << kDLPackTensorName << "\n" << "But got: " << capsule->name() << "\n" @@ -85,6 +86,7 @@ static CfsaVec *CfsaVecFromDLPack(py::capsule *capsule, return cfsa_vec; } + static void PybindCfsaVec(py::module &m) { m.def("get_cfsa_vec_size", overload_cast_()(&k2::GetCfsaVecSize), py::arg("cfsa")); @@ -105,89 +107,91 @@ static void PybindCfsaVec(py::module &m) { py::arg("dlpack"), py::arg("cfsas") = nullptr, py::return_value_policy::take_ownership); } +*/ void PybindFsa(py::module &m) { - py::class_(m, "Arc") - .def(py::init<>()) - .def(py::init(), py::arg("src_state"), - py::arg("dest_state"), py::arg("label")) - .def_readwrite("src_state", &Arc::src_state) - .def_readwrite("dest_state", &Arc::dest_state) - .def_readwrite("label", &Arc::label) - .def("__str__", [](const Arc &self) { - std::ostringstream os; - os << self; - return os.str(); - }); - - py::class_(m, "Fsa") - .def(py::init<>()) - .def("num_states", &Fsa::NumStates) - .def("final_state", &Fsa::FinalState) - .def("__str__", [](const Fsa &self) { return FsaToString(self); }) - .def_readwrite("arc_indexes", &Fsa::arc_indexes) - .def_readwrite("arcs", &Fsa::arcs); - - py::class_>(m, "FsaVec") - .def(py::init<>()) - .def("clear", &std::vector::clear) - .def("__len__", [](const std::vector &self) { return self.size(); }) - .def("push_back", - [](std::vector *self, const Fsa &fsa) { self->push_back(fsa); }) - .def("__iter__", - [](const std::vector &self) { - return py::make_iterator(self.begin(), self.end()); - }, - py::keep_alive<0, 1>()); - // py::keep_alive - // 0 is the return value and 1 is the first argument. - // Keep the patient (i.e., `self`) alive as long as the Nurse (i.e., the - // return value) is not freed. - - py::class_>(m, "ArcVec") - .def(py::init<>()) - .def("clear", &std::vector::clear) - .def("__len__", [](const std::vector &self) { return self.size(); }) - .def("__iter__", - [](const std::vector &self) { - return py::make_iterator(self.begin(), self.end()); - }, - py::keep_alive<0, 1>()); - - py::class_(m, "Cfsa") - .def(py::init<>()) - .def(py::init(), py::arg("fsa"), py::keep_alive<1, 2>()) - .def("num_states", &Cfsa::NumStates) - .def("num_arcs", &Cfsa::NumArcs) - .def("arc", - [](Cfsa *self, int s) { - DCHECK_GE(s, 0); - DCHECK_LT(s, self->NumStates()); - auto begin = self->arc_indexes[s]; - auto end = self->arc_indexes[s + 1]; - return py::make_iterator(self->arcs + begin, self->arcs + end); - }, - py::keep_alive<0, 1>()) - .def("__str__", - [](const Cfsa &self) { - std::ostringstream os; - os << self; - return os.str(); - }) - .def("__eq__", // for test only - [](const Cfsa &self, const Cfsa &other) { return self == other; }); - - py::class_>(m, "CfsaStdVec") - .def(py::init<>()) - .def("clear", &std::vector::clear) - .def("push_back", [](std::vector *self, - const Cfsa &cfsa) { self->push_back(cfsa); }) - .def("__len__", [](const std::vector &self) { return self.size(); }) - .def("__iter__", - [](const std::vector &self) { - return py::make_iterator(self.begin(), self.end()); - }, - py::keep_alive<0, 1>()); - + /* +py::class_(m, "Arc") + .def(py::init<>()) + .def(py::init(), py::arg("src_state"), + py::arg("dest_state"), py::arg("label")) + .def_readwrite("src_state", &Arc::src_state) + .def_readwrite("dest_state", &Arc::dest_state) + .def_readwrite("label", &Arc::label) + .def("__str__", [](const Arc &self) { + std::ostringstream os; + os << self; + return os.str(); + }); + +py::class_(m, "Fsa") + .def(py::init<>()) + .def("num_states", &Fsa::NumStates) + .def("final_state", &Fsa::FinalState) + .def("__str__", [](const Fsa &self) { return FsaToString(self); }) + .def_readwrite("arc_indexes", &Fsa::arc_indexes) + .def_readwrite("arcs", &Fsa::arcs); + +py::class_>(m, "FsaVec") + .def(py::init<>()) + .def("clear", &std::vector::clear) + .def("__len__", [](const std::vector &self) { return self.size(); }) + .def("push_back", + [](std::vector *self, const Fsa &fsa) { self->push_back(fsa); }) + .def("__iter__", + [](const std::vector &self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); +// py::keep_alive +// 0 is the return value and 1 is the first argument. +// Keep the patient (i.e., `self`) alive as long as the Nurse (i.e., the +// return value) is not freed. + +py::class_>(m, "ArcVec") + .def(py::init<>()) + .def("clear", &std::vector::clear) + .def("__len__", [](const std::vector &self) { return self.size(); }) + .def("__iter__", + [](const std::vector &self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); + +py::class_(m, "Cfsa") + .def(py::init<>()) + .def(py::init(), py::arg("fsa"), py::keep_alive<1, 2>()) + .def("num_states", &Cfsa::NumStates) + .def("num_arcs", &Cfsa::NumArcs) + .def("arc", + [](Cfsa *self, int s) { + DCHECK_GE(s, 0); + DCHECK_LT(s, self->NumStates()); + auto begin = self->arc_indexes[s]; + auto end = self->arc_indexes[s + 1]; + return py::make_iterator(self->arcs + begin, self->arcs + end); + }, + py::keep_alive<0, 1>()) + .def("__str__", + [](const Cfsa &self) { + std::ostringstream os; + os << self; + return os.str(); + }) + .def("__eq__", // for test only + [](const Cfsa &self, const Cfsa &other) { return self == other; }); + +py::class_>(m, "CfsaStdVec") + .def(py::init<>()) + .def("clear", &std::vector::clear) + .def("push_back", [](std::vector *self, + const Cfsa &cfsa) { self->push_back(cfsa); }) + .def("__len__", [](const std::vector &self) { return self.size(); }) + .def("__iter__", + [](const std::vector &self) { + return py::make_iterator(self.begin(), self.end()); + }, + py::keep_alive<0, 1>()); PybindCfsaVec(m); +*/ } diff --git a/k2/python/csrc/k2.cc b/k2/python/csrc/k2.cc index 29138df76..36f337b65 100644 --- a/k2/python/csrc/k2.cc +++ b/k2/python/csrc/k2.cc @@ -12,7 +12,7 @@ PYBIND11_MODULE(k2, m) { m.doc() = "pybind11 binding of k2"; - PybindFsa(m); + // PybindFsa(m); PybindFsaRenderer(m); PybindFsaUtil(m); } diff --git a/k2/python/csrc/k2.h b/k2/python/csrc/k2.h index 0293c26dc..baf5f18b2 100644 --- a/k2/python/csrc/k2.h +++ b/k2/python/csrc/k2.h @@ -9,18 +9,17 @@ #include +#include "k2/csrc/fsa.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" -#include "k2/csrc/fsa.h" - namespace py = pybind11; template using overload_cast_ = pybind11::detail::overload_cast_impl; PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE(std::vector); +// PYBIND11_MAKE_OPAQUE(std::vector); +// PYBIND11_MAKE_OPAQUE(std::vector); #endif // K2_PYTHON_CSRC_K2_H_