Skip to content

Commit

Permalink
add pybindings for Array2 (#68)
Browse files Browse the repository at this point in the history
* add pybindings for Array2

* fix some comment issues
  • Loading branch information
qindazhu authored Jul 19, 2020
1 parent d1754fb commit 0667a60
Show file tree
Hide file tree
Showing 13 changed files with 297 additions and 23 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ GSYMS
GPATH
tags
TAGS

# python build files
**/__pycache__
1 change: 1 addition & 0 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# please sort the files alphabetically
pybind11_add_module(_k2
array.cc
fsa.cc
fsa_util.cc
k2.cc
Expand Down
141 changes: 141 additions & 0 deletions k2/python/csrc/array.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// k2/python/csrc/array.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

#include "k2/python/csrc/array.h"

#include <memory>
#include <utility>

#include "k2/csrc/array.h"
#include "k2/python/csrc/tensor.h"

namespace k2 {
/*
DLPackArray2 initializes Array2 with `cap_indexes` and `cap_data` which are
DLManagedTensors.
`cap_indexes` is usually a one dimensional contiguous array, i.e.,
`cap_indexes.ndim == 1 && cap_indexes.strides[0] == 1`.
`cap_data` may have different shapes depending on `ValueType`:
1. if `ValueType` is a primitive type (e.g. `int32_t`), it will be
a one dimensional contiguous array, i.e.,
`cap_data.ndim == 1 && cap_data.strides[0] == 1`.
2. if `ValueType` is a complex type (e.g. Arc), it will be a two
dimension array, i.e., it meets the following requirements:
a) cap_data.ndim == 2.
b) cap_data.shape[0] == num-elements it stores; note the
element's type is `ValueType`, which means we view each row of
`cap_data.data` as one element with type `ValueType`.
c) cap_data.shape[1] == num-primitive-values in `ValueType`,
which means we require that `ValueType` can be viewed as a tensor,
this is true for Arc as it only holds primitive values with same
type (i.e. `int32_t`), but may need type cast in other cases
(e.g. ValueType contains both `int32_t` and `float`).
d) cap_data.strides[0] == num-primitive-values in `ValueType`.
e) cap_data.strides[1] == 1.
Note if `data` in Array2 has stride > 1 (i.e. `data`'s type is
StridedPtr<ValueType>), the requirement of `cap_data` is nearly same with
case 2 above except cap_data.strides[0] will be greater than
num-primitive-values in `ValueType`.
*/
template <typename ValueType, bool IsPrimitive, typename I>
class DLPackArray2;

template <typename ValueType, typename I>
class DLPackArray2<ValueType *, true, I> : public Array2<ValueType *, I> {
public:
DLPackArray2(py::capsule cap_indexes, py::capsule cap_data)
: indexes_tensor_(new Tensor(cap_indexes)),
data_tensor_(new Tensor(cap_data)) {
CHECK_EQ(indexes_tensor_->NumDim(), 1);
CHECK_GE(indexes_tensor_->Shape(0), 1); // must have one element at least
CHECK_EQ(indexes_tensor_->Stride(0), 1);

CHECK_EQ(data_tensor_->NumDim(), 1);
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements
CHECK_EQ(data_tensor_->Stride(0), 1);

int32_t size1 = indexes_tensor_->Shape(0) - 1;
int32_t size2 = data_tensor_->Shape(0);
this->Init(size1, size2, indexes_tensor_->Data<I>(),
data_tensor_->Data<ValueType>());
}

private:
std::unique_ptr<Tensor> indexes_tensor_;
std::unique_ptr<Tensor> data_tensor_;
};

template <typename ValueType, typename I>
class DLPackArray2<ValueType *, false, I> : public Array2<ValueType *, I> {
public:
DLPackArray2(py::capsule cap_indexes, py::capsule cap_data)
: indexes_tensor_(new Tensor(cap_indexes)),
data_tensor_(new Tensor(cap_data)) {
CHECK_EQ(indexes_tensor_->NumDim(), 1);
CHECK_GE(indexes_tensor_->Shape(0), 1); // must have one element at least
CHECK_EQ(indexes_tensor_->Stride(0), 1);

CHECK_EQ(data_tensor_->NumDim(), 2);
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements
CHECK_EQ(data_tensor_->Shape(1) * data_tensor_->BytesPerElement(),
sizeof(ValueType));
CHECK_EQ(data_tensor_->Stride(0) * data_tensor_->BytesPerElement(),
sizeof(ValueType));
CHECK_EQ(data_tensor_->Stride(1), 1);

int32_t size1 = indexes_tensor_->Shape(0) - 1;
int32_t size2 = data_tensor_->Shape(0);
this->Init(size1, size2, indexes_tensor_->Data<I>(),
data_tensor_->Data<ValueType>());
}

private:
std::unique_ptr<Tensor> indexes_tensor_;
std::unique_ptr<Tensor> data_tensor_;
};

// Note: we can specialized for `StridedPtr` later if we need it.

} // namespace k2

template <typename Ptr, bool IsPrimitive, typename I = int32_t>
void PybindArray2Tpl(py::module &m, const char *name) {
using PyClass = k2::DLPackArray2<Ptr, IsPrimitive, I>;
using Parent = k2::Array2<Ptr, I>;
py::class_<PyClass, Parent>(m, name)
.def(py::init<py::capsule, py::capsule>(), py::arg("indexes"),
py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"__iter__",
[](const PyClass &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>())
.def_readonly("size1", &PyClass::size1)
.def_readonly("size2", &PyClass::size2)
.def("indexes", [](const PyClass &self, I i) { return self.indexes[i]; })
.def("data", [](const PyClass &self, I i) { return self.data[i]; });
// TODO(haowen): expose `indexes` and `data` as an array
// instead of a function call?
}

void PybindArray(py::module &m) {
// Note: all the following wrappers whose name starts with `_` are only used
// by pybind11 internally so that it knows `k2::DLPackArray2` is a subclass of
// `k2::Array2`.
py::class_<k2::Array2<int32_t *>>(m, "_IntArray2");
PybindArray2Tpl<int32_t *, true>(m, "DLPackIntArray2");

// note there is a type cast as the underlying Tensor is with type `float`
py::class_<k2::Array2<std::pair<int32_t, float> *>>(m, "_LogSumArcDerivs");
PybindArray2Tpl<std::pair<int32_t, float> *, false>(m,
"DLPackLogSumArcDerivs");
}
14 changes: 14 additions & 0 deletions k2/python/csrc/array.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// k2/python/csrc/array.h

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

#ifndef K2_PYTHON_CSRC_ARRAY_H_
#define K2_PYTHON_CSRC_ARRAY_H_

#include "k2/python/csrc/k2.h"

void PybindArray(py::module &m);

#endif // K2_PYTHON_CSRC_ARRAY_H_
42 changes: 27 additions & 15 deletions k2/python/csrc/fsa.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// k2/python/csrc/fsa.cc

// Copyright (c) 2020 Fangjun Kuang ([email protected])
// Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

Expand All @@ -15,23 +16,24 @@ namespace k2 {

// it uses external memory passed from DLPack (e.g., by PyTorch)
// to construct an Fsa.
class _Fsa : public Fsa {
class DLPackFsa : public Fsa {
public:
_Fsa(py::capsule cap_indexes, py::capsule cap_data)
DLPackFsa(py::capsule cap_indexes, py::capsule cap_data)
: indexes_tensor_(new Tensor(cap_indexes)),
data_tensor_(new Tensor(cap_data)) {
CHECK_EQ(indexes_tensor_->dtype(), kInt32Type);
CHECK_EQ(indexes_tensor_->NumDim(), 1);
CHECK_GT(indexes_tensor_->Shape(0), 1);
CHECK_EQ(indexes_tensor_->Stride(0), 1)
<< "Only contiguous index arrays are supported at present";
CHECK_GE(indexes_tensor_->Shape(0), 1);
CHECK_EQ(indexes_tensor_->Stride(0), 1);

CHECK_EQ(data_tensor_->dtype(), kInt32Type);
CHECK_EQ(data_tensor_->NumDim(), 2);
CHECK_EQ(data_tensor_->Stride(1), 1)
<< "Only contiguous data arrays at supported at present";
CHECK_EQ(sizeof(Arc),
data_tensor_->Shape(1) * data_tensor_->BytesPerElement());
CHECK_GE(data_tensor_->Shape(0), 0); // num-elements
CHECK_EQ(data_tensor_->Shape(1) * data_tensor_->BytesPerElement(),
sizeof(Arc));
CHECK_EQ(data_tensor_->Stride(0) * data_tensor_->BytesPerElement(),
sizeof(Arc));
CHECK_EQ(data_tensor_->Stride(1), 1);

int32_t size1 = indexes_tensor_->Shape(0) - 1;
int32_t size2 = data_tensor_->Shape(0);
Expand Down Expand Up @@ -63,16 +65,26 @@ void PybindArc(py::module &m) {
}

void PybindFsa(py::module &m) {
// Note(fangjun): Users are not supposed to use `k2::Fsa` directly
// in Python; the following wrapper is only used by pybind11 internally
// so that it knows `k2::_Fsa` is a subclass of `k2::Fsa`.
py::class_<k2::Fsa>(m, "__Fsa");
// The following wrapper is only used by pybind11 internally
// so that it knows `k2::DLPackFsa` is a subclass of `k2::Fsa`.
py::class_<k2::Fsa>(m, "_Fsa");

using PyClass = k2::_Fsa;
py::class_<PyClass, k2::Fsa>(m, "Fsa")
using PyClass = k2::DLPackFsa;
py::class_<PyClass, k2::Fsa>(m, "DLPackFsa")
.def(py::init<py::capsule, py::capsule>(), py::arg("indexes"),
py::arg("data"))
.def("empty", &PyClass::Empty)
.def(
"__iter__",
[](const PyClass &self) {
return py::make_iterator(self.begin(), self.end());
},
py::keep_alive<0, 1>())
.def_readonly("size1", &PyClass::size1)
.def_readonly("size2", &PyClass::size2)
.def("indexes",
[](const PyClass &self, int32_t i) { return self.indexes[i]; })
.def("data", [](const PyClass &self, int32_t i) { return self.data[i]; })
.def("num_states", &PyClass::NumStates)
.def("final_state", &PyClass::FinalState);
}
2 changes: 2 additions & 0 deletions k2/python/csrc/k2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

#include "k2/python/csrc/k2.h"

#include "k2/python/csrc/array.h"
#include "k2/python/csrc/fsa.h"
#include "k2/python/csrc/fsa_util.h"

PYBIND11_MODULE(_k2, m) {
m.doc() = "pybind11 binding of k2";
PybindArc(m);
PybindArray(m);
PybindFsa(m);
PybindFsaUtil(m);
}
6 changes: 3 additions & 3 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO(fangjun): import only things we need
from _k2 import *

from _k2 import Arc
from .array import *
from .fsa import *
from .fsa_util import str_to_fsa
21 changes: 21 additions & 0 deletions k2/python/k2/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

# See ../../../LICENSE for clarification regarding multiple authors

import torch
from torch.utils.dlpack import to_dlpack

from _k2 import DLPackIntArray2
from _k2 import DLPackLogSumArcDerivs

class IntArray2(DLPackIntArray2):

# TODO(haowen): add methods to construct object with Array2Size
def __init__(self, indexes: torch.Tensor, data: torch.Tensor):
super().__init__(to_dlpack(indexes), to_dlpack(data))


class LogSumArcDerivs(DLPackLogSumArcDerivs):

def __init__(self, indexes: torch.Tensor, data: torch.Tensor):
super().__init__(to_dlpack(indexes), to_dlpack(data))
16 changes: 16 additions & 0 deletions k2/python/k2/fsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

# See ../../../LICENSE for clarification regarding multiple authors

import torch
from torch.utils.dlpack import to_dlpack

from _k2 import DLPackFsa

class Fsa(DLPackFsa):

# TODO(haowen): add methods to construct object with Array2Size
def __init__(self, indexes: torch.Tensor, data: torch.Tensor):
super().__init__(to_dlpack(indexes), to_dlpack(data))


6 changes: 1 addition & 5 deletions k2/python/k2/fsa_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from collections import defaultdict

import torch
from torch.utils.dlpack import to_dlpack

from _k2 import Fsa
from .fsa import Fsa


def str_to_fsa(s: str) -> Fsa:
Expand Down Expand Up @@ -64,8 +63,5 @@ def str_to_fsa(s: str) -> Fsa:
data = torch.tensor(arcs, dtype=torch.int32)
indexes = torch.tensor(indexes, dtype=torch.int32)

data = to_dlpack(data)
indexes = to_dlpack(indexes)

fsa = Fsa(indexes, data)
return fsa
1 change: 1 addition & 0 deletions k2/python/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ endfunction()

# please sort the files in alphabetic order
set(py_test_files
array_test.py
fsa_test.py
)

Expand Down
63 changes: 63 additions & 0 deletions k2/python/tests/array_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
#
# Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)
#
# See ../../../LICENSE for clarification regarding multiple authors

# To run this single test, use
#
# ctest --verbose -R array_test_py
#

import unittest

import torch

import k2


class TestArray(unittest.TestCase):

def test_int_array2(self):
data = torch.arange(10).to(torch.int32)
indexes = torch.tensor([0, 2, 5, 6, 10]).to(torch.int32)
self.assertEqual(data.numel(),indexes[-1].item())

array = k2.IntArray2(indexes, data)
self.assertFalse(array.empty())
self.assertIsInstance(array, k2.IntArray2)

# test iterator
for i, v in enumerate(array):
self.assertEqual(i, v)

self.assertEqual(indexes.numel(), array.size1 + 1)
self.assertEqual(data.numel(), array.size2)

# the underlying memory is shared between k2 and torch;
# so change one will change another
data[0] = 100
self.assertEqual(array.data(0), 100)

del data
# the array in k2 is still accessible
self.assertEqual(array.data(0), 100)


def test_logsum_arc_derivs(self):
data = torch.arange(10).reshape(5,2).to(torch.float)
indexes = torch.tensor([0, 2, 3, 5]).to(torch.int32)
self.assertEqual(data.shape[0],indexes[-1].item())

array = k2.LogSumArcDerivs(indexes, data)
self.assertFalse(array.empty())
self.assertIsInstance(array, k2.LogSumArcDerivs)

self.assertEqual(indexes.numel(), array.size1 + 1)
self.assertEqual(data.shape[0], array.size2)

self.assertEqual(array.data(0), (0,1.0))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 0667a60

Please sign in to comment.