From ea6221668ae2905c1a1b770be326630fffb31d15 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Wed, 22 Jul 2020 19:12:22 +0800 Subject: [PATCH] rewrite pybind interface for Array2 and Fsa --- k2/csrc/arcsort.cc | 11 ++- k2/csrc/arcsort.h | 9 ++- k2/csrc/arcsort_test.cc | 12 ++- k2/csrc/array.h | 43 ++++++++++- k2/csrc/fsa_equivalent.cc | 8 +- k2/python/csrc/CMakeLists.txt | 1 + k2/python/csrc/array.cc | 110 ++++++++++++++++++++++++---- k2/python/csrc/array.h | 1 + k2/python/csrc/fsa.cc | 39 ++++++---- k2/python/csrc/fsa_algo.cc | 28 +++++++ k2/python/csrc/fsa_algo.h | 14 ++++ k2/python/csrc/k2.cc | 3 + k2/python/k2/__init__.py | 3 +- k2/python/k2/array.py | 42 ++++++++++- k2/python/k2/fsa.py | 43 ++++++++++- k2/python/k2/fsa_algo.py | 30 ++++++++ k2/python/tests/CMakeLists.txt | 1 + k2/python/tests/array_test.py | 63 +++++++++++++--- k2/python/tests/fsa_arcsort_test.py | 102 ++++++++++++++++++++++++++ k2/python/tests/fsa_test.py | 27 ++++++- 20 files changed, 522 insertions(+), 68 deletions(-) create mode 100644 k2/python/csrc/fsa_algo.cc create mode 100644 k2/python/csrc/fsa_algo.h create mode 100644 k2/python/k2/fsa_algo.py create mode 100644 k2/python/tests/fsa_arcsort_test.py diff --git a/k2/csrc/arcsort.cc b/k2/csrc/arcsort.cc index 1d8c8d855..65b70ac30 100644 --- a/k2/csrc/arcsort.cc +++ b/k2/csrc/arcsort.cc @@ -21,7 +21,8 @@ void ArcSorter::GetSizes(Array2Size *fsa_size) const { fsa_size->size2 = fsa_in_.size2; } -void ArcSorter::GetOutput(Fsa *fsa_out, int32_t *arc_map /*= nullptr*/) { +void ArcSorter::GetOutput(Fsa *fsa_out, + Array1 *arc_map /*= nullptr*/) { CHECK_NOTNULL(fsa_out); CHECK_EQ(fsa_out->size1, fsa_in_.size1); CHECK_EQ(fsa_out->size2, fsa_in_.size2); @@ -50,10 +51,11 @@ void ArcSorter::GetOutput(Fsa *fsa_out, int32_t *arc_map /*= nullptr*/) { } fsa_out->indexes[num_states] = num_arcs; - if (arc_map != nullptr) std::copy(indexes.begin(), indexes.end(), arc_map); + if (arc_map != nullptr) + std::copy(indexes.begin(), indexes.end(), arc_map->data); } -void ArcSort(Fsa *fsa, int32_t *arc_map /*= nullptr*/) { +void ArcSort(Fsa *fsa, Array1 *arc_map /*= nullptr*/) { CHECK_NOTNULL(fsa); std::vector indexes(fsa->size2); @@ -74,7 +76,8 @@ void ArcSort(Fsa *fsa, int32_t *arc_map /*= nullptr*/) { [](const Arc &left, const Arc &right) { return left < right; }); } - if (arc_map != nullptr) std::copy(indexes.begin(), indexes.end(), arc_map); + if (arc_map != nullptr) + std::copy(indexes.begin(), indexes.end(), arc_map->data); } } // namespace k2 diff --git a/k2/csrc/arcsort.h b/k2/csrc/arcsort.h index e08c1e772..fad1a8428 100644 --- a/k2/csrc/arcsort.h +++ b/k2/csrc/arcsort.h @@ -11,6 +11,7 @@ #include #include "glog/logging.h" +#include "k2/csrc/array.h" #include "k2/csrc/fsa.h" namespace k2 { @@ -43,19 +44,19 @@ class ArcSorter { @param [out] arc_map If non-NULL, will output a map from the arc-index in `fsa_out` to the corresponding arc-index in `fsa_in`. - If non-NULL, at entry it must be allocated with + If non-NULL, at entry it must be initialized with size num-arcs of `fsa_out`, e.g. `fsa_out->size2`. */ - void GetOutput(Fsa *fsa_out, int32_t *arc_map = nullptr); + void GetOutput(Fsa *fsa_out, Array1 *arc_map = nullptr); private: const Fsa &fsa_in_; }; // In-place version of ArcSorter; see its documentation; -// Note that if `arc_map` is non-NULL, then at entry it must be allocated with +// Note that if `arc_map` is non-NULL, then at entry it must be initialized with // size num-arcs of `fsa`, e.g. `fsa->size2` -void ArcSort(Fsa *fsa, int32_t *arc_map = nullptr); +void ArcSort(Fsa *fsa, Array1 *arc_map = nullptr); } // namespace k2 diff --git a/k2/csrc/arcsort_test.cc b/k2/csrc/arcsort_test.cc index 18d37a2f9..bcdf649a2 100644 --- a/k2/csrc/arcsort_test.cc +++ b/k2/csrc/arcsort_test.cc @@ -33,7 +33,8 @@ TEST(ArcSortTest, ArcSorter) { FsaCreator fsa_creator_out(fsa_size); auto &arc_sorted = fsa_creator_out.GetFsa(); std::vector arc_map(fsa_size.size2); - sorter.GetOutput(&arc_sorted, arc_map.data()); + Array1 arc_map_array1(arc_map.size(), arc_map.data()); + sorter.GetOutput(&arc_sorted, &arc_map_array1); EXPECT_TRUE(IsEmpty(arc_sorted)); EXPECT_TRUE(arc_map.empty()); @@ -52,7 +53,8 @@ TEST(ArcSortTest, ArcSorter) { FsaCreator fsa_creator_out(fsa_size); auto &arc_sorted = fsa_creator_out.GetFsa(); std::vector arc_map(fsa_size.size2); - sorter.GetOutput(&arc_sorted, arc_map.data()); + Array1 arc_map_array1(arc_map.size(), arc_map.data()); + sorter.GetOutput(&arc_sorted, &arc_map_array1); EXPECT_FALSE(arc_map.empty()); EXPECT_TRUE(IsArcSorted(arc_sorted)); @@ -86,7 +88,8 @@ TEST(ArcSortTest, ArcSort) { FsaCreator fsa_creator; auto &fsa = fsa_creator.GetFsa(); std::vector arc_map(fsa.size2); - ArcSort(&fsa, arc_map.data()); + Array1 arc_map_array1(arc_map.size(), arc_map.data()); + ArcSort(&fsa, &arc_map_array1); EXPECT_TRUE(IsEmpty(fsa)); EXPECT_TRUE(arc_map.empty()); @@ -99,7 +102,8 @@ TEST(ArcSortTest, ArcSort) { FsaCreator fsa_creator(src_arcs, 4); auto &fsa = fsa_creator.GetFsa(); std::vector arc_map(fsa.size2); - ArcSort(&fsa, arc_map.data()); + Array1 arc_map_array1(arc_map.size(), arc_map.data()); + ArcSort(&fsa, &arc_map_array1); EXPECT_TRUE(IsArcSorted(fsa)); diff --git a/k2/csrc/array.h b/k2/csrc/array.h index 1b6383706..1fa8aaaba 100644 --- a/k2/csrc/array.h +++ b/k2/csrc/array.h @@ -71,6 +71,43 @@ struct StridedPtr { } }; +template +struct Array1 { + // One dimensional array of something, like vector + // where Ptr is, or behaves like, X*. + using IndexT = I; + using PtrT = Ptr; + using ValueType = typename std::iterator_traits::value_type; + + Array1() : begin(0), end(0), size(0), data(nullptr) {} + Array1(IndexT begin, IndexT end, PtrT data) + : begin(begin), end(end), data(data) { + CHECK_GE(end, begin); + this->size = end - begin; + } + Array1(IndexT size, PtrT data) : begin(0), end(size), size(size), data(data) { + CHECK_GE(size, 0); + } + void Init(IndexT begin, IndexT end, PtrT data) { + CHECK_GE(end, begin); + this->begin = begin; + this->end = end; + this->size = end - begin; + this->data = data; + } + bool Empty() const { return begin == end; } + + // 'begin' and 'end' are the first and one-past-the-last indexes into `data` + // that we are allowed to use. + IndexT begin; + IndexT end; + IndexT size; // the number of elements in `data` that can be accessed, equals + // to `end - begin` + PtrT data; + + private: +}; + /* This struct stores the size of an Array2 object; it will generally be used as an output argument by functions that work out this size. @@ -293,12 +330,10 @@ struct Array2Storage { Array2Storage(const Array2Size &array2_size, I stride) : indexes_storage_(new I[array2_size.size1 + 1]), data_storage_(new ValueType[array2_size.size2 * stride]) { - array_.size1 = array2_size.size1; - array_.size2 = array2_size.size2; - array_.indexes = indexes_storage_.get(); + array_.Init(array2_size.size1, array2_size.size2, indexes_storage_.get(), + DataPtrCreator::Create(data_storage_, stride)); // just for case of empty Array2 object, may be written by the caller array_.indexes[0] = 0; - array_.data = DataPtrCreator::Create(data_storage_, stride); } void FillIndexes(const std::vector &indexes) { diff --git a/k2/csrc/fsa_equivalent.cc b/k2/csrc/fsa_equivalent.cc index 8019becf9..92f30e44f 100644 --- a/k2/csrc/fsa_equivalent.cc +++ b/k2/csrc/fsa_equivalent.cc @@ -62,8 +62,12 @@ static void ArcSort(const k2::Fsa &fsa_in, k2::FsaCreator *fsa_out, fsa_out->Init(fsa_size); auto &sorted_fsa = fsa_out->GetFsa(); - if (arc_map != nullptr) arc_map->resize(fsa_size.size2); - sorter.GetOutput(&sorted_fsa, arc_map == nullptr ? nullptr : arc_map->data()); + k2::Array1 arc_map_array1; + if (arc_map != nullptr) { + arc_map->resize(fsa_size.size2); + arc_map_array1.Init(0, arc_map->size(), arc_map->data()); + } + sorter.GetOutput(&sorted_fsa, arc_map == nullptr ? nullptr : &arc_map_array1); } /* diff --git a/k2/python/csrc/CMakeLists.txt b/k2/python/csrc/CMakeLists.txt index 6f4703b7a..32cefb6a6 100644 --- a/k2/python/csrc/CMakeLists.txt +++ b/k2/python/csrc/CMakeLists.txt @@ -2,6 +2,7 @@ pybind11_add_module(_k2 array.cc fsa.cc + fsa_algo.cc fsa_util.cc k2.cc tensor.cc diff --git a/k2/python/csrc/array.cc b/k2/python/csrc/array.cc index eee7bafca..ba0e55fd3 100644 --- a/k2/python/csrc/array.cc +++ b/k2/python/csrc/array.cc @@ -1,5 +1,3 @@ -// k2/python/csrc/array.cc - // Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) // See ../../../LICENSE for clarification regarding multiple authors @@ -10,9 +8,39 @@ #include #include "k2/csrc/array.h" +#include "k2/csrc/determinize_impl.h" #include "k2/python/csrc/tensor.h" namespace k2 { + +/* + DLPackArray1 initializes Array1 with `cap_data` which is a DLManagedTensor. + + `cap_data` is usually a one dimensional array with stride >= 1, i.e., + `cap_data.ndim == 1 && cap_indexes.strides[0] >= 1`. +*/ +template +class DLPackArray1; + +template +class DLPackArray1 : public Array1 { + public: + explicit DLPackArray1(py::capsule cap_data) + : data_tensor_(new Tensor(cap_data)) { + CHECK_EQ(data_tensor_->NumDim(), 1); + CHECK_GE(data_tensor_->Shape(0), 0); // num-elements + CHECK_EQ(data_tensor_->Stride(0), 1); + + int32_t size = data_tensor_->Shape(0); + this->Init(0, size, data_tensor_->Data()); + } + + private: + std::unique_ptr data_tensor_; +}; +// Note: we can specialized for `StridedPtr` later if we need it, +// `cap_data.strides[0]` will be greater than 1 in that case. + /* DLPackArray2 initializes Array2 with `cap_indexes` and `cap_data` which are DLManagedTensors. @@ -100,11 +128,32 @@ class DLPackArray2 : public Array2 { std::unique_ptr indexes_tensor_; std::unique_ptr data_tensor_; }; - // Note: we can specialized for `StridedPtr` later if we need it. } // namespace k2 +template +void PybindArray1Tpl(py::module &m, const char *name) { + using PyClass = k2::DLPackArray1; + using Parent = k2::Array1; + py::class_(m, name) + .def(py::init(), py::arg("data")) + .def("empty", &PyClass::Empty) + .def( + "get_base", + [](PyClass &self) { return reinterpret_cast(&self); }, + py::return_value_policy::reference_internal) + .def_readonly("size", &PyClass::size) + .def( + "get_data", + [](const PyClass &self, I i) { + if (i >= self.size) throw py::index_error(); + return self.data[self.begin + i]; + }, + "just for test purpose to check if k2::Array1 and the " + "underlying tensor are sharing memory."); +} + template void PybindArray2Tpl(py::module &m, const char *name) { using PyClass = k2::DLPackArray2; @@ -114,20 +163,47 @@ void PybindArray2Tpl(py::module &m, const char *name) { py::arg("data")) .def("empty", &PyClass::Empty) .def( - "__iter__", - [](const PyClass &self) { - return py::make_iterator(self.begin(), self.end()); - }, - py::keep_alive<0, 1>()) + "get_base", + [](PyClass &self) { return reinterpret_cast(&self); }, + py::return_value_policy::reference_internal) .def_readonly("size1", &PyClass::size1) .def_readonly("size2", &PyClass::size2) - .def("indexes", [](const PyClass &self, I i) { return self.indexes[i]; }) - .def("data", [](const PyClass &self, I i) { return self.data[i]; }); - // TODO(haowen): expose `indexes` and `data` as an array - // instead of a function call? + .def( + "get_indexes", + [](const PyClass &self, I i) { + if (i > self.size1) // note indexes.size == size1+1 + throw py::index_error(); + return self.indexes[i]; + }, + "just for test purpose to check if k2::Array1 and the " + "underlying tensor are sharing memory.") + .def( + "get_data", + [](const PyClass &self, I i) { + if (i >= self.size2) throw py::index_error(); + return self.data[self.indexes[0] + i]; + }, + "just for test purpose to check if k2::Array1 and the " + "underlying tensor are sharing memory."); +} + +template +void PybindArray2SizeTpl(py::module &m, const char *name) { + using PyClass = k2::Array2Size; + py::class_(m, name) + .def(py::init<>()) + .def(py::init(), py::arg("size1"), py::arg("size2")) + .def_readwrite("size1", &PyClass::size1) + .def_readwrite("size2", &PyClass::size2); } void PybindArray(py::module &m) { + // Note: all the following wrappers whose name starts with `_` are only used + // by pybind11 internally so that it knows `k2::DLPackArray1` is a subclass of + // `k2::Array1`. + py::class_>(m, "_IntArray1"); + PybindArray1Tpl(m, "DLPackIntArray1"); + // Note: all the following wrappers whose name starts with `_` are only used // by pybind11 internally so that it knows `k2::DLPackArray2` is a subclass of // `k2::Array2`. @@ -135,7 +211,11 @@ void PybindArray(py::module &m) { PybindArray2Tpl(m, "DLPackIntArray2"); // note there is a type cast as the underlying Tensor is with type `float` - py::class_ *>>(m, "_LogSumArcDerivs"); - PybindArray2Tpl *, false>(m, - "DLPackLogSumArcDerivs"); + using LogSumDerivType = typename k2::LogSumTracebackState::DerivType; + py::class_>(m, "_LogSumArcDerivs"); + PybindArray2Tpl(m, "DLPackLogSumArcDerivs"); +} + +void PybindArray2Size(py::module &m) { + PybindArray2SizeTpl(m, "IntArray2Size"); } diff --git a/k2/python/csrc/array.h b/k2/python/csrc/array.h index 35b926014..14d5d267f 100644 --- a/k2/python/csrc/array.h +++ b/k2/python/csrc/array.h @@ -10,5 +10,6 @@ #include "k2/python/csrc/k2.h" void PybindArray(py::module &m); +void PybindArray2Size(py::module &m); #endif // K2_PYTHON_CSRC_ARRAY_H_ diff --git a/k2/python/csrc/fsa.cc b/k2/python/csrc/fsa.cc index 123a4ee4c..67c8e3867 100644 --- a/k2/python/csrc/fsa.cc +++ b/k2/python/csrc/fsa.cc @@ -14,8 +14,8 @@ namespace k2 { -// it uses external memory passed from DLPack (e.g., by PyTorch) -// to construct an Fsa. +// DLPackFsa initializes Fsa with `cap_indexes` and `cap_data` which are +// DLManagedTensors. class DLPackFsa : public Fsa { public: DLPackFsa(py::capsule cap_indexes, py::capsule cap_data) @@ -50,7 +50,7 @@ class DLPackFsa : public Fsa { void PybindArc(py::module &m) { using PyClass = k2::Arc; - py::class_(m, "Arc") + py::class_(m, "_Arc") .def(py::init<>()) .def(py::init(), py::arg("src_state"), py::arg("dest_state"), py::arg("label")) @@ -70,21 +70,34 @@ void PybindFsa(py::module &m) { py::class_(m, "_Fsa"); using PyClass = k2::DLPackFsa; - py::class_(m, "DLPackFsa") + using Parent = k2::Fsa; + py::class_(m, "DLPackFsa") .def(py::init(), py::arg("indexes"), py::arg("data")) - .def("empty", &PyClass::Empty) .def( - "__iter__", - [](const PyClass &self) { - return py::make_iterator(self.begin(), self.end()); - }, - py::keep_alive<0, 1>()) + "get_base", + [](PyClass &self) { return reinterpret_cast(&self); }, + py::return_value_policy::reference_internal) + .def("empty", &PyClass::Empty) .def_readonly("size1", &PyClass::size1) .def_readonly("size2", &PyClass::size2) - .def("indexes", - [](const PyClass &self, int32_t i) { return self.indexes[i]; }) - .def("data", [](const PyClass &self, int32_t i) { return self.data[i]; }) + .def( + "get_indexes", + [](const PyClass &self, int32_t i) { + if (i > self.size1) // note indexes.size == size1+1 + throw py::index_error(); + return self.indexes[i]; + }, + "just for test purpose to check if k2::Fsa and the " + "underlying tensor are sharing memory.") + .def( + "get_data", + [](const PyClass &self, int32_t i) { + if (i >= self.size2) throw py::index_error(); + return self.data[self.indexes[0] + i]; + }, + "just for test purpose to check if k2::Fsa and the " + "underlying tensor are sharing memory.") .def("num_states", &PyClass::NumStates) .def("final_state", &PyClass::FinalState); } diff --git a/k2/python/csrc/fsa_algo.cc b/k2/python/csrc/fsa_algo.cc new file mode 100644 index 000000000..70fdf1a9b --- /dev/null +++ b/k2/python/csrc/fsa_algo.cc @@ -0,0 +1,28 @@ +// k2/python/csrc/fsa_algo.cc + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#include +#include + +#include "k2/csrc/arcsort.h" +#include "k2/csrc/array.h" +#include "k2/python/csrc/array.h" + +namespace k2 {} // namespace k2 + +void PyBindArcSort(py::module &m) { + using PyClass = k2::ArcSorter; + py::class_(m, "_ArcSorter") + .def(py::init(), py::arg("fsa_in")) + .def("get_sizes", &PyClass::GetSizes, py::arg("fsa_size")) + .def("get_output", &PyClass::GetOutput, py::arg("fsa_out"), + py::arg("arc_map") = (k2::Array1 *)nullptr); + + m.def("_arc_sort", &k2::ArcSort, "in-place version of ArcSorter", + py::arg("fsa"), py::arg("arc_map") = (k2::Array1 *)nullptr); +} + +void PybindFsaAlgo(py::module &m) { PyBindArcSort(m); } diff --git a/k2/python/csrc/fsa_algo.h b/k2/python/csrc/fsa_algo.h new file mode 100644 index 000000000..531f86590 --- /dev/null +++ b/k2/python/csrc/fsa_algo.h @@ -0,0 +1,14 @@ +// k2/python/csrc/fsa_algo.h + +// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +// See ../../../LICENSE for clarification regarding multiple authors + +#ifndef K2_PYTHON_CSRC_FSA_ALGO_H_ +#define K2_PYTHON_CSRC_FSA_ALGO_H_ + +#include "k2/python/csrc/k2.h" + +void PybindFsaAlgo(py::module &m); + +#endif // K2_PYTHON_CSRC_FSA_ALGO_H_ diff --git a/k2/python/csrc/k2.cc b/k2/python/csrc/k2.cc index 3e20b8e76..01535061b 100644 --- a/k2/python/csrc/k2.cc +++ b/k2/python/csrc/k2.cc @@ -8,12 +8,15 @@ #include "k2/python/csrc/array.h" #include "k2/python/csrc/fsa.h" +#include "k2/python/csrc/fsa_algo.h" #include "k2/python/csrc/fsa_util.h" PYBIND11_MODULE(_k2, m) { m.doc() = "pybind11 binding of k2"; PybindArc(m); PybindArray(m); + PybindArray2Size(m); PybindFsa(m); PybindFsaUtil(m); + PybindFsaAlgo(m); } diff --git a/k2/python/k2/__init__.py b/k2/python/k2/__init__.py index 8addabe13..0fac89108 100644 --- a/k2/python/k2/__init__.py +++ b/k2/python/k2/__init__.py @@ -1,4 +1,5 @@ -from _k2 import Arc +from _k2 import IntArray2Size from .array import * from .fsa import * +from .fsa_algo import * from .fsa_util import str_to_fsa diff --git a/k2/python/k2/array.py b/k2/python/k2/array.py index 33c5cf8bb..eddc420fd 100644 --- a/k2/python/k2/array.py +++ b/k2/python/k2/array.py @@ -5,17 +5,53 @@ import torch from torch.utils.dlpack import to_dlpack +from _k2 import IntArray2Size from _k2 import DLPackIntArray2 +from _k2 import DLPackIntArray1 from _k2 import DLPackLogSumArcDerivs + +class IntArray1(DLPackIntArray1): + + def __init__(self, data: torch.Tensor): + assert data.dtype == torch.int32 + self.data = data + super().__init__(to_dlpack(self.data)) + + @staticmethod + def create_array_with_size(size: int): + data = torch.zeros(size, dtype=torch.int32) + return IntArray1(data) + + class IntArray2(DLPackIntArray2): - # TODO(haowen): add methods to construct object with Array2Size def __init__(self, indexes: torch.Tensor, data: torch.Tensor): - super().__init__(to_dlpack(indexes), to_dlpack(data)) + assert indexes.dtype == torch.int32 + assert data.dtype == torch.int32 + self.indexes = indexes + self.data = data + super().__init__(to_dlpack(self.indexes), to_dlpack(self.data)) + + @staticmethod + def create_array_with_size(array_size: IntArray2Size): + indexes = torch.zeros(array_size.size1 + 1, dtype=torch.int32) + data = torch.zeros(array_size.size2, dtype=torch.int32) + return IntArray2(indexes, data) class LogSumArcDerivs(DLPackLogSumArcDerivs): def __init__(self, indexes: torch.Tensor, data: torch.Tensor): - super().__init__(to_dlpack(indexes), to_dlpack(data)) + assert indexes.dtype == torch.int32 + assert data.dtype == torch.float32 + assert data.shape[1] == 2 + self.indexes = indexes + self.data = data + super().__init__(to_dlpack(self.indexes), to_dlpack(self.data)) + + @staticmethod + def create_arc_derivs_with_size(array_size: IntArray2Size): + indexes = torch.zeros(array_size.size1 + 1, dtype=torch.int32) + data = torch.zeros([array_size.size2, 2], dtype=torch.float32) + return Fsa(indexes, data) diff --git a/k2/python/k2/fsa.py b/k2/python/k2/fsa.py index c83e87820..9191f8298 100644 --- a/k2/python/k2/fsa.py +++ b/k2/python/k2/fsa.py @@ -5,12 +5,49 @@ import torch from torch.utils.dlpack import to_dlpack +from _k2 import IntArray2Size +from _k2 import _Arc from _k2 import DLPackFsa +from _k2 import IntArray2Size + + +class Arc(_Arc): + + def __init__(self, src_state: int, dest_state: int, label: int): + super().__init__(src_state, dest_state, label) + + def to_tensor(self): + return torch.tensor([self.src_state, self.dest_state, self.label], + dtype=torch.int32) + + @staticmethod + def from_tensor(tensor: torch.Tensor): + assert tensor.shape == torch.Size([3]) + assert tensor.dtype == torch.int32 + return Arc(*tensor.tolist()) + class Fsa(DLPackFsa): + """ + Corresponds to k2::Fsa class, initializes k2::Fsa with torch.Tensors. - # TODO(haowen): add methods to construct object with Array2Size - def __init__(self, indexes: torch.Tensor, data: torch.Tensor): - super().__init__(to_dlpack(indexes), to_dlpack(data)) + Note that we view each row of self.data as a k2::Arc, usually users + can convert the type between tensor and k2::Arc by calling + `k2.Arc.from_tensor()` and `k2.Arc.to_tensor()`. + If users want to change the values in Fsa, just call `fsa.data[i] = some_tensor`. + + """ + def __init__(self, indexes: torch.Tensor, data: torch.Tensor): + assert indexes.dtype == torch.int32 + assert data.dtype == torch.int32 + assert data.shape[1] == 3 + self.indexes = indexes + self.data = data + super().__init__(to_dlpack(self.indexes), to_dlpack(self.data)) + @staticmethod + def create_fsa_with_size(array_size: IntArray2Size): + indexes = torch.zeros(array_size.size1 + 1, dtype=torch.int32) + data = torch.zeros([array_size.size2, 3], dtype=torch.int32) + return Fsa(indexes, data) diff --git a/k2/python/k2/fsa_algo.py b/k2/python/k2/fsa_algo.py new file mode 100644 index 000000000..ea50d4754 --- /dev/null +++ b/k2/python/k2/fsa_algo.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) + +# See ../../../LICENSE for clarification regarding multiple authors + +import torch +from torch.utils.dlpack import to_dlpack + +from .fsa import Fsa +from .array import IntArray1 +from .array import IntArray2 +from _k2 import _ArcSorter +from _k2 import _arc_sort + + +class ArcSorter(_ArcSorter): + + def __init__(self, fsa_in: Fsa): + super().__init__(fsa_in.get_base()) + + def get_sizes(self, array_size: IntArray2): + super().get_sizes(array_size) + + def get_output(self, fsa_out: Fsa, arc_map: IntArray1 = None): + super().get_output(fsa_out.get_base(), + arc_map.get_base() if arc_map is not None else None) + + +def arc_sort(fsa: Fsa, arc_map: IntArray1 = None): + return _arc_sort(fsa.get_base(), + arc_map.get_base() if arc_map is not None else None) diff --git a/k2/python/tests/CMakeLists.txt b/k2/python/tests/CMakeLists.txt index 72dd5df3c..ec8476f2b 100644 --- a/k2/python/tests/CMakeLists.txt +++ b/k2/python/tests/CMakeLists.txt @@ -19,6 +19,7 @@ endfunction() # please sort the files in alphabetic order set(py_test_files array_test.py + fsa_arcsort_test.py fsa_test.py ) diff --git a/k2/python/tests/array_test.py b/k2/python/tests/array_test.py index 5cdfacb90..298fd72f4 100644 --- a/k2/python/tests/array_test.py +++ b/k2/python/tests/array_test.py @@ -9,6 +9,7 @@ # ctest --verbose -R array_test_py # +from struct import pack, unpack import unittest import torch @@ -18,36 +19,60 @@ class TestArray(unittest.TestCase): + def test_int_array1(self): + data = torch.arange(10).to(torch.int32) + + array = k2.IntArray1(data) + self.assertFalse(array.empty()) + self.assertIsInstance(array, k2.IntArray1) + self.assertEqual(data.numel(), array.size) + self.assertEqual(array.data[9], 9) + + # the underlying memory is shared between k2 and torch; + # so change one will change another + data[0] = 100 + self.assertEqual(array.data[0], 100) + self.assertEqual(array.get_data(0), 100) + + del data + # the array in k2 is still accessible + self.assertEqual(array.data[0], 100) + self.assertEqual(array.get_data(0), 100) + def test_int_array2(self): data = torch.arange(10).to(torch.int32) indexes = torch.tensor([0, 2, 5, 6, 10]).to(torch.int32) - self.assertEqual(data.numel(),indexes[-1].item()) + self.assertEqual(data.numel(), indexes[-1].item()) array = k2.IntArray2(indexes, data) self.assertFalse(array.empty()) self.assertIsInstance(array, k2.IntArray2) - # test iterator - for i, v in enumerate(array): - self.assertEqual(i, v) - self.assertEqual(indexes.numel(), array.size1 + 1) self.assertEqual(data.numel(), array.size2) + self.assertEqual(array.data[9], 9) # the underlying memory is shared between k2 and torch; # so change one will change another data[0] = 100 - self.assertEqual(array.data(0), 100) + self.assertEqual(array.data[0], 100) + self.assertEqual(array.get_data(0), 100) + indexes[1] = 3 + self.assertEqual(array.indexes[1], 3) + self.assertEqual(array.get_indexes(1), 3) del data + del indexes # the array in k2 is still accessible - self.assertEqual(array.data(0), 100) - + self.assertEqual(array.data[0], 100) + self.assertEqual(array.get_data(0), 100) + self.assertEqual(array.indexes[1], 3) + self.assertEqual(array.get_indexes(1), 3) def test_logsum_arc_derivs(self): - data = torch.arange(10).reshape(5,2).to(torch.float) + data = torch.arange(10).reshape(5, 2).to(torch.float) indexes = torch.tensor([0, 2, 3, 5]).to(torch.int32) - self.assertEqual(data.shape[0],indexes[-1].item()) + self.assertEqual(data.shape[0], indexes[-1].item()) array = k2.LogSumArcDerivs(indexes, data) self.assertFalse(array.empty()) @@ -55,8 +80,24 @@ def test_logsum_arc_derivs(self): self.assertEqual(indexes.numel(), array.size1 + 1) self.assertEqual(data.shape[0], array.size2) + self.assertTrue(torch.equal(array.data[0], torch.FloatTensor([0, 1]))) - self.assertEqual(array.data(0), (0,1.0)) + # the underlying memory is shared between k2 and torch; + # so change one will change another + data[0] = torch.FloatTensor([100, 200]) + self.assertTrue( + torch.equal(array.data[0], torch.FloatTensor([100, 200]))) + self.assertEqual(array.get_data(0)[1], 200) + # we need pack and then unpack here to interpret arc_id (int) as a float, + # this is only for test purpose as users would usually never call + # `array.get_data` to retrieve data. Instead, it is supposed to call + # `array.data` to retrieve or update data in the array object. + arc_id = pack('i', array.get_data(0)[0]) + self.assertEqual(unpack('f', arc_id)[0], 100) + + del data + # the array in k2 is still accessible + self.assertEqual(array.get_data(0)[1], 200) if __name__ == '__main__': diff --git a/k2/python/tests/fsa_arcsort_test.py b/k2/python/tests/fsa_arcsort_test.py new file mode 100644 index 000000000..edf437a6a --- /dev/null +++ b/k2/python/tests/fsa_arcsort_test.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) +# +# See ../../../LICENSE for clarification regarding multiple authors + +# To run this single test, use +# +# ctest --verbose -R fsa_arcsort_test_py +# + +import unittest + +import torch + +import k2 + + +class TestArcSort(unittest.TestCase): + + def test_empty_fsa(self): + array_size = k2.IntArray2Size(0, 0) + fsa = k2.Fsa.create_fsa_with_size(array_size) + arc_map = k2.IntArray1.create_array_with_size(fsa.size2) + k2.arc_sort(fsa, arc_map) + self.assertTrue(fsa.empty()) + self.assertTrue(arc_map.empty()) + + # test without arc_map + k2.arc_sort(fsa) + self.assertTrue(fsa.empty()) + + def test_arc_sort(self): + s = r''' + 0 1 2 + 0 4 0 + 0 2 0 + 1 2 1 + 1 3 0 + 2 1 0 + 4 + ''' + + fsa = k2.str_to_fsa(s) + arc_map = k2.IntArray1.create_array_with_size(fsa.size2) + k2.arc_sort(fsa, arc_map) + expected_arc_indexes = torch.IntTensor([0, 3, 5, 6, 6, 6]) + expected_arcs = torch.IntTensor([[0, 2, 0], [0, 4, 0], [0, 1, 2], + [1, 3, 0], [1, 2, 1], [2, 1, 0]]) + expected_arc_map = torch.IntTensor([2, 1, 0, 4, 3, 5]) + self.assertTrue(torch.equal(fsa.indexes, expected_arc_indexes)) + self.assertTrue(torch.equal(fsa.data, expected_arcs)) + self.assertTrue(torch.equal(arc_map.data, expected_arc_map)) + + +class TestArcSorter(unittest.TestCase): + + def test_empty_fsa(self): + array_size = k2.IntArray2Size(0, 0) + fsa = k2.Fsa.create_fsa_with_size(array_size) + sorter = k2.ArcSorter(fsa) + array_size = k2.IntArray2Size() + sorter.get_sizes(array_size) + fsa_out = k2.Fsa.create_fsa_with_size(array_size) + arc_map = k2.IntArray1.create_array_with_size(fsa.size2) + sorter.get_output(fsa_out, arc_map) + self.assertTrue(fsa.empty()) + self.assertTrue(arc_map.empty()) + + # test without arc_map + sorter.get_output(fsa_out) + self.assertTrue(fsa.empty()) + + def test_arc_sort(self): + s = r''' + 0 1 2 + 0 4 0 + 0 2 0 + 1 2 1 + 1 3 0 + 2 1 0 + 4 + ''' + + fsa = k2.str_to_fsa(s) + sorter = k2.ArcSorter(fsa) + array_size = k2.IntArray2Size() + sorter.get_sizes(array_size) + fsa_out = k2.Fsa.create_fsa_with_size(array_size) + arc_map = k2.IntArray1.create_array_with_size(fsa.size2) + sorter.get_output(fsa_out, arc_map) + expected_arc_indexes = torch.IntTensor([0, 3, 5, 6, 6, 6]) + expected_arcs = torch.IntTensor([[0, 2, 0], [0, 4, 0], [0, 1, 2], + [1, 3, 0], [1, 2, 1], [2, 1, 0]]) + expected_arc_map = torch.IntTensor([2, 1, 0, 4, 3, 5]) + self.assertTrue(torch.equal(fsa_out.indexes, expected_arc_indexes)) + self.assertTrue(torch.equal(fsa_out.data, expected_arcs)) + self.assertTrue(torch.equal(arc_map.data, expected_arc_map)) + + +if __name__ == '__main__': + unittest.main() diff --git a/k2/python/tests/fsa_test.py b/k2/python/tests/fsa_test.py index c2a5716a9..5f98dd173 100644 --- a/k2/python/tests/fsa_test.py +++ b/k2/python/tests/fsa_test.py @@ -11,17 +11,33 @@ import unittest +import torch + import k2 class TestFsa(unittest.TestCase): def test_arc(self): + # construct arc arc = k2.Arc(1, 2, 3) self.assertEqual(arc.src_state, 1) self.assertEqual(arc.dest_state, 2) self.assertEqual(arc.label, 3) + # test from_tensor + arc_tensor = torch.tensor([1, 2, 3], dtype=torch.int32) + arc = k2.Arc.from_tensor(arc_tensor) + self.assertEqual(arc.src_state, 1) + self.assertEqual(arc.dest_state, 2) + self.assertEqual(arc.label, 3) + + # test to_tensor + arc.src_state = 2 + arc_tensor = arc.to_tensor() + arc_tensor_target = torch.tensor([2, 2, 3], dtype=torch.int32) + self.assertTrue(torch.equal(arc_tensor, arc_tensor_target)) + def test_fsa(self): s = r''' 0 1 1 @@ -37,10 +53,13 @@ def test_fsa(self): self.assertEqual(fsa.final_state(), 4) self.assertFalse(fsa.empty()) self.assertIsInstance(fsa, k2.Fsa) - self.assertIsInstance(fsa.data(0), k2.Arc); - self.assertEqual(fsa.data(0).src_state, 0) - self.assertEqual(fsa.data(0).dest_state, 1) - self.assertEqual(fsa.data(0).label, 1) + # test get_data + self.assertEqual(fsa.get_data(0).src_state, 0) + self.assertEqual(fsa.get_data(0).dest_state, 1) + self.assertEqual(fsa.get_data(0).label, 1) + # fsa.data and the corresponding k2::Fsa object are sharing memory + fsa.data[0] = torch.IntTensor([5, 1, 6]) + self.assertEqual(fsa.get_data(0).src_state, 5) if __name__ == '__main__':