Skip to content

Commit

Permalink
rewrite pybind interface for Array2 and Fsa
Browse files Browse the repository at this point in the history
  • Loading branch information
qindazhu committed Jul 22, 2020
1 parent 0667a60 commit ea62216
Show file tree
Hide file tree
Showing 20 changed files with 522 additions and 68 deletions.
11 changes: 7 additions & 4 deletions k2/csrc/arcsort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ void ArcSorter::GetSizes(Array2Size<int32_t> *fsa_size) const {
fsa_size->size2 = fsa_in_.size2;
}

void ArcSorter::GetOutput(Fsa *fsa_out, int32_t *arc_map /*= nullptr*/) {
void ArcSorter::GetOutput(Fsa *fsa_out,
Array1<int32_t *> *arc_map /*= nullptr*/) {
CHECK_NOTNULL(fsa_out);
CHECK_EQ(fsa_out->size1, fsa_in_.size1);
CHECK_EQ(fsa_out->size2, fsa_in_.size2);
Expand Down Expand Up @@ -50,10 +51,11 @@ void ArcSorter::GetOutput(Fsa *fsa_out, int32_t *arc_map /*= nullptr*/) {
}
fsa_out->indexes[num_states] = num_arcs;

if (arc_map != nullptr) std::copy(indexes.begin(), indexes.end(), arc_map);
if (arc_map != nullptr)
std::copy(indexes.begin(), indexes.end(), arc_map->data);
}

void ArcSort(Fsa *fsa, int32_t *arc_map /*= nullptr*/) {
void ArcSort(Fsa *fsa, Array1<int32_t *> *arc_map /*= nullptr*/) {
CHECK_NOTNULL(fsa);

std::vector<int32_t> indexes(fsa->size2);
Expand All @@ -74,7 +76,8 @@ void ArcSort(Fsa *fsa, int32_t *arc_map /*= nullptr*/) {
[](const Arc &left, const Arc &right) { return left < right; });
}

if (arc_map != nullptr) std::copy(indexes.begin(), indexes.end(), arc_map);
if (arc_map != nullptr)
std::copy(indexes.begin(), indexes.end(), arc_map->data);
}

} // namespace k2
9 changes: 5 additions & 4 deletions k2/csrc/arcsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <vector>

#include "glog/logging.h"
#include "k2/csrc/array.h"
#include "k2/csrc/fsa.h"

namespace k2 {
Expand Down Expand Up @@ -43,19 +44,19 @@ class ArcSorter {
@param [out] arc_map If non-NULL, will output a map from the arc-index
in `fsa_out` to the corresponding arc-index in
`fsa_in`.
If non-NULL, at entry it must be allocated with
If non-NULL, at entry it must be initialized with
size num-arcs of `fsa_out`, e.g. `fsa_out->size2`.
*/
void GetOutput(Fsa *fsa_out, int32_t *arc_map = nullptr);
void GetOutput(Fsa *fsa_out, Array1<int32_t *> *arc_map = nullptr);

private:
const Fsa &fsa_in_;
};

// In-place version of ArcSorter; see its documentation;
// Note that if `arc_map` is non-NULL, then at entry it must be allocated with
// Note that if `arc_map` is non-NULL, then at entry it must be initialized with
// size num-arcs of `fsa`, e.g. `fsa->size2`
void ArcSort(Fsa *fsa, int32_t *arc_map = nullptr);
void ArcSort(Fsa *fsa, Array1<int32_t *> *arc_map = nullptr);

} // namespace k2

Expand Down
12 changes: 8 additions & 4 deletions k2/csrc/arcsort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ TEST(ArcSortTest, ArcSorter) {
FsaCreator fsa_creator_out(fsa_size);
auto &arc_sorted = fsa_creator_out.GetFsa();
std::vector<int32_t> arc_map(fsa_size.size2);
sorter.GetOutput(&arc_sorted, arc_map.data());
Array1<int32_t *> arc_map_array1(arc_map.size(), arc_map.data());
sorter.GetOutput(&arc_sorted, &arc_map_array1);

EXPECT_TRUE(IsEmpty(arc_sorted));
EXPECT_TRUE(arc_map.empty());
Expand All @@ -52,7 +53,8 @@ TEST(ArcSortTest, ArcSorter) {
FsaCreator fsa_creator_out(fsa_size);
auto &arc_sorted = fsa_creator_out.GetFsa();
std::vector<int32_t> arc_map(fsa_size.size2);
sorter.GetOutput(&arc_sorted, arc_map.data());
Array1<int32_t *> arc_map_array1(arc_map.size(), arc_map.data());
sorter.GetOutput(&arc_sorted, &arc_map_array1);

EXPECT_FALSE(arc_map.empty());
EXPECT_TRUE(IsArcSorted(arc_sorted));
Expand Down Expand Up @@ -86,7 +88,8 @@ TEST(ArcSortTest, ArcSort) {
FsaCreator fsa_creator;
auto &fsa = fsa_creator.GetFsa();
std::vector<int32_t> arc_map(fsa.size2);
ArcSort(&fsa, arc_map.data());
Array1<int32_t *> arc_map_array1(arc_map.size(), arc_map.data());
ArcSort(&fsa, &arc_map_array1);

EXPECT_TRUE(IsEmpty(fsa));
EXPECT_TRUE(arc_map.empty());
Expand All @@ -99,7 +102,8 @@ TEST(ArcSortTest, ArcSort) {
FsaCreator fsa_creator(src_arcs, 4);
auto &fsa = fsa_creator.GetFsa();
std::vector<int32_t> arc_map(fsa.size2);
ArcSort(&fsa, arc_map.data());
Array1<int32_t *> arc_map_array1(arc_map.size(), arc_map.data());
ArcSort(&fsa, &arc_map_array1);

EXPECT_TRUE(IsArcSorted(fsa));

Expand Down
43 changes: 39 additions & 4 deletions k2/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,43 @@ struct StridedPtr {
}
};

template <typename Ptr, typename I = int32_t>
struct Array1 {
// One dimensional array of something, like vector<X>
// where Ptr is, or behaves like, X*.
using IndexT = I;
using PtrT = Ptr;
using ValueType = typename std::iterator_traits<Ptr>::value_type;

Array1() : begin(0), end(0), size(0), data(nullptr) {}
Array1(IndexT begin, IndexT end, PtrT data)
: begin(begin), end(end), data(data) {
CHECK_GE(end, begin);
this->size = end - begin;
}
Array1(IndexT size, PtrT data) : begin(0), end(size), size(size), data(data) {
CHECK_GE(size, 0);
}
void Init(IndexT begin, IndexT end, PtrT data) {
CHECK_GE(end, begin);
this->begin = begin;
this->end = end;
this->size = end - begin;
this->data = data;
}
bool Empty() const { return begin == end; }

// 'begin' and 'end' are the first and one-past-the-last indexes into `data`
// that we are allowed to use.
IndexT begin;
IndexT end;
IndexT size; // the number of elements in `data` that can be accessed, equals
// to `end - begin`
PtrT data;

private:
};

/*
This struct stores the size of an Array2 object; it will generally be used as
an output argument by functions that work out this size.
Expand Down Expand Up @@ -293,12 +330,10 @@ struct Array2Storage {
Array2Storage(const Array2Size<I> &array2_size, I stride)
: indexes_storage_(new I[array2_size.size1 + 1]),
data_storage_(new ValueType[array2_size.size2 * stride]) {
array_.size1 = array2_size.size1;
array_.size2 = array2_size.size2;
array_.indexes = indexes_storage_.get();
array_.Init(array2_size.size1, array2_size.size2, indexes_storage_.get(),
DataPtrCreator<Ptr, I>::Create(data_storage_, stride));
// just for case of empty Array2 object, may be written by the caller
array_.indexes[0] = 0;
array_.data = DataPtrCreator<Ptr, I>::Create(data_storage_, stride);
}

void FillIndexes(const std::vector<I> &indexes) {
Expand Down
8 changes: 6 additions & 2 deletions k2/csrc/fsa_equivalent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ static void ArcSort(const k2::Fsa &fsa_in, k2::FsaCreator *fsa_out,

fsa_out->Init(fsa_size);
auto &sorted_fsa = fsa_out->GetFsa();
if (arc_map != nullptr) arc_map->resize(fsa_size.size2);
sorter.GetOutput(&sorted_fsa, arc_map == nullptr ? nullptr : arc_map->data());
k2::Array1<int32_t *> arc_map_array1;
if (arc_map != nullptr) {
arc_map->resize(fsa_size.size2);
arc_map_array1.Init(0, arc_map->size(), arc_map->data());
}
sorter.GetOutput(&sorted_fsa, arc_map == nullptr ? nullptr : &arc_map_array1);
}

/*
Expand Down
1 change: 1 addition & 0 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pybind11_add_module(_k2
array.cc
fsa.cc
fsa_algo.cc
fsa_util.cc
k2.cc
tensor.cc
Expand Down
110 changes: 95 additions & 15 deletions k2/python/csrc/array.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// k2/python/csrc/array.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors
Expand All @@ -10,9 +8,39 @@
#include <utility>

#include "k2/csrc/array.h"
#include "k2/csrc/determinize_impl.h"
#include "k2/python/csrc/tensor.h"

namespace k2 {

/*
DLPackArray1 initializes Array1 with `cap_data` which is a DLManagedTensor.
`cap_data` is usually a one dimensional array with stride >= 1, i.e.,
`cap_data.ndim == 1 && cap_indexes.strides[0] >= 1`.
*/
template <typename ValueType, typename I>
class DLPackArray1;

template <typename ValueType, typename I>
class DLPackArray1<ValueType *, I> : public Array1<ValueType *, I> {
public:
explicit DLPackArray1(py::capsule cap_data)
: data_tensor_(new Tensor(cap_data)) {
CHECK_EQ(data_tensor_->NumDim(), 1);
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements
CHECK_EQ(data_tensor_->Stride(0), 1);

int32_t size = data_tensor_->Shape(0);
this->Init(0, size, data_tensor_->Data<ValueType>());
}

private:
std::unique_ptr<Tensor> data_tensor_;
};
// Note: we can specialized for `StridedPtr` later if we need it,
// `cap_data.strides[0]` will be greater than 1 in that case.

/*
DLPackArray2 initializes Array2 with `cap_indexes` and `cap_data` which are
DLManagedTensors.
Expand Down Expand Up @@ -100,11 +128,32 @@ class DLPackArray2<ValueType *, false, I> : public Array2<ValueType *, I> {
std::unique_ptr<Tensor> indexes_tensor_;
std::unique_ptr<Tensor> data_tensor_;
};

// Note: we can specialized for `StridedPtr` later if we need it.

} // namespace k2

template <typename Ptr, typename I = int32_t>
void PybindArray1Tpl(py::module &m, const char *name) {
using PyClass = k2::DLPackArray1<Ptr, I>;
using Parent = k2::Array1<Ptr, I>;
py::class_<PyClass, Parent>(m, name)
.def(py::init<py::capsule>(), py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"get_base",
[](PyClass &self) { return reinterpret_cast<Parent *>(&self); },
py::return_value_policy::reference_internal)
.def_readonly("size", &PyClass::size)
.def(
"get_data",
[](const PyClass &self, I i) {
if (i >= self.size) throw py::index_error();
return self.data[self.begin + i];
},
"just for test purpose to check if k2::Array1 and the "
"underlying tensor are sharing memory.");
}

template <typename Ptr, bool IsPrimitive, typename I = int32_t>
void PybindArray2Tpl(py::module &m, const char *name) {
using PyClass = k2::DLPackArray2<Ptr, IsPrimitive, I>;
Expand All @@ -114,28 +163,59 @@ void PybindArray2Tpl(py::module &m, const char *name) {
py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"__iter__",
[](const PyClass &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>())
"get_base",
[](PyClass &self) { return reinterpret_cast<Parent *>(&self); },
py::return_value_policy::reference_internal)
.def_readonly("size1", &PyClass::size1)
.def_readonly("size2", &PyClass::size2)
.def("indexes", [](const PyClass &self, I i) { return self.indexes[i]; })
.def("data", [](const PyClass &self, I i) { return self.data[i]; });
// TODO(haowen): expose `indexes` and `data` as an array
// instead of a function call?
.def(
"get_indexes",
[](const PyClass &self, I i) {
if (i > self.size1) // note indexes.size == size1+1
throw py::index_error();
return self.indexes[i];
},
"just for test purpose to check if k2::Array1 and the "
"underlying tensor are sharing memory.")
.def(
"get_data",
[](const PyClass &self, I i) {
if (i >= self.size2) throw py::index_error();
return self.data[self.indexes[0] + i];
},
"just for test purpose to check if k2::Array1 and the "
"underlying tensor are sharing memory.");
}

template <typename I>
void PybindArray2SizeTpl(py::module &m, const char *name) {
using PyClass = k2::Array2Size<I>;
py::class_<PyClass>(m, name)
.def(py::init<>())
.def(py::init<int32_t, int32_t>(), py::arg("size1"), py::arg("size2"))
.def_readwrite("size1", &PyClass::size1)
.def_readwrite("size2", &PyClass::size2);
}

void PybindArray(py::module &m) {
// Note: all the following wrappers whose name starts with `_` are only used
// by pybind11 internally so that it knows `k2::DLPackArray1` is a subclass of
// `k2::Array1`.
py::class_<k2::Array1<int32_t *>>(m, "_IntArray1");
PybindArray1Tpl<int32_t *>(m, "DLPackIntArray1");

// Note: all the following wrappers whose name starts with `_` are only used
// by pybind11 internally so that it knows `k2::DLPackArray2` is a subclass of
// `k2::Array2`.
py::class_<k2::Array2<int32_t *>>(m, "_IntArray2");
PybindArray2Tpl<int32_t *, true>(m, "DLPackIntArray2");

// note there is a type cast as the underlying Tensor is with type `float`
py::class_<k2::Array2<std::pair<int32_t, float> *>>(m, "_LogSumArcDerivs");
PybindArray2Tpl<std::pair<int32_t, float> *, false>(m,
"DLPackLogSumArcDerivs");
using LogSumDerivType = typename k2::LogSumTracebackState::DerivType;
py::class_<k2::Array2<LogSumDerivType *>>(m, "_LogSumArcDerivs");
PybindArray2Tpl<LogSumDerivType *, false>(m, "DLPackLogSumArcDerivs");
}

void PybindArray2Size(py::module &m) {
PybindArray2SizeTpl<int32_t>(m, "IntArray2Size");
}
1 change: 1 addition & 0 deletions k2/python/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
#include "k2/python/csrc/k2.h"

void PybindArray(py::module &m);
void PybindArray2Size(py::module &m);

#endif // K2_PYTHON_CSRC_ARRAY_H_
Loading

0 comments on commit ea62216

Please sign in to comment.