forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFsa.py
executable file
·1504 lines (1325 loc) · 53.3 KB
/
Fsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python2.7
from __future__ import print_function
from __future__ import division
import numpy
import theano
import pickle
from theano import tensor as T
from copy import deepcopy
from Log import log
from LmDataset import Lexicon, StateTying
from os.path import isfile
import itertools
class Edge:
"""
class to represent an edge
"""
# label placeholder
SIL = '_'
EPS = '*'
BLANK = '%'
def __init__(self, source_state_idx, target_state_idx, label, weight=0.0):
"""
:param int source_state_idx: the starting node of the edge
:param int target_state_idx: the ending node od th edge
:param int|str|None label: the label of the edge (normally a letter or a phoneme ...)
:param float weight: probability of the word/phon in -log space
"""
self.source_state_idx = source_state_idx
self.target_state_idx = target_state_idx
self.label = label
self.weight = weight
# int|str|None label_prev: previous label
self.label_prev = None
# int|str|None label_next: next label
self.label_next = None
# int|None idx_word_in_sentence: index of word in the given sentence
self.idx_word_in_sentence = None
# int|None idx_phon_in_word: index of phon in a word
self.idx_phon_in_word = None
# int|None idx: label index within the sentence/word/phon
self.idx = None
# int|None allo_idx: allophone position
self.allo_idx = None
# bool phon_at_word_begin: flag indicates if phon at the beginning of a word
self.phon_at_word_begin = False
# bool phon_at_word_end: flag indicates if phon at the end of a word
self.phon_at_word_end = False
# float|None score: score of the edge
self.score = None
# bool is_loop: is the edge a loop within the graph
self.is_loop = False
def __repr__(self):
return "".join(("[",
str(self.source_state_idx), ", ",
str(self.target_state_idx), ", ",
str(self.label), ", ",
str(self.weight),
"]"))
def __str__(self):
return "".join(("Edge:\n",
"Source state: ",
str(self.source_state_idx), "\n",
"Target state: ",
str(self.target_state_idx), "\n",
"Label: ",
str(self.label), "\n",
"Weight: ",
str(self.weight)))
def as_tuple(self):
return self.source_state_idx, self.target_state_idx, self.label, self.weight
def __eq__(self, other):
return self.as_tuple() == other.as_tuple()
def __ne__(self, other):
return self.as_tuple() != other.as_tuple()
def __le__(self, other):
return self.as_tuple() <= other.as_tuple()
def __lt__(self, other):
return self.as_tuple() < other.as_tuple()
def __ge__(self, other):
return self.as_tuple() >= other.as_tuple()
def __gt__(self, other):
return self.as_tuple() > other.as_tuple()
class Graph:
"""
class holds the Graph representing the Finite State Automaton
holds the input and the created output (ASG, CTC, HMM)
states between input and output may be held if necessary
"""
def __init__(self, lemma):
"""
:param str|list[str]|list[Edge]|None lemma: a sentence or word
list[str] lem_list: lemma transformed into list of strings
list[Edge] lem_edges: the lemma is provided as a list of edges, so basically is already a fsa
"""
# TODO use dict to distinguish between str and list?
if isinstance(lemma, str):
self.lemma = lemma.strip()
self.lem_list = self.lemma.lower().split()
self.lem_edges = None
elif isinstance(lemma, list) and isinstance(lemma[0], str):
self.lemma = None
self.lem_list = lemma
self.lem_edges = None
elif isinstance(lemma, list) and isinstance(lemma[0], Edge):
self.lemma = None
self.lem_list = None
self.lem_edges = lemma
else:
assert False, ("The input you provided is not acceptable!", lemma)
self.filename = None
# int num_states: number of states of FSA during creation and final
self.num_states = -1
self.num_states_asg = -1
self.num_states_ctc = -1
self.num_states_hmm = -1
self.num_states_word = -1
# list[Edge] edges: edges of FSA during creation and final state
self.edges = []
self.edges_asg = []
self.edges_ctc = []
self.edges_hmm = []
self.edges_word = []
def __repr__(self):
return "Graph()"
def __str__(self):
prettygraph = "Graph:\n"\
+ str(self.lem_list)\
+ "\nASG:\nNum states: "\
+ str(self.num_states_asg)\
+ "\nEdges:\n"\
+ str(self.edges_asg)\
+ "\nCTC:\nNum states: "\
+ str(self.num_states_ctc)\
+ "\nEdges:\n"\
+ str(self.edges_ctc)\
+ "\nHMM:\nNum states: "\
+ str(self.num_states_hmm)\
+ "\nEdges:\n"\
+ str(self.edges_hmm)
return prettygraph
def is_empty(self):
return True if self.num_states <= 0 and len(self.edges) <= 0 else False
@staticmethod
def make_single_state_graph(num_states, edges):
"""
takes a graph with several states and transforms into single state graph
:param int num_states: number of states
:param list[Edges] edges: list of Edges symbolizing the graph
:return: returns the transformed list of Edges with one state
:rtype: list[Edges]
"""
edges_single_state = deepcopy(edges)
if num_states > 1:
for edge in edges_single_state:
edge.source_state_idx = 0
edge.target_state_idx = 0
return edges_single_state
class Asg:
"""
class to create ASG FSA
"""
def __init__(self, fsa, num_labels=256, asg_repetition=2, label_conversion=False):
"""
:param Graph fsa: represents the Graph on which the class operates
:param int num_labels: number of labels without blank, silence, eps and repetitions
where num_labels > 0
:param int asg_repetition: asg repeat symbol which stands for x repetitions
where asg_repetition > 1
:param bool label_conversion: shall the labels be converted into numbers (only ASG and CTC)
"""
if isinstance(fsa, Graph) and isinstance(num_labels, int)\
and isinstance(asg_repetition, int) and isinstance(label_conversion, bool):
self.fsa = fsa
self.num_labels = num_labels
self.asg_repetition = asg_repetition
self.label_conversion = label_conversion
self.separator = False # words in the sentence will be separated by Edge.BLANK
else:
assert False, ("The ASG init went wrong!", fsa)
def run(self):
"""
creates the ASG FSA
"""
print("Starting ASG FSA Creation")
label_prev = None
rep_count = 0
label_repetitions = [] # marks the labels which will be replaced with a rep symbol
# goes through the list of strings
for lem in self.fsa.lem_list:
# goes through the string
reps_label = []
for label in lem:
label_cur = label
# check if current label matches previous label and generates label reps list
if label_cur == label_prev:
# adds reps symbol
if rep_count < self.asg_repetition:
rep_count += 1
else:
reps_label.append(self.num_labels + rep_count)
rep_count = 1
else:
# adds normal label
if rep_count != 0:
reps_label.append(self.num_labels + rep_count)
rep_count = 0
reps_label.append(label)
label_prev = label
# put reps list back into list -> list[list[str|int]]
label_repetitions.append(reps_label)
# create states
self.fsa.num_states = 0
cur_idx = 0
src_idx = 0
trgt_idx = 0
for rep_index, rep_label in enumerate(label_repetitions):
for idx, lab in enumerate(rep_label):
src_idx = cur_idx
trgt_idx = src_idx + 1
if cur_idx == 0: # for final state
self.fsa.num_states += 1
self.fsa.num_states += 1
edge = Edge(src_idx, trgt_idx, lab)
edge.idx_word_in_sentence = rep_index
edge.idx_phon_in_word = idx
edge.idx = cur_idx
if idx == 0:
edge.phon_at_word_begin = True
if idx == len(rep_label) - 1:
edge.phon_at_word_end = True
self.fsa.edges.append(edge)
cur_idx += 1
# adds separator between words in sentence
if self.separator and rep_index < len(label_repetitions) - 1:
self.fsa.edges.append(Edge(src_idx + 1, trgt_idx + 1, Edge.BLANK))
self.fsa.num_states += 1
cur_idx += 1
# adds loops to ASG FSA
for loop_idx in range(1, self.fsa.num_states):
edges_add_loop = [edge_idx for edge_idx, edge_cur in enumerate(self.fsa.edges)
if (edge_cur.target_state_idx == loop_idx and edge_cur.label != Edge.EPS
and edge_cur.label != Edge.SIL)]
for add_loop_edge in edges_add_loop:
edge = deepcopy(self.fsa.edges[add_loop_edge])
edge.source_state_idx = edge.target_state_idx
edge.is_loop = True
self.fsa.edges.append(edge)
self.fsa.edges.sort()
# label conversion
if self.label_conversion:
Store.label_conversion(self.fsa.edges)
self.fsa.num_states_asg = deepcopy(self.fsa.num_states)
self.fsa.num_states = -1
self.fsa.edges_asg = deepcopy(self.fsa.edges)
self.fsa.edges = []
class Ctc:
"""
class to create CTC FSA
"""
def __init__(self, fsa, num_labels=256, label_conversion=False):
"""
:param Graph fsa: represents the Graph on which the class operates
:param int num_labels: number of labels without blank, silence, eps and repetitions
:param bool label_conversion: shall the labels be converted into numbers (only ASG and CTC)
"""
assert isinstance(fsa, Graph)
assert isinstance(num_labels, int)
assert isinstance(label_conversion, int)
self.fsa = fsa
self.num_labels = num_labels
self.label_conversion = label_conversion
# list[int] final_states: list of final states
self.final_states = []
def run(self):
"""
creates the CTC FSA
"""
print("Starting CTC FSA Creation")
self.fsa.num_states = 0
cur_idx = 0
# if the graph fsa is empty use the provided list of strings
if self.fsa.lem_list is not None:
# goes through the list of strings
for idx, seq in enumerate(self.fsa.lem_list):
# goes through string
for i, label in enumerate(seq):
src_idx = 2 * cur_idx
if cur_idx == 0:
self.fsa.num_states += 1
trgt_idx = src_idx + 2
e_norm = Edge(src_idx, trgt_idx, seq[i])
e_norm.idx = cur_idx
e_norm.idx_word_in_sentence = idx
e_norm.idx_phon_in_word = i
# if two equal labels back to back in string -> skip repetition
if seq[i] != seq[i - 1] or len(seq) == 1:
self.fsa.edges.append(e_norm)
# adds blank labels and label repetitions
e_blank = Edge(src_idx, trgt_idx - 1, Edge.BLANK)
self.fsa.edges.append(e_blank)
e_rep = deepcopy(e_norm)
e_rep.source_state_idx = src_idx + 1
self.fsa.edges.append(e_rep)
cur_idx += 1
# add number of states
self.fsa.num_states += 2
# adds separator between words in sentence
if idx < len(self.fsa.lem_list) - 1:
self.fsa.edges.append(Edge(2 * cur_idx, 2 * cur_idx + 1, Edge.BLANK))
self.fsa.edges.append(Edge(2 * cur_idx + 1, 2 * cur_idx + 2, Edge.SIL))
self.fsa.edges.append(Edge(2 * cur_idx, 2 * cur_idx + 2, Edge.SIL))
self.fsa.num_states += 2
cur_idx += 1
# add node number of final state
self.final_states.append(self.fsa.num_states - 1)
# add all final possibilities
e_end_1 = Edge(self.fsa.num_states - 3, self.fsa.num_states, Edge.BLANK, 1.)
self.fsa.edges.append(e_end_1)
e_end_2 = Edge(self.fsa.num_states + 1, self.fsa.num_states + 2, Edge.BLANK, 1.)
self.fsa.edges.append(e_end_2)
e_end_3 = Edge(self.fsa.num_states, self.fsa.num_states + 1, self.fsa.lem_list[-1][-1], 1.)
self.fsa.edges.append(e_end_3)
self.fsa.num_states += 3
# add node nuber of final state
self.final_states.append(self.fsa.num_states - 1)
elif self.fsa.lem_edges is not None:
self.fsa.lem_edges.sort()
for idx, edge in enumerate(self.fsa.lem_edges):
# goes through fsa (list)
src_idx = 2 * cur_idx
if cur_idx == 0:
self.fsa.num_states += 1
trgt_idx = src_idx + 2
e_norm = deepcopy(edge)
e_norm.source_state_idx = src_idx
e_norm.target_state_idx = trgt_idx
e_norm.idx = cur_idx
e_norm.idx_word_in_sentence = idx
# if two equal labels back to back in string -> skip repetition
if self.fsa.lem_edges[idx].label != self.fsa.lem_edges[idx - 1].label or len(self.fsa.lem_edges) == 1:
self.fsa.edges.append(e_norm)
# adds blank labels and label repetitions
e_blank = Edge(src_idx, trgt_idx - 1, Edge.BLANK)
self.fsa.edges.append(e_blank)
e_rep = deepcopy(e_norm)
e_rep.source_state_idx = src_idx + 1
self.fsa.edges.append(e_rep)
cur_idx += 1
# add number of states
self.fsa.num_states += 2
# add node number of final state
self.final_states.append(self.fsa.num_states - 1)
# add all final possibilities
e_end_1 = Edge(self.fsa.num_states - 3, self.fsa.num_states, Edge.BLANK, 1.)
self.fsa.edges.append(e_end_1)
e_end_2 = Edge(self.fsa.num_states + 1, self.fsa.num_states + 2, Edge.BLANK, 1.)
self.fsa.edges.append(e_end_2)
e_end_3 = deepcopy(self.fsa.lem_edges[-1])
e_end_3.source_state_idx = self.fsa.num_states
e_end_3.target_state_idx = self.fsa.num_states + 1
self.fsa.edges.append(e_end_3)
self.fsa.num_states += 3
# add node nuber of final state
self.final_states.append(self.fsa.num_states - 1)
else:
assert False, "Something went wrong! Graph does not have a lemma list or fsa for CTC"
# make single final node
if not (len(self.final_states) == 1 and self.final_states[0] == self.fsa.num_states - 1):
# add new single final node
self.fsa.num_states += 1
for fstate in self.final_states:
# find edges which end in final nodes
final_state_idx_list = [edge_idx for edge_idx, edge in enumerate(self.fsa.edges)
if edge.target_state_idx == fstate]
# add edge from final nodes to new single final node
final_state_node = Edge(fstate, self.fsa.num_states - 1,
self.fsa.edges[final_state_idx_list[0]].label)
self.fsa.edges.append(final_state_node)
for final_state_idx in final_state_idx_list:
# add edges from nodes which go to final nodes
final_state_edge = deepcopy(self.fsa.edges[final_state_idx])
final_state_edge.target_state_idx = self.fsa.num_states - 1
self.fsa.edges.append(final_state_edge)
# add loops to CTC FSA
for loop_idx in range(1, self.fsa.num_states - 1):
edges_add_loop = [edge_idx for edge_idx, edge_cur in enumerate(self.fsa.edges)
if (edge_cur.target_state_idx == loop_idx)]
edge = deepcopy(self.fsa.edges[edges_add_loop[0]])
edge.source_state_idx = edge.target_state_idx
edge.is_loop = True
self.fsa.edges.append(edge)
# label conversion
if self.label_conversion:
Store.label_conversion(self.fsa.edges)
self.fsa.edges.sort()
self.fsa.num_states_ctc = deepcopy(self.fsa.num_states)
self.fsa.num_states = -1
self.fsa.edges_ctc = deepcopy(self.fsa.edges)
self.fsa.edges = []
class Hmm:
"""
class to create HMM FSA
"""
def __init__(self, fsa, depth=6, allo_num_states=3, state_tying_conversion=False):
"""
:param Graph fsa: represents the Graph on which the class operates
:param int depth: the depth of the HMM FSA process
:param int allo_num_states: number of allophone states
where allo_num_states > 0
:param bool state_tying_conversion: flag for state tying conversion
"""
if isinstance(fsa, Graph) and isinstance(depth, int) and isinstance(allo_num_states, int):
self.fsa = fsa
self.depth = depth
self.allo_num_states = allo_num_states
self.state_tying_conversion = state_tying_conversion
else:
assert False, ('The HMM init went wrong', fsa)
# Lexicon|None lexicon: lexicon for transforming a word into allophones
self.lexicon = None
# StateTying|None state_tying: holds the transformation from created label to number
self.state_tying = None
# dict phon_dict: dictionary of phonemes, loaded from lexicon file
self.phon_dict = {}
@staticmethod
def _find_node_in_edges(node, edges):
"""
find a specific node in all edges
:param int node: node number
:param list edges: all edges
:return node_dict: dict of nodes where
key: edge index
value: 0 = specific node is as source state idx
value: 1 = specific node is target state idx
value: 2 = specific node is source and target state idx
:rtype: dict
"""
node_dict = {}
pos_start = [edge_index for edge_index, edge in enumerate(edges)
if (edge.source_state_idx == node)]
pos_end = [edge_index for edge_index, edge in enumerate(edges)
if (edge.target_state_idx == node)]
pos_start_end = [edge_index for edge_index, edge in enumerate(edges) if
(edge.source_state_idx == node and edge.target_state_idx == node)]
for pos in pos_start:
node_dict[pos] = 0
for pos in pos_end:
node_dict[pos] = 1
for pos in pos_start_end:
node_dict[pos] = 2
return node_dict
@staticmethod
def _build_allo_syntax_for_mapping(edge):
"""
builds a conforming allo syntax for mapping
:param Edge edge: edge to build the allo syntax from
:return allo_map: a allo syntax ready for mapping
:rtype: str
"""
if edge.label == Edge.SIL:
allo_map = "%s{#+#}" % '[SILENCE]'
elif edge.label == Edge.EPS:
allo_map = "*"
else:
if edge.label_prev == '' and edge.label_next == '':
allo_map = "%s{#+#}" % edge.label
elif edge.label_prev == '':
allo_map = "%s{#+%s}" % (edge.label, edge.label_next)
elif edge.label_next == '':
allo_map = "%s{%s+#}" % (edge.label, edge.label_prev)
else:
allo_map = "%s{%s+%s}" % (edge.label, edge.label_prev, edge.label_next)
if edge.phon_at_word_begin:
allo_map += '@i'
if edge.phon_at_word_end:
allo_map += '@f'
if edge.label == Edge.SIL:
allo_map += ".0"
elif edge.label == Edge.EPS:
allo_map += ""
elif edge.allo_idx is not None:
allo_map += "." + str(edge.allo_idx)
return allo_map
def run(self):
"""
creates the HMM FSA
"""
print("Starting HMM FSA Creation")
self.fsa.num_states_hmm = 0
split_node = 0
merge_node = 0
for word_idx, word in enumerate(self.fsa.lem_list):
if word_idx == 0:
# add first silence and eps
self.fsa.edges.append(Edge(0, 1, Edge.SIL))
self.fsa.edges.append(Edge(0, 1, Edge.EPS))
self.fsa.num_states += 2
# get word with phons from lexicon
self.phon_dict[word] = self.lexicon.lemmas[word]['phons']
# go through all phoneme variations for a given word
for lemma_idx, lemma in enumerate(self.phon_dict[word]):
# go through the phoneme variations phoneme by phoneme
lem = lemma['phon'].split(' ')
phon_dict_len = len(self.phon_dict[word])
for phon_idx, phon in enumerate(lem):
if phon_dict_len == 1:
# only one phoneme variation - no split!!!
source_node = self.fsa.num_states
target_node = self.fsa.num_states + 1
self.fsa.num_states += 1
else:
# several phoneme variations - split and merge required
if lemma_idx == 0:
# save split node
if phon_idx == 0:
split_node = self.fsa.num_states
# save merge node
if phon_idx == len(lem) - 1:
merge_node = self.fsa.num_states + 1
# add appropriate number of states
if lemma_idx == 0:
# set source and target node for first phoneme variation
source_node = self.fsa.num_states
target_node = self.fsa.num_states + 1
self.fsa.num_states += 1
else:
if phon_idx != 0:
self.fsa.num_states += 1
# set source and target node for split / merge
source_node = split_node if phon_idx == 0 else self.fsa.num_states
target_node = merge_node if phon_idx == len(lem) - 1 else self.fsa.num_states + 1
# edge creation
phon_edge = Edge(source_node, target_node, phon)
# triphone labels if current pos at first or last phon
if phon_idx == 0:
phon_edge.label_prev = ''
else:
phon_edge.label_prev = lem[phon_idx - 1]
if phon_idx == len(lem) - 1:
phon_edge.label_next = ''
else:
phon_edge.label_next = lem[phon_idx + 1]
# assign score
if phon_idx == 0:
phon_edge.score = lemma['score']
# position of phon in word and word in sentence
phon_edge.idx_word_in_sentence = word_idx
phon_edge.idx_phon_in_word = phon_idx
phon_edge.idx = self.fsa.num_states + phon_idx
# phon at word begin / end
if phon_idx == 0:
phon_edge.phon_at_word_begin = True
if phon_idx == len(lem) - 1:
phon_edge.phon_at_word_end = True
# add to graph
self.fsa.edges.append(phon_edge)
# add silence and eps after word
self.fsa.edges.append(Edge(target_node, self.fsa.num_states + 1, Edge.SIL))
self.fsa.edges.append(Edge(target_node, self.fsa.num_states + 1, Edge.EPS))
self.fsa.num_states += 1
# final node
self.fsa.num_states += 1
edges_allo_tmp = []
if self.allo_num_states > 1:
for edge in self.fsa.edges: # do not add to list you are looping over XD
if edge.label != Edge.SIL and edge.label != Edge.EPS:
allo_target_idx = edge.target_state_idx
for state in range(self.allo_num_states):
if state == 0:
edge.target_state_idx = self.fsa.num_states
edge.allo_idx = state
elif state == self.allo_num_states - 1:
edge_1 = deepcopy(edge)
edge_1.allo_idx = state
edge_1.source_state_idx = self.fsa.num_states
edge_1.target_state_idx = allo_target_idx
self.fsa.num_states += 1
edges_allo_tmp.append(edge_1)
else:
self.fsa.num_states += 1
edge_2 = deepcopy(edge)
edge_2.allo_idx = state
edge_2.source_state_idx = self.fsa.num_states - 1
edge_2.target_state_idx = self.fsa.num_states
edges_allo_tmp.append(edge_2)
self.fsa.edges.extend(edges_allo_tmp)
sort_idx = 0
while sort_idx < len(self.fsa.edges):
cur_source_state = self.fsa.edges[sort_idx].source_state_idx
cur_target_state = self.fsa.edges[sort_idx].target_state_idx
if cur_source_state > cur_target_state: # swap is needed
edges_with_cur_source_state = self._find_node_in_edges(
cur_source_state, self.fsa.edges) # find start node in all edges
edges_with_cur_target_state = self._find_node_in_edges(
cur_target_state, self.fsa.edges) # find end node in all edges
for edge_key in edges_with_cur_source_state.keys(): # loop over edges with specific node
if edges_with_cur_source_state[edge_key] == 0: # swap source state
self.fsa.edges[edge_key].source_state_idx = cur_target_state
elif edges_with_cur_source_state[edge_key] == 1:
self.fsa.edges[edge_key].target_state_idx = cur_target_state
elif edges_with_cur_source_state[edge_key] == 2:
self.fsa.edges[edge_key].source_state_idx = cur_target_state
self.fsa.edges[edge_key].target_state_idx = cur_target_state
else:
assert False, ("Dict has a non matching value:",
edge_key, edges_with_cur_source_state[edge_key])
for edge_key in edges_with_cur_target_state.keys(): # edge_key: idx from edge in edges
if edges_with_cur_target_state[edge_key] == 0: # swap target state
self.fsa.edges[edge_key].source_state_idx = cur_source_state
elif edges_with_cur_target_state[edge_key] == 1:
self.fsa.edges[edge_key].target_state_idx = cur_source_state
elif edges_with_cur_target_state[edge_key] == 2:
self.fsa.edges[edge_key].source_state_idx = cur_source_state
self.fsa.edges[edge_key].target_state_idx = cur_source_state
else:
assert False, ("Dict has a non matching value:",
edge_key, edges_with_cur_source_state[edge_key])
# reset idx: restarts traversing at the beginning of graph
# swapping may introduce new disorders
sort_idx = 0
sort_idx += 1
# add loops
for state in range(1, self.fsa.num_states):
edges_included = [edge_index for edge_index, edge in enumerate(self.fsa.edges) if
(edge.target_state_idx == state and edge.label != Edge.EPS)]
for edge_inc in edges_included:
edge_loop = deepcopy(self.fsa.edges[edge_inc])
edge_loop.source_state_idx = edge_loop.target_state_idx
self.fsa.edges.append(edge_loop)
# state tying labels or numbers
for edge in self.fsa.edges:
allo_syntax = self._build_allo_syntax_for_mapping(edge)
edge.label = allo_syntax
if self.state_tying_conversion:
if edge.label == Edge.EPS:
pass
elif allo_syntax in self.state_tying.allo_map:
allo_id = self.state_tying.allo_map[allo_syntax]
edge.label = allo_id
else:
print("Error converting label:", edge.label, allo_syntax)
self.fsa.edges.sort()
self.fsa.num_states_hmm = deepcopy(self.fsa.num_states)
self.fsa.num_states = -1
self.fsa.edges_hmm = deepcopy(self.fsa.edges)
self.fsa.edges = []
class AllPossibleWordsFsa:
"""
constructs a fsa from all words in a lexicon
"""
def __init__(self, fsa):
"""
takes a lexicon file and constructs a fsa over all words
:param Graph fsa: the graph which holds the constructed fsa
"""
self.fsa = fsa
self.lexicon = None
def run(self):
print("Starting All Possible Words FSA Creation")
for key, value in self.lexicon.lemmas.iteritems(): # for python 3: .items()
edge = Edge(0, 0, key, 0)
self.fsa.edges_word.append(edge)
self.fsa.num_states_word = 1
class Ngram:
"""
constructs a fsa with a n-gram lm
"""
def __init__(self, n):
"""
constructs a fsa over a lexicon with n-grams
:param int n: size of the gram (1, 2, 3)
"""
self.n = n
self.lexicon = None # type: Lexicon
# lexicon consists of 3 entries: phoneme_list, phonemes and lemmas
# phoneme_list: list of string phonemes in the lexicon
# phonemes: dict of dict of str {phone: {index: , symbol: , variation:}}
# lemmas: dict of dict of (str, list of dict) {orth: {orth: , phons: [{score: , phon:}]}}
self.lemma_list = []
self.ngram_scores = None
self.ngram_list = []
self.num_states = 0
self.edges = []
# TODO take lexicon and generate a fsa for ngram lm
def _create_lemma_list(self):
"""
create list of lemmas from lexicon
transform lexicon.lemmas into a list of str
"""
for lemma, lemma_dict in self.lexicon.lemmas.items():
self.lemma_list.append(lemma)
def _create_ngram_list(self):
"""
creates a ngram list from list of lemmas
permute over the created list
"""
for perm in itertools.permutations(self.lemma_list, self.n):
self.ngram_list.append(perm)
def _create_fsa_from_ngram_list(self):
"""
takes a ngram list and converts it into a fsa
"""
for idx, ngram in enumerate(self.ngram_list):
ngram_edge = Edge(idx, idx + 1, ngram, 0.)
self.edges.append(ngram_edge)
self.num_states += 1
if self.ngram_list:
self.num_states += 1
def run(self):
print("Starting {}-gram FSA Creation".format(self.n))
if not self.lemma_list:
self._create_lemma_list()
node_expand = []
node_expand.append(0)
ngram_counter = 1
while node_expand:
cur_start = node_expand.pop()
for idx, lemma in enumerate(self.lemma_list):
cur_end = self.num_states + 1 # cur_start + idx + 1
edge = Edge(cur_start, cur_end, lemma, 0.)
self.edges.append(edge)
self.num_states += 1
if ngram_counter < self.n:
node_expand.append(cur_end)
print(self.num_states, cur_start, cur_end, idx, lemma)
ngram_counter += 1
if self.lemma_list:
self.num_states += 1
print(self.num_states)
def load_lexicon(lexicon_name='recog.150k.final.lex.gz', pickleflag=False):
"""
loads Lexicon
takes a file, loads the xml and returns as Lexicon
a pickled file can be loaded for a speed improvement
where:
lex.lemmas and lex.phonemes important
:param str lexicon_name: holds the path and name of the lexicon file
:param bool pickleflag: flag to indicate if the lexicon datastructure is to be pickled
:return lexicon: lexicon datastructure
:rtype: Lexicon
"""
log.initialize(verbosity=[5])
lexicon_dumpname = lexicon_name.rstrip('\.gz') + '.pickle'
if pickleflag:
# loads from pickled lexicon file
if isfile(lexicon_dumpname):
print("Loading pickled lexicon")
with open(lexicon_dumpname, 'rb') as lexicon_load:
lexicon = pickle.load(lexicon_load)
else: # pickled lexicon file does not exists -> now created
assert isfile(lexicon_name), "Lexicon file does not exist"
lexicon = Lexicon(lexicon_name)
print("Saving pickled lexicon")
with open(lexicon_dumpname, 'wb') as lexicon_dump:
pickle.dump(lexicon, lexicon_dump)
else:
# loads from non-pickled lexicon file
assert isfile(lexicon_name), "Lexicon file does not exist"
lexicon = Lexicon(lexicon_name)
return lexicon
def load_state_tying(state_tying_name='state-tying.txt'):
"""
loads a state tying map from a file, loads the file and returns its content
state tying slower with pickling
where:
statetying.allo_map important
:param state_tying_name: holds the path and name of the state tying file
:return state_tying: state tying datastructure
:rtype: StateTying
"""
log.initialize(verbosity=[5])
assert isfile(state_tying_name), "State tying file does not exist"
state_tying = StateTying(state_tying_name)
return state_tying
class Store:
"""
Conversion and save class for FSA
"""
def __init__(self, num_states, edges, filename='edges', path='./tmp/', file_format='svg'):
"""
:param int num_states: number of states of FSA
:param list[Edge] edges: list of edges representing FSA
:param str filename: name of the output file
:param str path: location
:param str file_format: format in which to save the file
"""
self.num_states = num_states
self.edges = edges
self.filename = filename
self.path = path
self.file_format = file_format
# noinspection PyPackageRequirements,PyUnresolvedReferences
import graphviz
self.graph = graphviz.Digraph(format=self.file_format)
def fsa_to_dot_format(self):
"""
converts num_states and edges within the graph to dot format
"""
self.add_nodes(self.graph, self.num_states)
self.add_edges(self.graph, self.edges)
def save_to_file(self):
"""
saves dot graph to file
settings: filename, path
caution: overwrites already present files
"""
# noinspection PyArgumentList
save_path = self.graph.render(filename=self.filename, directory=self.path)
print("FSA saved in:", save_path)
@staticmethod
def label_conversion(edges):
"""
coverts the string labels to int labels
:param list[Edge] edges: list of edges describing the fsa graph
:return edges:
:rtype: list[Edges]
"""
for edge in edges:
lbl = edge.label
if lbl == Edge.BLANK:
edge.label = ord(' ')
elif lbl == Edge.SIL or lbl == Edge.EPS or isinstance(lbl, int):
pass
elif isinstance(lbl, str):
edge.label = ord(lbl)
else:
assert False, "Label Conversion failed!"
@staticmethod
def add_nodes(graph, num_states):
"""
add nodes to the dot graph
:param Digraph graph: add nodes to this graph
:param int num_states: number of states equal number of nodes
"""
nodes = []
for i in range(0, num_states):
nodes.append(str(i))
for n in nodes:
graph.node(n)
@staticmethod
def add_edges(graph, edges):
"""
add edges to the dot graph
:param Digraph graph: add edges to this graph
:param list[Edge] edges: list of edges
"""
for edge in edges:
if isinstance(edge.label, int):
label = edge.label
elif '{' in edge.label:
label = edge.label
elif edge.label_prev is not None and edge.label_next is not None:
label = [edge.label_prev, edge.label, edge.label_next]
if edge.allo_idx is not None:
label.append(edge.allo_idx)
# TODO add label creation for fst
else:
label = edge.label
e = ((str(edge.source_state_idx), str(edge.target_state_idx)), {'label': str(label)})
graph.edge(*e[0], **e[1])
class BuildSimpleFsaOp(theano.Op):
itypes = (T.imatrix,)
# the first and last output are actually uint32
otypes = (T.fmatrix, T.fvector, T.fmatrix)
def __init__(self, state_models=None):
if state_models is None:
state_models = {}
self.state_models = state_models
def perform(self, node, inputs, output_storage, params=None):
labels = inputs[0]
from_states = []
to_states = []
emission_idxs = []
seq_idxs = []
weights = []
start_end_states = []
cur_state = 0
edges = []
weights = []
start_end_states = []
for b in range(labels.shape[1]):