-
Notifications
You must be signed in to change notification settings - Fork 217
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add pybindings for Array2 * fix some comment issues
- Loading branch information
Showing
13 changed files
with
297 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,3 +46,6 @@ GSYMS | |
GPATH | ||
tags | ||
TAGS | ||
|
||
# python build files | ||
**/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# please sort the files alphabetically | ||
pybind11_add_module(_k2 | ||
array.cc | ||
fsa.cc | ||
fsa_util.cc | ||
k2.cc | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// k2/python/csrc/array.cc | ||
|
||
// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) | ||
|
||
// See ../../../LICENSE for clarification regarding multiple authors | ||
|
||
#include "k2/python/csrc/array.h" | ||
|
||
#include <memory> | ||
#include <utility> | ||
|
||
#include "k2/csrc/array.h" | ||
#include "k2/python/csrc/tensor.h" | ||
|
||
namespace k2 { | ||
/* | ||
DLPackArray2 initializes Array2 with `cap_indexes` and `cap_data` which are | ||
DLManagedTensors. | ||
`cap_indexes` is usually a one dimensional contiguous array, i.e., | ||
`cap_indexes.ndim == 1 && cap_indexes.strides[0] == 1`. | ||
`cap_data` may have different shapes depending on `ValueType`: | ||
1. if `ValueType` is a primitive type (e.g. `int32_t`), it will be | ||
a one dimensional contiguous array, i.e., | ||
`cap_data.ndim == 1 && cap_data.strides[0] == 1`. | ||
2. if `ValueType` is a complex type (e.g. Arc), it will be a two | ||
dimension array, i.e., it meets the following requirements: | ||
a) cap_data.ndim == 2. | ||
b) cap_data.shape[0] == num-elements it stores; note the | ||
element's type is `ValueType`, which means we view each row of | ||
`cap_data.data` as one element with type `ValueType`. | ||
c) cap_data.shape[1] == num-primitive-values in `ValueType`, | ||
which means we require that `ValueType` can be viewed as a tensor, | ||
this is true for Arc as it only holds primitive values with same | ||
type (i.e. `int32_t`), but may need type cast in other cases | ||
(e.g. ValueType contains both `int32_t` and `float`). | ||
d) cap_data.strides[0] == num-primitive-values in `ValueType`. | ||
e) cap_data.strides[1] == 1. | ||
Note if `data` in Array2 has stride > 1 (i.e. `data`'s type is | ||
StridedPtr<ValueType>), the requirement of `cap_data` is nearly same with | ||
case 2 above except cap_data.strides[0] will be greater than | ||
num-primitive-values in `ValueType`. | ||
*/ | ||
template <typename ValueType, bool IsPrimitive, typename I> | ||
class DLPackArray2; | ||
|
||
template <typename ValueType, typename I> | ||
class DLPackArray2<ValueType *, true, I> : public Array2<ValueType *, I> { | ||
public: | ||
DLPackArray2(py::capsule cap_indexes, py::capsule cap_data) | ||
: indexes_tensor_(new Tensor(cap_indexes)), | ||
data_tensor_(new Tensor(cap_data)) { | ||
CHECK_EQ(indexes_tensor_->NumDim(), 1); | ||
CHECK_GE(indexes_tensor_->Shape(0), 1); // must have one element at least | ||
CHECK_EQ(indexes_tensor_->Stride(0), 1); | ||
|
||
CHECK_EQ(data_tensor_->NumDim(), 1); | ||
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements | ||
CHECK_EQ(data_tensor_->Stride(0), 1); | ||
|
||
int32_t size1 = indexes_tensor_->Shape(0) - 1; | ||
int32_t size2 = data_tensor_->Shape(0); | ||
this->Init(size1, size2, indexes_tensor_->Data<I>(), | ||
data_tensor_->Data<ValueType>()); | ||
} | ||
|
||
private: | ||
std::unique_ptr<Tensor> indexes_tensor_; | ||
std::unique_ptr<Tensor> data_tensor_; | ||
}; | ||
|
||
template <typename ValueType, typename I> | ||
class DLPackArray2<ValueType *, false, I> : public Array2<ValueType *, I> { | ||
public: | ||
DLPackArray2(py::capsule cap_indexes, py::capsule cap_data) | ||
: indexes_tensor_(new Tensor(cap_indexes)), | ||
data_tensor_(new Tensor(cap_data)) { | ||
CHECK_EQ(indexes_tensor_->NumDim(), 1); | ||
CHECK_GE(indexes_tensor_->Shape(0), 1); // must have one element at least | ||
CHECK_EQ(indexes_tensor_->Stride(0), 1); | ||
|
||
CHECK_EQ(data_tensor_->NumDim(), 2); | ||
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements | ||
CHECK_EQ(data_tensor_->Shape(1) * data_tensor_->BytesPerElement(), | ||
sizeof(ValueType)); | ||
CHECK_EQ(data_tensor_->Stride(0) * data_tensor_->BytesPerElement(), | ||
sizeof(ValueType)); | ||
CHECK_EQ(data_tensor_->Stride(1), 1); | ||
|
||
int32_t size1 = indexes_tensor_->Shape(0) - 1; | ||
int32_t size2 = data_tensor_->Shape(0); | ||
this->Init(size1, size2, indexes_tensor_->Data<I>(), | ||
data_tensor_->Data<ValueType>()); | ||
} | ||
|
||
private: | ||
std::unique_ptr<Tensor> indexes_tensor_; | ||
std::unique_ptr<Tensor> data_tensor_; | ||
}; | ||
|
||
// Note: we can specialized for `StridedPtr` later if we need it. | ||
|
||
} // namespace k2 | ||
|
||
template <typename Ptr, bool IsPrimitive, typename I = int32_t> | ||
void PybindArray2Tpl(py::module &m, const char *name) { | ||
using PyClass = k2::DLPackArray2<Ptr, IsPrimitive, I>; | ||
using Parent = k2::Array2<Ptr, I>; | ||
py::class_<PyClass, Parent>(m, name) | ||
.def(py::init<py::capsule, py::capsule>(), py::arg("indexes"), | ||
py::arg("data")) | ||
.def("empty", &PyClass::Empty) | ||
.def( | ||
"__iter__", | ||
[](const PyClass &self) { | ||
return py::make_iterator(self.begin(), self.end()); | ||
}, | ||
py::keep_alive<0, 1>()) | ||
.def_readonly("size1", &PyClass::size1) | ||
.def_readonly("size2", &PyClass::size2) | ||
.def("indexes", [](const PyClass &self, I i) { return self.indexes[i]; }) | ||
.def("data", [](const PyClass &self, I i) { return self.data[i]; }); | ||
// TODO(haowen): expose `indexes` and `data` as an array | ||
// instead of a function call? | ||
} | ||
|
||
void PybindArray(py::module &m) { | ||
// Note: all the following wrappers whose name starts with `_` are only used | ||
// by pybind11 internally so that it knows `k2::DLPackArray2` is a subclass of | ||
// `k2::Array2`. | ||
py::class_<k2::Array2<int32_t *>>(m, "_IntArray2"); | ||
PybindArray2Tpl<int32_t *, true>(m, "DLPackIntArray2"); | ||
|
||
// note there is a type cast as the underlying Tensor is with type `float` | ||
py::class_<k2::Array2<std::pair<int32_t, float> *>>(m, "_LogSumArcDerivs"); | ||
PybindArray2Tpl<std::pair<int32_t, float> *, false>(m, | ||
"DLPackLogSumArcDerivs"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// k2/python/csrc/array.h | ||
|
||
// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) | ||
|
||
// See ../../../LICENSE for clarification regarding multiple authors | ||
|
||
#ifndef K2_PYTHON_CSRC_ARRAY_H_ | ||
#define K2_PYTHON_CSRC_ARRAY_H_ | ||
|
||
#include "k2/python/csrc/k2.h" | ||
|
||
void PybindArray(py::module &m); | ||
|
||
#endif // K2_PYTHON_CSRC_ARRAY_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
// k2/python/csrc/fsa.cc | ||
|
||
// Copyright (c) 2020 Fangjun Kuang ([email protected]) | ||
// Xiaomi Corporation (author: Haowen Qiu) | ||
|
||
// See ../../../LICENSE for clarification regarding multiple authors | ||
|
||
|
@@ -15,23 +16,24 @@ namespace k2 { | |
|
||
// it uses external memory passed from DLPack (e.g., by PyTorch) | ||
// to construct an Fsa. | ||
class _Fsa : public Fsa { | ||
class DLPackFsa : public Fsa { | ||
public: | ||
_Fsa(py::capsule cap_indexes, py::capsule cap_data) | ||
DLPackFsa(py::capsule cap_indexes, py::capsule cap_data) | ||
: indexes_tensor_(new Tensor(cap_indexes)), | ||
data_tensor_(new Tensor(cap_data)) { | ||
CHECK_EQ(indexes_tensor_->dtype(), kInt32Type); | ||
CHECK_EQ(indexes_tensor_->NumDim(), 1); | ||
CHECK_GT(indexes_tensor_->Shape(0), 1); | ||
CHECK_EQ(indexes_tensor_->Stride(0), 1) | ||
<< "Only contiguous index arrays are supported at present"; | ||
CHECK_GE(indexes_tensor_->Shape(0), 1); | ||
CHECK_EQ(indexes_tensor_->Stride(0), 1); | ||
|
||
CHECK_EQ(data_tensor_->dtype(), kInt32Type); | ||
CHECK_EQ(data_tensor_->NumDim(), 2); | ||
CHECK_EQ(data_tensor_->Stride(1), 1) | ||
<< "Only contiguous data arrays at supported at present"; | ||
CHECK_EQ(sizeof(Arc), | ||
data_tensor_->Shape(1) * data_tensor_->BytesPerElement()); | ||
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements | ||
CHECK_EQ(data_tensor_->Shape(1) * data_tensor_->BytesPerElement(), | ||
sizeof(Arc)); | ||
CHECK_EQ(data_tensor_->Stride(0) * data_tensor_->BytesPerElement(), | ||
sizeof(Arc)); | ||
CHECK_EQ(data_tensor_->Stride(1), 1); | ||
|
||
int32_t size1 = indexes_tensor_->Shape(0) - 1; | ||
int32_t size2 = data_tensor_->Shape(0); | ||
|
@@ -63,16 +65,26 @@ void PybindArc(py::module &m) { | |
} | ||
|
||
void PybindFsa(py::module &m) { | ||
// Note(fangjun): Users are not supposed to use `k2::Fsa` directly | ||
// in Python; the following wrapper is only used by pybind11 internally | ||
// so that it knows `k2::_Fsa` is a subclass of `k2::Fsa`. | ||
py::class_<k2::Fsa>(m, "__Fsa"); | ||
// The following wrapper is only used by pybind11 internally | ||
// so that it knows `k2::DLPackFsa` is a subclass of `k2::Fsa`. | ||
py::class_<k2::Fsa>(m, "_Fsa"); | ||
|
||
using PyClass = k2::_Fsa; | ||
py::class_<PyClass, k2::Fsa>(m, "Fsa") | ||
using PyClass = k2::DLPackFsa; | ||
py::class_<PyClass, k2::Fsa>(m, "DLPackFsa") | ||
.def(py::init<py::capsule, py::capsule>(), py::arg("indexes"), | ||
py::arg("data")) | ||
.def("empty", &PyClass::Empty) | ||
.def( | ||
"__iter__", | ||
[](const PyClass &self) { | ||
return py::make_iterator(self.begin(), self.end()); | ||
}, | ||
py::keep_alive<0, 1>()) | ||
.def_readonly("size1", &PyClass::size1) | ||
.def_readonly("size2", &PyClass::size2) | ||
.def("indexes", | ||
[](const PyClass &self, int32_t i) { return self.indexes[i]; }) | ||
.def("data", [](const PyClass &self, int32_t i) { return self.data[i]; }) | ||
.def("num_states", &PyClass::NumStates) | ||
.def("final_state", &PyClass::FinalState); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# TODO(fangjun): import only things we need | ||
from _k2 import * | ||
|
||
from _k2 import Arc | ||
from .array import * | ||
from .fsa import * | ||
from .fsa_util import str_to_fsa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) | ||
|
||
# See ../../../LICENSE for clarification regarding multiple authors | ||
|
||
import torch | ||
from torch.utils.dlpack import to_dlpack | ||
|
||
from _k2 import DLPackIntArray2 | ||
from _k2 import DLPackLogSumArcDerivs | ||
|
||
class IntArray2(DLPackIntArray2): | ||
|
||
# TODO(haowen): add methods to construct object with Array2Size | ||
def __init__(self, indexes: torch.Tensor, data: torch.Tensor): | ||
super().__init__(to_dlpack(indexes), to_dlpack(data)) | ||
|
||
|
||
class LogSumArcDerivs(DLPackLogSumArcDerivs): | ||
|
||
def __init__(self, indexes: torch.Tensor, data: torch.Tensor): | ||
super().__init__(to_dlpack(indexes), to_dlpack(data)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) | ||
|
||
# See ../../../LICENSE for clarification regarding multiple authors | ||
|
||
import torch | ||
from torch.utils.dlpack import to_dlpack | ||
|
||
from _k2 import DLPackFsa | ||
|
||
class Fsa(DLPackFsa): | ||
|
||
# TODO(haowen): add methods to construct object with Array2Size | ||
def __init__(self, indexes: torch.Tensor, data: torch.Tensor): | ||
super().__init__(to_dlpack(indexes), to_dlpack(data)) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/usr/bin/env python3 | ||
# | ||
# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu) | ||
# | ||
# See ../../../LICENSE for clarification regarding multiple authors | ||
|
||
# To run this single test, use | ||
# | ||
# ctest --verbose -R array_test_py | ||
# | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
import k2 | ||
|
||
|
||
class TestArray(unittest.TestCase): | ||
|
||
def test_int_array2(self): | ||
data = torch.arange(10).to(torch.int32) | ||
indexes = torch.tensor([0, 2, 5, 6, 10]).to(torch.int32) | ||
self.assertEqual(data.numel(),indexes[-1].item()) | ||
|
||
array = k2.IntArray2(indexes, data) | ||
self.assertFalse(array.empty()) | ||
self.assertIsInstance(array, k2.IntArray2) | ||
|
||
# test iterator | ||
for i, v in enumerate(array): | ||
self.assertEqual(i, v) | ||
|
||
self.assertEqual(indexes.numel(), array.size1 + 1) | ||
self.assertEqual(data.numel(), array.size2) | ||
|
||
# the underlying memory is shared between k2 and torch; | ||
# so change one will change another | ||
data[0] = 100 | ||
self.assertEqual(array.data(0), 100) | ||
|
||
del data | ||
# the array in k2 is still accessible | ||
self.assertEqual(array.data(0), 100) | ||
|
||
|
||
def test_logsum_arc_derivs(self): | ||
data = torch.arange(10).reshape(5,2).to(torch.float) | ||
indexes = torch.tensor([0, 2, 3, 5]).to(torch.int32) | ||
self.assertEqual(data.shape[0],indexes[-1].item()) | ||
|
||
array = k2.LogSumArcDerivs(indexes, data) | ||
self.assertFalse(array.empty()) | ||
self.assertIsInstance(array, k2.LogSumArcDerivs) | ||
|
||
self.assertEqual(indexes.numel(), array.size1 + 1) | ||
self.assertEqual(data.shape[0], array.size2) | ||
|
||
self.assertEqual(array.data(0), (0,1.0)) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.