forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBestPathDecoder.py
66 lines (57 loc) · 2.24 KB
/
BestPathDecoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import theano
import theano.tensor as T
import os
Tfloat = theano.config.floatX # @UndefinedVariable
class BestPathDecodeOp(theano.Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def __str__(self):
return self.__class__.__name__
def make_node(self, x, y, seq_lengths):
x = theano.tensor.as_tensor_variable(x)
assert x.ndim == 3 # tensor: nframes x nseqs x dim
y = theano.tensor.as_tensor_variable(y)
assert y.ndim == 2 # matrix: nseqs x max_labelling_length
seq_lengths = theano.tensor.as_tensor_variable(seq_lengths)
assert seq_lengths.ndim == 1 # vector of seqs lengths
return theano.Apply(self, [x, y, seq_lengths], [T.ivector()])
#output: number of edits for each sequence
def c_code(self, node, name, inp, out, sub):
x, y, seq_lengths = inp
lev, = out
fail = sub['fail']
return """
Py_XDECREF(%(lev)s);
npy_intp dims[] = {PyArray_DIM(%(x)s,1)};
%(lev)s = (PyArrayObject*) PyArray_Zeros(1, dims, PyArray_DescrFromType(NPY_INT32), 0);
if(!%(lev)s)
%(fail)s;
{
CArrayF xWr(%(x)s);
CArrayI yWr(%(y)s);
CArrayI seqLensWr(%(seq_lengths)s);
ArrayI levWr(%(lev)s);
int numSeqs = seqLensWr.dim(0);
#pragma omp parallel for
for(int i = 0; i < numSeqs; ++i)
{
BestPathDecoder decoder;
decoder.labellingErrors(xWr, seqLensWr, i, yWr, levWr);
}
}
""" % locals()
def c_compile_args(self):
return ['-fopenmp']
#IMPORTANT: change this, if you change the c-code
def c_code_cache_version(self):
return (1.62,)
def c_support_code(self):
src = ""
path = os.path.dirname(os.path.abspath(__file__))
with open(path + '/C_Support_Code.cpp', 'r') as f:
src += f.read()
with open(path + '/BestPathDecoder.cpp', 'r') as f:
src += f.read()
return src