Skip to content

Commit

Permalink
wrap k2::Fsa and k2::Arc to Python. (#54)
Browse files Browse the repository at this point in the history
* wrap k2::Array1 to Python.

* fix cpplint.

* construct a k2::Fsa_ from a string in Python.

* fix cpplint.

* resolve some comments.

* add python test cases for stride==3

* rebase.

* wrap k2::Fsa and k2::Arc.

* fix a typo.
  • Loading branch information
csukuangfj authored Jul 17, 2020
1 parent 53e3317 commit d1754fb
Show file tree
Hide file tree
Showing 20 changed files with 396 additions and 477 deletions.
9 changes: 4 additions & 5 deletions k2/csrc/arcsort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "k2/csrc/fsa.h"

namespace k2 {
void ArcSorter::GetSizes(Array2Size<int32_t> *fsa_size) {
void ArcSorter::GetSizes(Array2Size<int32_t> *fsa_size) const {
CHECK_NOTNULL(fsa_size);
fsa_size->size1 = fsa_in_.size1;
fsa_size->size2 = fsa_in_.size2;
Expand All @@ -39,10 +39,9 @@ void ArcSorter::GetOutput(Fsa *fsa_out, int32_t *arc_map /*= nullptr*/) {
fsa_out->indexes[state] = num_arcs;
int32_t begin = fsa_in_.indexes[state] - arc_begin_index;
int32_t end = fsa_in_.indexes[state + 1] - arc_begin_index;
std::sort(indexes.begin() + begin, indexes.begin() + end,
[&arcs_in, arc_begin_index](int32_t i, int32_t j) {
return arcs_in[i] < arcs_in[j];
});
std::sort(
indexes.begin() + begin, indexes.begin() + end,
[&arcs_in](int32_t i, int32_t j) { return arcs_in[i] < arcs_in[j]; });
// copy sorted arcs to `fsa_out`
std::transform(indexes.begin() + begin, indexes.begin() + end,
fsa_out->data + num_arcs,
Expand Down
4 changes: 2 additions & 2 deletions k2/csrc/arcsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class ArcSorter {
explicit ArcSorter(const Fsa &fsa_in) : fsa_in_(fsa_in) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
*/
void GetSizes(Array2Size<int32_t> *fsa_size);
void GetSizes(Array2Size<int32_t> *fsa_size) const;

/*
Finish the operation and output the arc-sorted FSA to `fsa_out` and
Expand Down
4 changes: 2 additions & 2 deletions k2/csrc/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ template <typename T, typename I = int32_t>
struct StridedPtr {
T *data; // it is NOT owned here
I stride; // in number of elements, NOT number of bytes
StridedPtr(T *data = nullptr, I stride = 0) // NOLINT
explicit StridedPtr(T *data = nullptr, I stride = 0) // NOLINT
: data(data), stride(stride) {}

T &operator[](I i) { return data[i * stride]; }
Expand Down Expand Up @@ -290,7 +290,7 @@ struct DataPtrCreator<StridedPtr<ValueType, I>, I> {
template <typename Ptr, typename I>
struct Array2Storage {
using ValueType = typename Array2<Ptr, I>::ValueType;
explicit Array2Storage(const Array2Size<I> &array2_size, I stride)
Array2Storage(const Array2Size<I> &array2_size, I stride)
: indexes_storage_(new I[array2_size.size1 + 1]),
data_storage_(new ValueType[array2_size.size2 * stride]) {
array_.size1 = array2_size.size1;
Expand Down
1 change: 0 additions & 1 deletion k2/python/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
add_subdirectory(csrc)
add_subdirectory(tests)
add_subdirectory(tutorials)
12 changes: 6 additions & 6 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# sort the files alphabetically
pybind11_add_module(k2
#fsa.cc
fsa_renderer.cc
# please sort the files alphabetically
pybind11_add_module(_k2
fsa.cc
fsa_util.cc
k2.cc
tensor.cc
)

target_include_directories(k2 PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(k2 PRIVATE fsa)
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(_k2 PRIVATE fsa)
247 changes: 64 additions & 183 deletions k2/python/csrc/fsa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,192 +6,73 @@

#include "k2/python/csrc/fsa.h"

#include <vector>
#include <memory>

#include "k2/csrc/fsa.h"
#include "k2/csrc/fsa_util.h"
#include "k2/python/csrc/dlpack.h"

using k2::Arc;
using k2::Cfsa;
using k2::CfsaVec;
using k2::Fsa;

// refer to
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp#L375
// https://github.com/microsoft/onnxruntime-tvm/blob/master/python/tvm/_ffi/_ctypes/ndarray.py#L28
// https://github.com/cupy/cupy/blob/master/cupy/core/dlpack.pyx#L66
// PyTorch, TVM and CuPy name the created dltensor to be `dltensor`
static const char *kDLPackTensorName = "dltensor";

// refer to
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp#L402
// https://github.com/apache/incubator-tvm/blob/master/python/tvm/_ffi/_ctypes/ndarray.py#L29
// https://github.com/cupy/cupy/blob/master/cupy/core/dlpack.pyx#L62
// PyTorch, TVM and CuPy name the used dltensor to be `used_dltensor`
static const char *kDLPackUsedTensorName = "used_dltensor";

/*
static void DLPackDeleter(void *p) {
auto dl_managed_tensor = reinterpret_cast<DLManagedTensor *>(p);
if (dl_managed_tensor && dl_managed_tensor->deleter)
dl_managed_tensor->deleter(dl_managed_tensor);
// this will be invoked if you uncomment it, which
// means Python will indeed free the memory returned by the subsequent
// `CfsaVecFromDLPack()`.
//
// LOG(INFO) << "freed!";
#include "k2/python/csrc/tensor.h"

namespace k2 {

// it uses external memory passed from DLPack (e.g., by PyTorch)
// to construct an Fsa.
class _Fsa : public Fsa {
public:
_Fsa(py::capsule cap_indexes, py::capsule cap_data)
: indexes_tensor_(new Tensor(cap_indexes)),
data_tensor_(new Tensor(cap_data)) {
CHECK_EQ(indexes_tensor_->dtype(), kInt32Type);
CHECK_EQ(indexes_tensor_->NumDim(), 1);
CHECK_GT(indexes_tensor_->Shape(0), 1);
CHECK_EQ(indexes_tensor_->Stride(0), 1)
<< "Only contiguous index arrays are supported at present";

CHECK_EQ(data_tensor_->dtype(), kInt32Type);
CHECK_EQ(data_tensor_->NumDim(), 2);
CHECK_EQ(data_tensor_->Stride(1), 1)
<< "Only contiguous data arrays at supported at present";
CHECK_EQ(sizeof(Arc),
data_tensor_->Shape(1) * data_tensor_->BytesPerElement());

int32_t size1 = indexes_tensor_->Shape(0) - 1;
int32_t size2 = data_tensor_->Shape(0);
this->Init(size1, size2, indexes_tensor_->Data<int32_t>(),
data_tensor_->Data<Arc>());
}

private:
std::unique_ptr<Tensor> indexes_tensor_;
std::unique_ptr<Tensor> data_tensor_;
};

} // namespace k2

void PybindArc(py::module &m) {
using PyClass = k2::Arc;
py::class_<PyClass>(m, "Arc")
.def(py::init<>())
.def(py::init<int32_t, int32_t, int32_t>(), py::arg("src_state"),
py::arg("dest_state"), py::arg("label"))
.def_readwrite("src_state", &PyClass::src_state)
.def_readwrite("dest_state", &PyClass::dest_state)
.def_readwrite("label", &PyClass::label)
.def("__str__", [](const PyClass &self) {
std::ostringstream os;
os << self;
return os.str();
});
}

// the returned pointer is freed by Python
static CfsaVec *CfsaVecFromDLPack(py::capsule *capsule,
const std::vector<Cfsa> *cfsas = nullptr) {
// the following error message is modified from
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/Module.cpp#L384
CHECK_EQ(strcmp(kDLPackTensorName, capsule->name()), 0)
<< "Expected capsule name: " << kDLPackTensorName << "\n"
<< "But got: " << capsule->name() << "\n"
<< "Note that DLTensor capsules can be consumed only once,\n"
<< "so you might have already constructed a tensor from it once.";
PyCapsule_SetName(capsule->ptr(), kDLPackUsedTensorName);
DLManagedTensor *managed_tensor = *capsule;
// (fangjun): the above assignment will either throw or succeed with a
// non-null ptr; so no need to check for nullptr below
auto tensor = &managed_tensor->dl_tensor;
CHECK_EQ(tensor->ndim, 1) << "Expect 1-D tensor";
CHECK_EQ(tensor->dtype.code, kDLInt);
CHECK_EQ(tensor->dtype.bits, 32);
CHECK_EQ(tensor->dtype.lanes, 1);
CHECK_EQ(tensor->strides[0], 1); // memory should be contiguous
auto ctx = &tensor->ctx;
// TODO(fangjun): enable GPU once k2 supports GPU.
CHECK_EQ(ctx->device_type, kDLCPU);
auto start_ptr = reinterpret_cast<char *>(tensor->data) + tensor->byte_offset;
CHECK_EQ((intptr_t)start_ptr % sizeof(int32_t), 0);
if (cfsas)
CreateCfsaVec(*cfsas, start_ptr, tensor->shape[0] * sizeof(int32_t));
// no memory leak here; python will deallocate it
auto cfsa_vec = new CfsaVec(tensor->shape[0], start_ptr);
cfsa_vec->SetDeleter(&DLPackDeleter, managed_tensor);
return cfsa_vec;
}
static void PybindCfsaVec(py::module &m) {
m.def("get_cfsa_vec_size",
overload_cast_<const Cfsa &>()(&k2::GetCfsaVecSize), py::arg("cfsa"));
m.def("get_cfsa_vec_size",
overload_cast_<const std::vector<Cfsa> &>()(&k2::GetCfsaVecSize),
py::arg("cfsas"));
py::class_<CfsaVec>(m, "CfsaVec")
.def("num_fsas", &CfsaVec::NumFsas)
.def("__getitem__", [](const CfsaVec &self, int i) { return self[i]; },
py::keep_alive<0, 1>());
m.def("create_cfsa_vec",
[](py::capsule *capsule, const std::vector<Cfsa> *cfsas = nullptr) {
return CfsaVecFromDLPack(capsule, cfsas);
},
py::arg("dlpack"), py::arg("cfsas") = nullptr,
py::return_value_policy::take_ownership);
}
*/

void PybindFsa(py::module &m) {
/*
py::class_<Arc>(m, "Arc")
.def(py::init<>())
.def(py::init<int32_t, int32_t, int32_t>(), py::arg("src_state"),
py::arg("dest_state"), py::arg("label"))
.def_readwrite("src_state", &Arc::src_state)
.def_readwrite("dest_state", &Arc::dest_state)
.def_readwrite("label", &Arc::label)
.def("__str__", [](const Arc &self) {
std::ostringstream os;
os << self;
return os.str();
});
py::class_<Fsa>(m, "Fsa")
.def(py::init<>())
.def("num_states", &Fsa::NumStates)
.def("final_state", &Fsa::FinalState)
.def("__str__", [](const Fsa &self) { return FsaToString(self); })
.def_readwrite("arc_indexes", &Fsa::arc_indexes)
.def_readwrite("arcs", &Fsa::arcs);
py::class_<std::vector<Fsa>>(m, "FsaVec")
.def(py::init<>())
.def("clear", &std::vector<Fsa>::clear)
.def("__len__", [](const std::vector<Fsa> &self) { return self.size(); })
.def("push_back",
[](std::vector<Fsa> *self, const Fsa &fsa) { self->push_back(fsa); })
.def("__iter__",
[](const std::vector<Fsa> &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>());
// py::keep_alive<Nurse, Patient>
// 0 is the return value and 1 is the first argument.
// Keep the patient (i.e., `self`) alive as long as the Nurse (i.e., the
// return value) is not freed.
py::class_<std::vector<Arc>>(m, "ArcVec")
.def(py::init<>())
.def("clear", &std::vector<Arc>::clear)
.def("__len__", [](const std::vector<Arc> &self) { return self.size(); })
.def("__iter__",
[](const std::vector<Arc> &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>());
py::class_<Cfsa>(m, "Cfsa")
.def(py::init<>())
.def(py::init<const Fsa &>(), py::arg("fsa"), py::keep_alive<1, 2>())
.def("num_states", &Cfsa::NumStates)
.def("num_arcs", &Cfsa::NumArcs)
.def("arc",
[](Cfsa *self, int s) {
DCHECK_GE(s, 0);
DCHECK_LT(s, self->NumStates());
auto begin = self->arc_indexes[s];
auto end = self->arc_indexes[s + 1];
return py::make_iterator(self->arcs + begin, self->arcs + end);
},
py::keep_alive<0, 1>())
.def("__str__",
[](const Cfsa &self) {
std::ostringstream os;
os << self;
return os.str();
})
.def("__eq__", // for test only
[](const Cfsa &self, const Cfsa &other) { return self == other; });
py::class_<std::vector<Cfsa>>(m, "CfsaStdVec")
.def(py::init<>())
.def("clear", &std::vector<Cfsa>::clear)
.def("push_back", [](std::vector<Cfsa> *self,
const Cfsa &cfsa) { self->push_back(cfsa); })
.def("__len__", [](const std::vector<Cfsa> &self) { return self.size(); })
.def("__iter__",
[](const std::vector<Cfsa> &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>());
PybindCfsaVec(m);
*/
// Note(fangjun): Users are not supposed to use `k2::Fsa` directly
// in Python; the following wrapper is only used by pybind11 internally
// so that it knows `k2::_Fsa` is a subclass of `k2::Fsa`.
py::class_<k2::Fsa>(m, "__Fsa");

using PyClass = k2::_Fsa;
py::class_<PyClass, k2::Fsa>(m, "Fsa")
.def(py::init<py::capsule, py::capsule>(), py::arg("indexes"),
py::arg("data"))
.def("empty", &PyClass::Empty)
.def("num_states", &PyClass::NumStates)
.def("final_state", &PyClass::FinalState);
}
1 change: 1 addition & 0 deletions k2/python/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "k2/python/csrc/k2.h"

void PybindArc(py::module &m);
void PybindFsa(py::module &m);

#endif // K2_PYTHON_CSRC_FSA_H_
18 changes: 0 additions & 18 deletions k2/python/csrc/fsa_renderer.cc

This file was deleted.

14 changes: 0 additions & 14 deletions k2/python/csrc/fsa_renderer.h

This file was deleted.

3 changes: 1 addition & 2 deletions k2/python/csrc/fsa_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,5 @@
#include "k2/csrc/fsa_util.h"

void PybindFsaUtil(py::module &m) {
// m.def("string_to_fsa", &k2::StringToFsa);
m.def("fsa_to_string", &k2::FsaToString);
m.def("fsa_to_str", &k2::FsaToString, py::arg("fsa"));
}
7 changes: 3 additions & 4 deletions k2/python/csrc/k2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
#include "k2/python/csrc/k2.h"

#include "k2/python/csrc/fsa.h"
#include "k2/python/csrc/fsa_renderer.h"
#include "k2/python/csrc/fsa_util.h"

PYBIND11_MODULE(k2, m) {
PYBIND11_MODULE(_k2, m) {
m.doc() = "pybind11 binding of k2";
// PybindFsa(m);
PybindFsaRenderer(m);
PybindArc(m);
PybindFsa(m);
PybindFsaUtil(m);
}
Loading

0 comments on commit d1754fb

Please sign in to comment.