Skip to content

Commit

Permalink
wrap properties and some fsa algorithms to Python (#70)
Browse files Browse the repository at this point in the history
* wrap properties and some fsa algorithms to Python

* fix comment issues
  • Loading branch information
qindazhu authored Jul 25, 2020
1 parent bf72293 commit 07ae981
Show file tree
Hide file tree
Showing 22 changed files with 907 additions and 36 deletions.
6 changes: 3 additions & 3 deletions k2/csrc/aux_labels.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AuxLabels1Mapper {
: labels_in_(labels_in), arc_map_(arc_map) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] aux_size The number of lists in the output AuxLabels
(equals num-arcs in the output FSA) and
Expand Down Expand Up @@ -110,7 +110,7 @@ class AuxLabels2Mapper {
: labels_in_(labels_in), arc_map_(arc_map) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] aux_size The number of lists in the output AuxLabels
(equals num-arcs in the output FSA) and
Expand Down Expand Up @@ -152,7 +152,7 @@ class FstInverter {
: fsa_in_(fsa_in), labels_in_(labels_in) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the FSA
will be written to here
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/connect.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class Connection {
explicit Connection(const Fsa &fsa_in) : fsa_in_(fsa_in) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/determinize.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Determinizer {
}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/fsa_equivalent.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class RandPath {
}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down
4 changes: 2 additions & 2 deletions k2/csrc/fsa_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class StringToFsa {
explicit StringToFsa(const std::string &s) : s_(s) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down Expand Up @@ -297,7 +297,7 @@ class RandFsaGenerator {
explicit RandFsaGenerator(const RandFsaOptions &opts) : opts_(opts) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the generated FSA
will be written to here
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/intersect.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Intersection {
Intersection(const Fsa &a, const Fsa &b) : a_(a), b_(b) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/rmepsilon.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class EpsilonsRemover {
}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down
2 changes: 1 addition & 1 deletion k2/csrc/topsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TopSorter {
explicit TopSorter(const Fsa &fsa_in) : fsa_in_(fsa_in) {}

/*
Do enough work that know now much memory will be needed, and output
Do enough work to know how much memory will be needed, and output
that information
@param [out] fsa_size The num-states and num-arcs of the output FSA
will be written to here
Expand Down
1 change: 1 addition & 0 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pybind11_add_module(_k2
fsa_algo.cc
fsa_util.cc
k2.cc
properties.cc
tensor.cc
)

Expand Down
72 changes: 63 additions & 9 deletions k2/python/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

#include "k2/csrc/arcsort.h"
#include "k2/csrc/array.h"
#include "k2/csrc/connect.h"
#include "k2/csrc/intersect.h"
#include "k2/csrc/topsort.h"
#include "k2/python/csrc/array.h"

namespace k2 {} // namespace k2

void PyBindArcSort(py::module &m) {
using PyClass = k2::ArcSorter;
py::class_<PyClass>(m, "_ArcSorter")
Expand All @@ -24,19 +25,72 @@ void PyBindArcSort(py::module &m) {
"get_output",
[](PyClass &self, k2::Fsa *fsa_out,
k2::Array1<int32_t *> *arc_map = nullptr) {
self.GetOutput(fsa_out,
arc_map == nullptr ? nullptr : arc_map->data);
return self.GetOutput(fsa_out,
arc_map == nullptr ? nullptr : arc_map->data);
},
py::arg("fsa_out"),
py::arg("arc_map") = (k2::Array1<int32_t *> *)nullptr);
py::arg("fsa_out"), py::arg("arc_map").none(true));

m.def(
"_arc_sort",
[](k2::Fsa *fsa, k2::Array1<int32_t *> *arc_map = nullptr) {
k2::ArcSort(fsa, arc_map == nullptr ? nullptr : arc_map->data);
return k2::ArcSort(fsa, arc_map == nullptr ? nullptr : arc_map->data);
},
"in-place version of ArcSorter", py::arg("fsa"),
py::arg("arc_map") = (k2::Array1<int32_t *> *)nullptr);
py::arg("arc_map").none(true));
}

void PybindFsaAlgo(py::module &m) { PyBindArcSort(m); }
void PyBindTopSort(py::module &m) {
using PyClass = k2::TopSorter;
py::class_<PyClass>(m, "_TopSorter")
.def(py::init<const k2::Fsa &>(), py::arg("fsa_in"))
.def("get_sizes", &PyClass::GetSizes, py::arg("fsa_size"))
.def(
"get_output",
[](PyClass &self, k2::Fsa *fsa_out,
k2::Array1<int32_t *> *state_map = nullptr) -> bool {
return self.GetOutput(
fsa_out, state_map == nullptr ? nullptr : state_map->data);
},
py::arg("fsa_out"), py::arg("state_map").none(true));
}

void PyBindConnect(py::module &m) {
using PyClass = k2::Connection;
py::class_<PyClass>(m, "_Connection")
.def(py::init<const k2::Fsa &>(), py::arg("fsa_in"))
.def("get_sizes", &PyClass::GetSizes, py::arg("fsa_size"))
.def(
"get_output",
[](PyClass &self, k2::Fsa *fsa_out,
k2::Array1<int32_t *> *arc_map = nullptr) -> bool {
return self.GetOutput(fsa_out,
arc_map == nullptr ? nullptr : arc_map->data);
},
py::arg("fsa_out"), py::arg("arc_map").none(true));
}

void PyBindIntersect(py::module &m) {
using PyClass = k2::Intersection;
py::class_<PyClass>(m, "_Intersection")
.def(py::init<const k2::Fsa &, const k2::Fsa &>(), py::arg("fsa_a"),
py::arg("fsa_b"))
.def("get_sizes", &PyClass::GetSizes, py::arg("fsa_size"))
.def(
"get_output",
[](PyClass &self, k2::Fsa *fsa_out,
k2::Array1<int32_t *> *arc_map_a = nullptr,
k2::Array1<int32_t *> *arc_map_b = nullptr) -> bool {
return self.GetOutput(
fsa_out, arc_map_a == nullptr ? nullptr : arc_map_a->data,
arc_map_b == nullptr ? nullptr : arc_map_b->data);
},
py::arg("fsa_out"), py::arg("arc_map_a").none(true),
py::arg("arc_map_b").none(true));
}

void PybindFsaAlgo(py::module &m) {
PyBindArcSort(m);
PyBindTopSort(m);
PyBindConnect(m);
PyBindIntersect(m);
}
2 changes: 2 additions & 0 deletions k2/python/csrc/k2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "k2/python/csrc/fsa.h"
#include "k2/python/csrc/fsa_algo.h"
#include "k2/python/csrc/fsa_util.h"
#include "k2/python/csrc/properties.h"

PYBIND11_MODULE(_k2, m) {
m.doc() = "pybind11 binding of k2";
Expand All @@ -19,4 +20,5 @@ PYBIND11_MODULE(_k2, m) {
PybindFsa(m);
PybindFsaUtil(m);
PybindFsaAlgo(m);
PybindProperties(m);
}
34 changes: 34 additions & 0 deletions k2/python/csrc/properties.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// k2/python/csrc/properties.cc

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

#include "k2/python/csrc/properties.h"

#include <vector>

#include "k2/csrc/array.h"
#include "k2/csrc/fsa.h"
#include "k2/csrc/properties.h"
#include "k2/python/csrc/array.h"

// We would never pass `order` parameter to k2::IsAcyclic in Python code.
// We can make it accept `None` with `std::optional` in pybind11, but
// that will require C++17, so we here choose to write a version without
// `order`.
static bool IsAcyclic(const k2::Fsa &fsa) {
return k2::IsAcyclic(fsa /*, std::vector<int32_t>* order = nullptr*/);
}

void PybindProperties(py::module &m) {
m.def("_is_valid", &k2::IsValid, py::arg("fsa"));
m.def("_is_top_sorted", &k2::IsTopSorted, py::arg("fsa"));
m.def("_is_arc_sorted", &k2::IsArcSorted, py::arg("fsa"));
m.def("_has_self_loops", &k2::HasSelfLoops, py::arg("fsa"));
m.def("_is_acyclic", &IsAcyclic, py::arg("fsa"));
m.def("_is_deterministic", &k2::IsDeterministic, py::arg("fsa"));
m.def("_is_epsilon_free", &k2::IsEpsilonFree, py::arg("fsa"));
m.def("_is_connected", &k2::IsConnected, py::arg("fsa"));
m.def("_is_empty", &k2::IsEmpty, py::arg("fsa"));
}
14 changes: 14 additions & 0 deletions k2/python/csrc/properties.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// k2/python/csrc/properties.h

// Copyright (c) 2020 Xiaomi Corporation (author: Haowen Qiu)

// See ../../../LICENSE for clarification regarding multiple authors

#ifndef K2_PYTHON_CSRC_PROPERTIES_H_
#define K2_PYTHON_CSRC_PROPERTIES_H_

#include "k2/python/csrc/k2.h"

void PybindProperties(py::module &m);

#endif // K2_PYTHON_CSRC_PROPERTIES_H_
1 change: 1 addition & 0 deletions k2/python/k2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .fsa import *
from .fsa_algo import *
from .fsa_util import str_to_fsa
from .properties import *
62 changes: 56 additions & 6 deletions k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,71 @@
from .array import IntArray2
from _k2 import _ArcSorter
from _k2 import _arc_sort
from _k2 import _TopSorter
from _k2 import _Connection
from _k2 import _Intersection


class ArcSorter(_ArcSorter):

def __init__(self, fsa_in: Fsa):
super().__init__(fsa_in.get_base())

def get_sizes(self, array_size: IntArray2):
super().get_sizes(array_size)
def get_sizes(self, array_size: IntArray2) -> None:
return super().get_sizes(array_size)

def get_output(self, fsa_out: Fsa, arc_map: IntArray1 = None):
super().get_output(fsa_out.get_base(),
arc_map.get_base() if arc_map is not None else None)
def get_output(self, fsa_out: Fsa, arc_map: IntArray1 = None) -> None:
return super().get_output(
fsa_out.get_base(),
arc_map.get_base() if arc_map is not None else None)


def arc_sort(fsa: Fsa, arc_map: IntArray1 = None):
def arc_sort(fsa: Fsa, arc_map: IntArray1 = None) -> None:
return _arc_sort(fsa.get_base(),
arc_map.get_base() if arc_map is not None else None)


class TopSorter(_TopSorter):

def __init__(self, fsa_in: Fsa):
super().__init__(fsa_in.get_base())

def get_sizes(self, array_size: IntArray2) -> None:
return super().get_sizes(array_size)

def get_output(self, fsa_out: Fsa, state_map: IntArray1 = None) -> bool:
return super().get_output(
fsa_out.get_base(),
state_map.get_base() if state_map is not None else None)


class Connection(_Connection):

def __init__(self, fsa_in: Fsa):
super().__init__(fsa_in.get_base())

def get_sizes(self, array_size: IntArray2) -> None:
return super().get_sizes(array_size)

def get_output(self, fsa_out: Fsa, arc_map: IntArray1 = None) -> bool:
return super().get_output(
fsa_out.get_base(),
arc_map.get_base() if arc_map is not None else None)


class Intersection(_Intersection):

def __init__(self, fsa_a: Fsa, fsa_b: Fsa):
super().__init__(fsa_a.get_base(), fsa_b.get_base())

def get_sizes(self, array_size: IntArray2) -> None:
return super().get_sizes(array_size)

def get_output(self,
fsa_out: Fsa,
arc_map_a: IntArray1 = None,
arc_map_b: IntArray1 = None) -> bool:
return super().get_output(
fsa_out.get_base(),
arc_map_a.get_base() if arc_map_a is not None else None,
arc_map_b.get_base() if arc_map_b is not None else None)
Loading

0 comments on commit 07ae981

Please sign in to comment.