Skip to content

Commit

Permalink
rewrite pybind interface for Array2 and Fsa
Browse files Browse the repository at this point in the history
  • Loading branch information
qindazhu committed Jul 22, 2020
1 parent 0667a60 commit 4fa6ba9
Show file tree
Hide file tree
Showing 18 changed files with 512 additions and 55 deletions.
1 change: 1 addition & 0 deletions k2/csrc/arcsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <vector>

#include "glog/logging.h"
#include "k2/csrc/array.h"
#include "k2/csrc/fsa.h"

namespace k2 {
Expand Down
41 changes: 37 additions & 4 deletions k2/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@ struct StridedPtr {
}
};

template <typename Ptr, typename I = int32_t>
struct Array1 {
// One dimensional array of something, like vector<X>
// where Ptr is, or behaves like, X*.
using IndexT = I;
using PtrT = Ptr;
using ValueType = typename std::iterator_traits<Ptr>::value_type;

Array1() : begin(0), end(0), size(0), data(nullptr) {}
Array1(IndexT begin, IndexT end, PtrT data)
: begin(begin), end(end), data(data) {
CHECK_GE(end, begin);
this->size = end - begin;
}
Array1(IndexT size, PtrT data) : begin(0), end(size), size(size), data(data) {
CHECK_GE(size, 0);
}
void Init(IndexT begin, IndexT end, PtrT data) {
CHECK_GE(end, begin);
this->begin = begin;
this->end = end;
this->size = end - begin;
this->data = data;
}
bool Empty() const { return begin == end; }

// 'begin' and 'end' are the first and one-past-the-last indexes into `data`
// that we are allowed to use.
IndexT begin;
IndexT end;
IndexT size; // the number of elements in `data` that can be accessed, equals
// to `end - begin`
PtrT data;
};

/*
This struct stores the size of an Array2 object; it will generally be used as
an output argument by functions that work out this size.
Expand Down Expand Up @@ -293,12 +328,10 @@ struct Array2Storage {
Array2Storage(const Array2Size<I> &array2_size, I stride)
: indexes_storage_(new I[array2_size.size1 + 1]),
data_storage_(new ValueType[array2_size.size2 * stride]) {
array_.size1 = array2_size.size1;
array_.size2 = array2_size.size2;
array_.indexes = indexes_storage_.get();
array_.Init(array2_size.size1, array2_size.size2, indexes_storage_.get(),
DataPtrCreator<Ptr, I>::Create(data_storage_, stride));
// just for case of empty Array2 object, may be written by the caller
array_.indexes[0] = 0;
array_.data = DataPtrCreator<Ptr, I>::Create(data_storage_, stride);
}

void FillIndexes(const std::vector<I> &indexes) {
Expand Down
4 changes: 3 additions & 1 deletion k2/csrc/fsa_equivalent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ static void ArcSort(const k2::Fsa &fsa_in, k2::FsaCreator *fsa_out,

fsa_out->Init(fsa_size);
auto &sorted_fsa = fsa_out->GetFsa();
if (arc_map != nullptr) arc_map->resize(fsa_size.size2);
if (arc_map != nullptr) {
arc_map->resize(fsa_size.size2);
}
sorter.GetOutput(&sorted_fsa, arc_map == nullptr ? nullptr : arc_map->data());
}

Expand Down
1 change: 1 addition & 0 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pybind11_add_module(_k2
array.cc
fsa.cc
fsa_algo.cc
fsa_util.cc
k2.cc
tensor.cc
Expand Down
110 changes: 95 additions & 15 deletions k2/python/csrc/array.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// k2/python/csrc/array.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors
Expand All @@ -10,9 +8,39 @@
#include <utility>

#include "k2/csrc/array.h"
#include "k2/csrc/determinize_impl.h"
#include "k2/python/csrc/tensor.h"

namespace k2 {

/*
DLPackArray1 initializes Array1 with `cap_data` which is a DLManagedTensor.
`cap_data` is usually a one dimensional array with stride >= 1, i.e.,
`cap_data.ndim == 1 && cap_indexes.strides[0] >= 1`.
*/
template <typename ValueType, typename I>
class DLPackArray1;

template <typename ValueType, typename I>
class DLPackArray1<ValueType *, I> : public Array1<ValueType *, I> {
public:
explicit DLPackArray1(py::capsule cap_data)
: data_tensor_(new Tensor(cap_data)) {
CHECK_EQ(data_tensor_->NumDim(), 1);
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements
CHECK_EQ(data_tensor_->Stride(0), 1);

int32_t size = data_tensor_->Shape(0);
this->Init(0, size, data_tensor_->Data<ValueType>());
}

private:
std::unique_ptr<Tensor> data_tensor_;
};
// Note: we can specialized for `StridedPtr` later if we need it,
// `cap_data.strides[0]` will be greater than 1 in that case.

/*
DLPackArray2 initializes Array2 with `cap_indexes` and `cap_data` which are
DLManagedTensors.
Expand Down Expand Up @@ -100,11 +128,32 @@ class DLPackArray2<ValueType *, false, I> : public Array2<ValueType *, I> {
std::unique_ptr<Tensor> indexes_tensor_;
std::unique_ptr<Tensor> data_tensor_;
};

// Note: we can specialized for `StridedPtr` later if we need it.

} // namespace k2

template <typename Ptr, typename I = int32_t>
void PybindArray1Tpl(py::module &m, const char *name) {
using PyClass = k2::DLPackArray1<Ptr, I>;
using Parent = k2::Array1<Ptr, I>;
py::class_<PyClass, Parent>(m, name)
.def(py::init<py::capsule>(), py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"get_base",
[](PyClass &self) { return reinterpret_cast<Parent *>(&self); },
py::return_value_policy::reference_internal)
.def_readonly("size", &PyClass::size)
.def(
"get_data",
[](const PyClass &self, I i) {
if (i >= self.size) throw py::index_error();
return self.data[self.begin + i];
},
"just for test purpose to check if k2::Array1 and the "
"underlying tensor are sharing memory.");
}

template <typename Ptr, bool IsPrimitive, typename I = int32_t>
void PybindArray2Tpl(py::module &m, const char *name) {
using PyClass = k2::DLPackArray2<Ptr, IsPrimitive, I>;
Expand All @@ -114,28 +163,59 @@ void PybindArray2Tpl(py::module &m, const char *name) {
py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"__iter__",
[](const PyClass &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>())
"get_base",
[](PyClass &self) { return reinterpret_cast<Parent *>(&self); },
py::return_value_policy::reference_internal)
.def_readonly("size1", &PyClass::size1)
.def_readonly("size2", &PyClass::size2)
.def("indexes", [](const PyClass &self, I i) { return self.indexes[i]; })
.def("data", [](const PyClass &self, I i) { return self.data[i]; });
// TODO(haowen): expose `indexes` and `data` as an array
// instead of a function call?
.def(
"get_indexes",
[](const PyClass &self, I i) {
if (i > self.size1) // note indexes.size == size1+1
throw py::index_error();
return self.indexes[i];
},
"just for test purpose to check if k2::Array2 and the "
"underlying tensor are sharing memory.")
.def(
"get_data",
[](const PyClass &self, I i) {
if (i >= self.size2) throw py::index_error();
return self.data[self.indexes[0] + i];
},
"just for test purpose to check if k2::Array2 and the "
"underlying tensor are sharing memory.");
}

template <typename I>
void PybindArray2SizeTpl(py::module &m, const char *name) {
using PyClass = k2::Array2Size<I>;
py::class_<PyClass>(m, name)
.def(py::init<>())
.def(py::init<int32_t, int32_t>(), py::arg("size1"), py::arg("size2"))
.def_readwrite("size1", &PyClass::size1)
.def_readwrite("size2", &PyClass::size2);
}

void PybindArray(py::module &m) {
// Note: all the following wrappers whose name starts with `_` are only used
// by pybind11 internally so that it knows `k2::DLPackArray1` is a subclass of
// `k2::Array1`.
py::class_<k2::Array1<int32_t *>>(m, "_IntArray1");
PybindArray1Tpl<int32_t *>(m, "DLPackIntArray1");

// Note: all the following wrappers whose name starts with `_` are only used
// by pybind11 internally so that it knows `k2::DLPackArray2` is a subclass of
// `k2::Array2`.
py::class_<k2::Array2<int32_t *>>(m, "_IntArray2");
PybindArray2Tpl<int32_t *, true>(m, "DLPackIntArray2");

// note there is a type cast as the underlying Tensor is with type `float`
py::class_<k2::Array2<std::pair<int32_t, float> *>>(m, "_LogSumArcDerivs");
PybindArray2Tpl<std::pair<int32_t, float> *, false>(m,
"DLPackLogSumArcDerivs");
using LogSumDerivType = typename k2::LogSumTracebackState::DerivType;
py::class_<k2::Array2<LogSumDerivType *>>(m, "_LogSumArcDerivs");
PybindArray2Tpl<LogSumDerivType *, false>(m, "DLPackLogSumArcDerivs");
}

void PybindArray2Size(py::module &m) {
PybindArray2SizeTpl<int32_t>(m, "IntArray2Size");
}
1 change: 1 addition & 0 deletions k2/python/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
#include "k2/python/csrc/k2.h"

void PybindArray(py::module &m);
void PybindArray2Size(py::module &m);

#endif // K2_PYTHON_CSRC_ARRAY_H_
39 changes: 26 additions & 13 deletions k2/python/csrc/fsa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

namespace k2 {

// it uses external memory passed from DLPack (e.g., by PyTorch)
// to construct an Fsa.
// DLPackFsa initializes Fsa with `cap_indexes` and `cap_data` which are
// DLManagedTensors.
class DLPackFsa : public Fsa {
public:
DLPackFsa(py::capsule cap_indexes, py::capsule cap_data)
Expand Down Expand Up @@ -50,7 +50,7 @@ class DLPackFsa : public Fsa {

void PybindArc(py::module &m) {
using PyClass = k2::Arc;
py::class_<PyClass>(m, "Arc")
py::class_<PyClass>(m, "_Arc")
.def(py::init<>())
.def(py::init<int32_t, int32_t, int32_t>(), py::arg("src_state"),
py::arg("dest_state"), py::arg("label"))
Expand All @@ -70,21 +70,34 @@ void PybindFsa(py::module &m) {
py::class_<k2::Fsa>(m, "_Fsa");

using PyClass = k2::DLPackFsa;
py::class_<PyClass, k2::Fsa>(m, "DLPackFsa")
using Parent = k2::Fsa;
py::class_<PyClass, Parent>(m, "DLPackFsa")
.def(py::init<py::capsule, py::capsule>(), py::arg("indexes"),
py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"__iter__",
[](const PyClass &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>())
"get_base",
[](PyClass &self) { return reinterpret_cast<Parent *>(&self); },
py::return_value_policy::reference_internal)
.def("empty", &PyClass::Empty)
.def_readonly("size1", &PyClass::size1)
.def_readonly("size2", &PyClass::size2)
.def("indexes",
[](const PyClass &self, int32_t i) { return self.indexes[i]; })
.def("data", [](const PyClass &self, int32_t i) { return self.data[i]; })
.def(
"get_indexes",
[](const PyClass &self, int32_t i) {
if (i > self.size1) // note indexes.size == size1+1
throw py::index_error();
return self.indexes[i];
},
"just for test purpose to check if k2::Fsa and the "
"underlying tensor are sharing memory.")
.def(
"get_data",
[](const PyClass &self, int32_t i) {
if (i >= self.size2) throw py::index_error();
return self.data[self.indexes[0] + i];
},
"just for test purpose to check if k2::Fsa and the "
"underlying tensor are sharing memory.")
.def("num_states", &PyClass::NumStates)
.def("final_state", &PyClass::FinalState);
}
42 changes: 42 additions & 0 deletions k2/python/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// k2/python/csrc/fsa_algo.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

#include "k2/python/csrc/fsa_algo.h"

#include <memory>
#include <utility>

#include "k2/csrc/arcsort.h"
#include "k2/csrc/array.h"
#include "k2/python/csrc/array.h"

namespace k2 {} // namespace k2

void PyBindArcSort(py::module &m) {
using PyClass = k2::ArcSorter;
py::class_<PyClass>(m, "_ArcSorter")
.def(py::init<const k2::Fsa &>(), py::arg("fsa_in"))
.def("get_sizes", &PyClass::GetSizes, py::arg("fsa_size"))
.def(
"get_output",
[](PyClass &self, k2::Fsa *fsa_out,
k2::Array1<int32_t *> *arc_map = nullptr) {
self.GetOutput(fsa_out,
arc_map == nullptr ? nullptr : arc_map->data);
},
py::arg("fsa_out"),
py::arg("arc_map") = (k2::Array1<int32_t *> *)nullptr);

m.def(
"_arc_sort",
[](k2::Fsa *fsa, k2::Array1<int32_t *> *arc_map = nullptr) {
k2::ArcSort(fsa, arc_map == nullptr ? nullptr : arc_map->data);
},
"in-place version of ArcSorter", py::arg("fsa"),
py::arg("arc_map") = (k2::Array1<int32_t *> *)nullptr);
}

void PybindFsaAlgo(py::module &m) { PyBindArcSort(m); }
14 changes: 14 additions & 0 deletions k2/python/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// k2/python/csrc/fsa_algo.h

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

#ifndef K2_PYTHON_CSRC_FSA_ALGO_H_
#define K2_PYTHON_CSRC_FSA_ALGO_H_

#include "k2/python/csrc/k2.h"

void PybindFsaAlgo(py::module &m);

#endif // K2_PYTHON_CSRC_FSA_ALGO_H_
3 changes: 3 additions & 0 deletions k2/python/csrc/k2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@

#include "k2/python/csrc/array.h"
#include "k2/python/csrc/fsa.h"
#include "k2/python/csrc/fsa_algo.h"
#include "k2/python/csrc/fsa_util.h"

PYBIND11_MODULE(_k2, m) {
m.doc() = "pybind11 binding of k2";
PybindArc(m);
PybindArray(m);
PybindArray2Size(m);
PybindFsa(m);
PybindFsaUtil(m);
PybindFsaAlgo(m);
}
3 changes: 2 additions & 1 deletion k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from _k2 import Arc
from _k2 import IntArray2Size
from .array import *
from .fsa import *
from .fsa_algo import *
from .fsa_util import str_to_fsa
Loading

0 comments on commit 4fa6ba9

Please sign in to comment.