diff --git a/k2/csrc/dense_fsa.h b/k2/csrc/dense_fsa.h new file mode 100644 index 000000000..a7c29c297 --- /dev/null +++ b/k2/csrc/dense_fsa.h @@ -0,0 +1,309 @@ +// k2/csrc/dense_fsa.h + +// Copyright (c) 2020 Daniel Povey + +// See ../../LICENSE for clarification regarding multiple authors + +#ifndef K2_CSRC_DENSE_FSA_H_ +#define K2_CSRC_DENSE_FSA_H_ + +#include +#include +#include +#include + +#include "glog/logging.h" +#include "k2/csrc/util.h" +#include "k2/csrc/fsa.h" + +namespace k2 { + + +/* + DenseFsa represents an FSA stored as a matrix, representing something + like CTC output from a neural net. `data` is a (T+1) by N + matrix, where N is the number of symbols (including blank/zero). + The last row of this matrix contains only zeros; this is where it + gets the (zero) weight for the final-arc. It may seem odd to + actually have to store the zero, but for the autograd to work + correctly we need all arcs to have an arc-index. + + Physically, we would access weights[t,n] as weights[t * t_stride + n]. + + This FSA has T + 2 states, with state 0 the start state and state T + 2 + the final state. (Caution: if we formulated our FSAs more normally we + would have T + 1 states, but because we represent final-probs via an + arc with symbol kFinalSymbol on it to the last state, we need one + more state). For 0 <= t < T, we have an arc with symbol n on it for + each 0 <= n < N, from state t to state t+1, with weight equal to + weights[t,n]. + */ +struct DenseFsa { + int32_t T; + int32_t num_symbols; + int32_t arc_offset; + + const float *data; // Would typically be a log-prob or unnormalized log-prob + + /* + The next few functions provide an interface more similar to struct + Fsa. We don't necessarily recommending using these functions much; + it may be more efficient to use what is known about the structure + of this object. But they may be useful for documentation and testing. + */ + int32_t num_states() { return T + 2; } + int32_t arc_indexes(int32_t state_index) { + return arc_offset + (state_index <= T ? state_index*num_symbols : + T*num_symbols + 1); + } + Arc arc(int32_t arc_index) { + arc_index -= arc_offset; + int32_t state_index = (arc_index-arc_offset) / num_symbols; + return Arc(state_index, state_index+1, + (state_index < T ? (arc_index%num_symbols) : kFinalSymbol)); + } + + + /* Constructor + @param [in] T number of frames / sequence length. This FSA has T+2 states; + state T+1 is the final-state, and state T has only a single + arc, with symbol kFinalSymbol, with arc-index + `arc_index = arc_offset+(T*num_symbols)`, to state T+1; the weight + on this is data[arc_index] == 0. + + All other states t( + seg_frame_index + (num_segs + 1)); + } + int32_t frame_info_dim() { return seg_frame_index[num_segs]; } + int32_t num_frames_padded() { return seg_frame_index[num_segs]; } + + + // and next we have the following, which will be used from + // the Python calling code to copy the neural-net output to the correct + // location. This lists the elements of `frame_info` but with the + // padding frames removed. + // + // DenseFsaVecFrameCopyInfo copy_info[num_frames]; + // where num_frames == \sum_segment length(segment) == total number of nnet + // output frames over all segments. + // + // int32_t frame_index[num_frames]; # frame_index will be an index into `frame_info`; + // # this will be of the form 0 1 2 4 5 6 7 8 9 11 ... + // # (note the gaps where the zero padding was!) + int32_t* frame_index() { + return reinterpret_cast(frame_info() + frame_info_dim()); + } + int32_t frame_index_dim() { + int32 num_frames_padded = seg_frame_index[num_segs], + num_frames = num_frames_padded - num_segs; + return num_frames; + } + + // The total size of this object in int32_t elements will equal: + // 3 + # for first 3 elements + // num_segs + 1 + # for seg_frame_index[] + // 4 * (num_segs + num_frames) + # == 4*num_frames_padded, for frame_info[] + // num_frames # for frame_index[] + // + // 4 + 5*num_segs + 5*num_frames +}; + + +struct DenseFsaVecFrameInfo { + int32_t seg_id; // The segment-id that this frame is part of (0 <= seg_id <= + // num_segs)... will be of the form 0 0 0 0 1 1 1 1 1 1 1 1 2 + // 2 2 3 3 3 ... + int32_t seq_id; // The sequence-id that the `seg_id`'th segment was part of. + // Would equal seg_id in the case where it was one segment + // per sequence. + int32_t frame_in_seg; // The frame-index within the segment, so would be 0 + // for the 1st frame of each segment, and `this_seg_num_frames` + // for the last (which could contain all zeros). + // Will be of the form 0 1 2 3 0 1 2 3 4 5 6 0 1 2 0 1 2.... + int32_t frame_in_seq; // The frame-index within the sequence that this segment + // is a part of. Would be the same as `frame_in_seg` if + // this segment starts at frame zero of its sequence. +}; + + +/** + Creates meta-info for DenseFsaVec (DenseFsaVecMeta) as one block in memory. + For some of the terminology, see the comment above the definition class + DenseFsaVec. + + First, some terminology. Note: some of this is more relevant to the + Python level here. Please note that seq == sequence and seg == segment. + The neural-network outputs, consisting of log likes, would be in a tensor + of shape (num_seqs, num_frames, num_symbols). I.e. we have + `num_seqs` sequences; each sequence has `num_frames` outputs, and + each frame of output has `num_symbols` symbols (e.g. phones or letters). + + There are `num_segs` segments. Each segment is a subset of the frames + in a sequence. + + @param [in] num_seqs Number of sequences of (e.g.) phone posteriors/loglikes + from the neural net + @param [in] frames_per_seq Number of frames in each sequence + @param [in] num_symbols Dimension of the neural network output, interpreted + for instance as epsilon/blank and the rest are phones + or letters. + @param [in] num_segs Number of segments. Each segment represents a range of + frames within a sequence. There will in general be + at least as many segments as sequences. + @param [in] seq_id Indexed by 0 <= seg_id < num_segs, seq_id[seg_id] contains the + sequence index 0 <= s < num_seqs to which this + segment belongs + @param [in] frame_begin Indexed by 0 <= seg_id < num_segs, frame_begin[seg_id] + contains the index of the first frame of that segment. + @param [in] frame_end Indexed by 0 <= seg_id < num_segs, frame_end[seg_id] + contains the index of the last-plus-one frame of that segment. + @param [in] storage_size Size of `storage` array, in int32_t elements. Defining + num_frames = sum(frame_end) - sum(frame_begin), + storage_size must equal 4 + 5*num_segs + 5*num_frames. + It is provided as an arg for checking purposes. + @param [out] storage Pointer to an array of int32_t where we put the meta-info (probably + part of a torch.Tensor). It will be interpreted internally + as type DenseFsaVecMeta. + */ +void CreateDenseFsaVecMeta( + int32_t num_seqs, + int32_t frames_per_seq, + int32_t num_symbols, + int32_t num_segs, + const int32_t *seq_id, + const int32_t *frame_begin, + const int32_t *frame_end, + ssize_t storage_size, + int32_t *storage); + + + +/** + DenseFsaVec represents a vector of FSAs with a special regular + structure. Suppose there are N FSAs, numbered n=0 .. N-1; + and suppose the symbol space has S symbols numbered 0, ... S-1 + (yes, 0 represents epsilon; and we're not including the "final symbol" + numbered -1). + + The n'th FSA corresponds to a log-likelihood matrix (call this M_n with M a + matrix) with T_n frames. Below, we'll just call this T for clarity. This + FSA has T+2 states, numbered 0, .. T+1. For 0 < t < T and 0 <= s < S, there + is an arc from state t to state t+1 with symbol s and log-like/weight + equal to M_n(t, s). From state T to T+1 there is a single arc with + symbol -1=kFinalSymbol and log-like/weight equal to 0.0. (Of course, state + T+1 is the final state.. this is how our framework works). + + */ +struct DenseFsaVec { + + + /* + Constructor. + + @param [in] meta The meta-info, as written to by CreateDenseFsaVecMeta(). + @param [in] data A contiguous, row-major matrix of shape + (meta_info->num_frames_padded(),meta_info->num_symbols), + containing the neural-net outputs for each segment with + zero rows in between for padding. + */ + DenseFsaVec(const DenseFsaVecMeta *meta, + const float *data): + meta(meta), data(data) { Check(); } + + + void Check(); // Sanity check (spot check, not thorough) on `meta_info` + + + const DenseFsaVecMeta *meta; + const float *data; + + DenseFsa operator[] (int32_t seg_id) { + CHECK_LT(seg_id, meta->num_segs); + int32_t start_frame_index = meta->seg_frame_index[seg_id], + end_frame_index = meta->seg_frame_index[seg_id+1]; + // below, the -1 is to exclude the zero-padding frame. + int23_t T = end_frame_index - start_frame_index - 1; + int32_t arc_offset = meta->num_symbols * start_frame_index; + return DenseFsa(T, num_symbols, arc_offset, this->data); + } +}; + +/* + Version of Intersect where `a` is dense? + */ +void Intersect(const DenseFsa &a, const Fsa &b, Fsa *c, + std::vector *arc_map_a = nullptr, + std::vector *arc_map_b = nullptr); + +/* + Version of Intersect where `a` is dense, pruned with pruning beam `beam`. + Suppose states in the output correspond to pairs (s_a, s_b), and have + forward-weights w(s_a, s_b), i.e. best-path from the start state... + then if a state has a forward-weight w(s_a, s_b) that is less than + (the largest w(s_a, x) for any x) minus the beam, we don't expand it. + + This is the same as time-synchronous Viterbi beam pruning. +*/ +void IntersectPruned(const DenseFsa &a, const Fsa &b, float beam, Fsa *c, + std::vector *arc_map_a = nullptr, + std::vector *arc_map_b = nullptr); + + +/* Convert DenseFsa to regular Fsa (for testing purposes) */ +void DenseToFsa(const DenseFsa &a, Fsa *b); + + +} // namespace k2 + +#endif // K2_CSRC_DENSE_FSA_H_ diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index ef5747cc7..0c2b2f78d 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -25,6 +25,10 @@ enum { // like in OpenFst. }; + +// CAUTION: the sizeof() this is probably 128, not 96. This could be a +// waste of space. We may later either use the extra field for something, or +// find a way to reduce the size. struct Arc { int32_t src_state; int32_t dest_state; @@ -122,39 +126,142 @@ struct Fsa { } }; + /* - DenseFsa represents an FSA stored as a matrix, representing something - like CTC output from a neural net. We view `weights` as a T by N - matrix, where N is the number of symbols (including blank/zero). - - Physically, we would access weights[t,n] as weights[t * t_stride + n]. - - This FSA has T + 2 states, with state 0 the start state and state T + 2 - the final state. (Caution: if we formulated our FSAs more normally we - would have T + 1 states, but because we represent final-probs via an - arc with symbol kFinalSymbol on it to the last state, we need one - more state). For 0 <= t < T, we have an arc with symbol n on it for - each 0 <= n < N, from state t to state t+1, with weight equal to - weights[t,n]. + Cfsa is a 'const' FSA, which we'll use as the input to operations. It is + designed in such a way that the storage underlying it may either be an Fsa + (i.e. with std::vectors) or may be some kind of tensor (probably CfsaVec). + Note: the pointers it holds aren't const for now, because there may be + situations where it makes sense to change them (even though the number of + states and arcs can't be changed). */ -struct DenseFsa { - float *weights; // Would typically be a log-prob or unnormalized log-prob - int32_t T; // The number of time steps == rows in the matrix `weights`; - // this FSA has T + 2 states, see explanation above. - int32_t num_symbols; // The number of symbols == columns in the matrix - // `weights`. - int32_t t_stride; // The stride of the matrix `weights` - - /* Constructor - @param [in] data Pointer to the raw data, which is a T by num_symbols - matrix with stride `stride`, containing logprobs - - CAUTION: we may later enforce that stride == num_symbols, in order to - be able to know the layout of a phantom matrix of arcs. (?) - */ - DenseFsa(float *data, int32_t T, int32_t num_symbols, int32_t stride); +struct Cfsa { + int32_t num_states; // number of states including final state. States are + // numbered `0 ... num_states - 1`. Start state is 0, + // final state is state `num_states - 1`. We store a + // redundant representation here out of a belief that it + // might reduce the number of instructions in code. + int32_t begin_arc; // a copy of arc_indexes[0]; gives the first index in + // `arcs` for the arcs in this FSA. Will be >= 0. + int32_t end_arc; // a copy of arc_indexes[num_states]; gives the + // one-past-the-last index in `arcs` for the arcs in this + // FSA. Will be >= begin_arc. + + const int32_t *arc_indexes; // an array, indexed by state index, giving the + // first arc index of each state. The last one + // is repeated, so for any valid state 0 <= s < + // num_states we can use arc_indexes[s+1]. That + // is: elements 0 through num_states (inclusive) + // are valid. CAUTION: arc_indexes[0] may be + // greater than zero. + + + Arc *arcs; // Note: arcs[BeginArcIndex()] through arcs[EndArcIndex() - 1] + // are valid. + + // Constructor from Fsa + Cfsa(const Fsa &fsa); + + Cfsa &operator = (const Cfsa &cfsa) = default; + Cfsa(const Cfsa &cfsa) = default; + + int32_t NumStates() const { return num_states; } + int32_t FinalState() const { return num_states - 1; } +}; + + + + +class CfsaVec { + public: + /* + Constructor from linear data, e.g. from data stored in a torch.Tensor. + This would previously have been created using CreateCfsaVec(). + + @param [in] size size in int32_t elements of `data`, only + needed for checking purposes. + @param [in] data The underlying data. Format of data is + described below (all elements are of type + int32_t unless stated otherwise). Would have + been created by CreateCfsaVec(). + + - version Format version number, currently always 1. + - num_fsas The number of FSAs + - state_offsets_start The offset from the start of `data` of + where the `state_offsets` array is, in int32_t + (4-byte) elements. + - arc_indexes_start The offset from the start of `data` of + where the `arc_indexes` array is, in int32_t + (4-byte) elements. + - arcs_start The offset from the start of `data` of where + the first Arc is, in sizeof(Arc) multiples, i.e. + Arc *arcs = ((Arc*)data) + arcs_start + + [possibly some padding here] + - state_offsets[num_fsas + 1] state_offsets[f] is the sum of + the num-states of all the FSAs preceding f. It is + also the offset from the beginning of the + `arc_indexes` array of where the part corresponding + to FSA f starts. The number of states in FSA f + is given by state_offsets[f+1] - state_offsets[f]. + This is >= 0; it will be zero if the + FSA f is empty, and >= 2 otherwise. + [possibly some padding here] + + - arc_indexes[tot_states + 1] This gives the indexes into the `arcs` + array of where we can find the first of each state's arcs. + + [pad as needed for memory-alignment purposes then...] + + - arcs[tot_arcs] + */ + CfsaVec(size_t size, void *data); + + int32_t NumFsas() const { return num_fsas_; } + + Cfsa operator[] (int32_t f) const; + + private: + CfsaVec &operator = (const CfsaVec &); // Disable + CfsaVec(const CfsaVec&); // Disable + + int32_t num_fsas_; + + // The raw underlying data + int32_t *data_; + // The size of the underlying data + size_t size_; }; + + + +/* + Return the number of bytes we'd need to represent this vector of Cfsas + linearly as a CfsaVec. */ +size_t GetCfsaVecSize(const std::vector &fsas_in); + +// Return the number of bytes we'd need to represent this Cfsa +// linearly as a CfsaVec with one element +size_t GetCfsaVecSize(const Cfsa &fsa_in); + +/* + Create a CfsaVec from a vector of Cfsas (this involves representing + the vector of Fsas in one big linear memory region). + + @param [in] fsas_in The vector of Cfsas to be linearized; + must be nonempty + @param [in] data The allocated data of size `size` bytes + @param [in] size The size of the memory block passed; + must equal the return value of + GetCfsaVecSize(fsas_in). + */ +void CreateCfsaVec(const std::vector &fsas_in, + void *data, + size_t size); + + + struct Fst { Fsa core; std::vector aux_label; @@ -185,7 +292,6 @@ class DeterministicGenericFsa { using FsaVec = std::vector; using FstVec = std::vector; -using DenseFsaVec = std::vector; } // namespace k2 diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index c0f26fcc1..93c7c3c66 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -167,25 +167,6 @@ bool Intersect(const Fsa &a, const Fsa &b, Fsa *c, std::vector *arc_map_a = nullptr, std::vector *arc_map_b = nullptr); -/* - Version of Intersect where `a` is dense? - */ -void Intersect(const DenseFsa &a, const Fsa &b, Fsa *c, - std::vector *arc_map_a = nullptr, - std::vector *arc_map_b = nullptr); - -/* - Version of Intersect where `a` is dense, pruned with pruning beam `beam`. - Suppose states in the output correspond to pairs (s_a, s_b), and have - forward-weights w(s_a, s_b), i.e. best-path from the start state... - then if a state has a forward-weight w(s_a, s_b) that is less than - (the largest w(s_a, x) for any x) minus the beam, we don't expand it. - - This is the same as time-synchronous Viterbi beam pruning. -*/ -void IntersectPruned(const DenseFsa &a, const Fsa &b, float beam, Fsa *c, - std::vector *arc_map_a = nullptr, - std::vector *arc_map_b = nullptr); /** Intersection of two weighted FSA's: the same as Intersect(), but it prunes @@ -249,44 +230,46 @@ bool TopSort(const Fsa &a, Fsa *b, std::vector *state_map = nullptr); /** Pruned determinization with log-sum on weights (interpret them as log-probs), equivalent to log semiring - @param [in] a Input FSA `a` to be determinized. Expected to be epsilon - free, but this is not checked; in any case, epsilon will be treated as a - normal symbol. Forward-backward weights must be provided for pruning - purposes; a.weight_type must be kLogSumWeight. + @param [in] a Input FSA `a` to be determinized. Expected to be epsilon free, but this + is not checked; in any case, epsilon will be treated as a normal symbol. + Forward-backward weights must be provided for pruning purposes; + a.weight_type must be kLogSumWeight. @param [in] beam Pruning beam; should be greater than 0. - @param [in] max_step Maximum number of computation steps before we - return (or if <= 0, there is no limit); provided so users can limit the time + @param [in] max_step Maximum number of computation steps before we return + (or if <= 0, there is no limit); provided so users can limit the time taken in pathological cases. - @param [out] b Output FSA; will be deterministic. For a symbol sequence - S accepted by a, the total (log-sum) weight of S in a should equal the total - (log-sum) weight of S in b (as discoverable by composition then finding the - total weight of the result), except as affected by pruning of course. + @param [out] b Output FSA; will be deterministic. For a symbol sequence S accepted by a, + the total (log-sum) weight of S in a should equal the total (log-sum) weight + of S in b (as discoverable by composition then finding the total + weight of the result), except as affected by pruning of course. @param [out] b_arc_weights Weights per arc of b. - @param [out] arc_derivs Indexed by arc in b, this is a list of pairs - (arc_in_a, x) where 0 < x <= 1 is the derivative of that arc's weight w.r.t. - the weight of `arc_in_a` in a. Note: the x values may actually be zero if - the pruning beam is very large, due to limited floating point range. - @return Returns the effective pruning beam, a value >= 0 which is the - difference between the total weight of the output FSA and the cost of the - last arc expanded. + @param [out] arc_derivs Indexed by arc in b, this is a list of pairs (arc_in_a, x) + where 0 < x <= 1 is the derivative of that arc's weight w.r.t. the + weight of `arc_in_a` in a. Note: the x values may actually be zero + if the pruning beam is very large, due to limited floating point range. + @return Returns the effective pruning beam, a value >= 0 which is the difference + between the total weight of the output FSA and the cost of the last + arc expanded. */ float DeterminizePrunedLogSum( - const WfsaWithFbWeights &a, float beam, int64_t max_step, Fsa *b, + const WfsaWithFbWeights &a, + float beam, + int64_t max_step, + Fsa *b, std::vector *b_arc_weights, - std::vector>> *arc_derivs); + std::vector > > *arc_derivs); /** - Pruned determinization with max on weights, equivalent to the tropical - semiring. + Pruned determinization with max on weights, equivalent to the tropical semiring. - @param [in] a Input FSA `a` to be determinized. Expected to be epsilon - free, but this is not checked; in any case, epsilon will be treated as a - normal symbol. Forward-backward weights must be provided for pruning - purposes; a.weight_type must be kMaxWeight. + @param [in] a Input FSA `a` to be determinized. Expected to be epsilon free, but this + is not checked; in any case, epsilon will be treated as a normal symbol. + Forward-backward weights must be provided for pruning purposes; + a.weight_type must be kMaxWeight. @param [in] beam Pruning beam; should be greater than 0. - @param [in] max_step Maximum number of computation steps before we - return (or if <= 0, there is no limit); provided so users can limit the time - taken in pathological cases. + @param [in] max_step Maximum number of computation steps before we return + (or if <= 0, there is no limit); provided so users can limit + the time taken in pathological cases. @param [out] b Output FSA; will be deterministic For a symbol sequence S accepted by a, the best weight of symbol-sequence S in a should equal the best weight of S in b (as discoverable @@ -300,15 +283,18 @@ float DeterminizePrunedLogSum( arcs in `a` that this arc in `b` corresponds to; the weight of the arc in b will equal the sum of those input arcs' weights. - @return Returns the effective pruning beam, a value >= 0 which is the - difference between the total weight of the output FSA and the cost of the - last arc expanded. + @return Returns the effective pruning beam, a value >= 0 which is the difference + between the total weight of the output FSA and the cost of the last + arc expanded. */ -float DeterminizePrunedMax(const WfsaWithFbWeights &a, float beam, - int64_t max_step, Fsa *b, +float DeterminizePrunedMax(const WfsaWithFbWeights &a, + float beam, + int64_t max_step, + Fsa *b, std::vector *b_arc_weights, std::vector> *arc_derivs); + /* Create an acyclic FSA from a list of arcs. Arcs do not need to be pre-sorted by src_state. diff --git a/k2/csrc/properties.h b/k2/csrc/properties.h index 23db82ba6..3aabe114d 100644 --- a/k2/csrc/properties.h +++ b/k2/csrc/properties.h @@ -20,7 +20,8 @@ enum Properties { kTopSortedAndAcyclic, // topologically sorted and no self-loops (which // implies acyclic) kAcyclic, // acyclic - kArcSorted, // arcs leaving each state are sorted on label + kArcSorted, // arcs leaving each state are sorted on label and then + // destination state kDeterministic, // no state has two arcs leaving it with the same label kConnected, // all states are both accessible (i.e. from start state) and // coaccessible (i.e. can reach final-state) @@ -32,7 +33,7 @@ enum Properties { `fsa` is valid if: 1. it is empty, if not, it contains at least two states. 2. only kFinalSymbol arcs enter the final state. - 3. `arcs_indexes` and `arcs` in this state are not consistent. + 3. `arc_indexes` and `arcs` in this state are consistent TODO(haowen): add more rules? */ bool IsValid(const Fsa &fsa); diff --git a/notes/python.txt b/notes/python.txt index 274e0ef87..33250d122 100644 --- a/notes/python.txt +++ b/notes/python.txt @@ -4,7 +4,6 @@ - # Assumes that A is an acceptor but B may # have auxiliary symbols (i.e. may be a transducer). def TransducerCompose(a: FsaVec, a_weights: Tensor, @@ -86,3 +85,227 @@ # TODO: handle transfers to/from GPU in case grad_out was on GPU. # Maybe mark this only once differentiable (it's twice differentiable, # I think, but this code doesn't currently support that). + + + + + ===================== + From here is some ideas on how we'd use this in a program. + + + # decoding_graph is an FsaVec, graph_weights is a float Tensor, + # graph_word_syms is a LongTensor (both with 1 axis + + # decoding_graph will have the following extra attributes: + # decoding_graph.weights, + # decoding_graph.word_syms + # both of shape (num_arcs,), of dtypes float and long respectively. + decoding_graph = fsa.ReadDecodingGraph('a/b/c.fsa') + + + # nnet_output's shape is (num_seq, num_frames, num_symbols).. these symbols + # might be phones or letters or small word-pieces. + nnet_output = model(input_feats) + + # Interpret each sequence numbered `n` in `nnet_output` as being a FSA with + # `num_frames + 1` states numbered 0, 1, ... num_frames, and for each 0 <= i + # < num_frames, arcs from state i to i+1 with each symbol `s` as the label + # and the output given by `nnet_output[n,i,s]`. This is a lightweight + # operation (except for the fact that it transfers the matrix to CPU for + # now). + # NOTE: the above is a slight simplification because the cuts may not + # span the entirety of `nnet_output`, they may be sub-sequences of frames + # within there, and `cut_info` will describe that somehow. So the ends + # of the cuts may be "ragged". + nnet_output_fsas = fsa.DenseFsaVec(nnet_output, cut_info) + + # nnet_output_fsas.input_indexes is a LongTensor with dimension (num_arcs_in_nnet_output_fsas, 3) + # where the 3 is are indexes n,i,s into nnet_output. + nnet_output_fsas.input_indexes = nnet_output_fsas.GetIndexes() + + # ... and implicitly nnet_output_fsas has a `weights` vector, which is a + # reshape/copy of parts of `nnet_output`. (It's not just a reshape of the + # whole thing due to its ragged structure). + + + # Composing with the FSAs representing the supervision; this is CTC aligment. + # alignment_fsas will be an FsaVec. + alignment_fsas = fsa.compose_pruned(nnet_output_fsas, supervision_fsas, beam=10.0) + + # objf_part1 is the CTC part of the objective; we'll later add it with the others. + objf_part1 = fsa.GetTotalLogWeight(alignment_fsas).sum() + + + # first_pass_fsas will be an FsaVec. + # This differs from `alignment_fsas` because it allows all possible word sequences, + # not just the supervision one(s). + # For a while, now, the code path will be the same as we'll take in test time. + # + + first_pass_fsas = fsa.compose_pruned(nnet_output_fsas, decoding_graph, beam=10.0) + + + # `first_pass_fsas` will have the following attributes: + # + # input0.{arc_indexes,weights,input_indexes} + # input1.{arc_indexes,weights,word_syms} + # + # and the following derived/computed params: + # weights = input0.weights + input1.weights + # word_syms = input1.word_syms + + # We'll want to propagate `nnet_arc_indexes` forward, so we can + # keep track of alignments. + first_pass_fsas.nnet_arc_indexes = first_pass_fsas.input0.arc_indexes + + # the following computes arc-level log-posteriors + log_posts = first_pass_fsas.ComputeLogPosteriors() + + # Also get the symbol-level posteriors: + + # The following will be the posteriors of a particular phone (or blank) at a particular + # time-index of a particular sequence. see fsautil.py. + first_pass_fsas.phone_log_posts = fsautil.sum_by_index( + log_posts.exp(), + first_pass_fsas.input0.arc_indexes).log() + + # the following mean "the posterior of the current word". + first_pass_fsas.word_log_posts = fsautil.sum_by_index( + log_posts.exp(), + torch.cat(first_pass_fsas.input0.arc_indexes, + first_pass_fsas.word_syms)) + + + # aggregate all these weights ? + # first_pass_fsas.all_weights = torch.cat( + # first_pass_fsas.input0.weights.unsqueeze(0), # acoustic weight + # first_pass_fsa.input1.weights.unsqueeze(0), # graph weight + # first_pass_fsas.phone_log_posts.unsqueeze(0), + # first_pass_fsas.word_log_posts.unsqueeze(0)) + + + + objf_part2 = + + + # OK, now we want to do a pass of RNNLM rescoring. Compose with a + # virtual FSA (DeterministicOnDemandFsa) that has words as the + # labels. We can make this a bit more efficient by determinizing the + # input first. + + + # we need the following to propagate members weights,all_weights. + # Invert the FST, swapping its symbols with word_syms. + fsas_inverted = fsa.Invert(first_pass_fsas, + first_pass_fsas.word_syms, + other_label_name='phone_syms', + keep=[...]) + + fsas_rmeps = fsa.RmEpsWeighted(first_pass_fsas) + fsas_det = fsa.DeterminizeTropicalPruned(fsas_rmeps, beam=..) + + # note: fsas_det still has `phone_syms` and `nnet_arc_indexes`, now as sequences. + + + +def DenseFsaVec: + + """Represents a vector of FSAs, but with a special regular + structure. Each FSA would normally correspond to one supervised + segment within an acoustic sequence. This wraps the data + output from the neural net. Each segment has T+2 states + numbered 0, 1, .. T, T+1 (the T+1'th is the final-state and + only has final-arcs entering it). From state i to i+1 + there is an arc for each symbol, whose loglike comes from + lookup in the neural-net output.""" + + + def __init__(self, loglikes, seq_indexes, start_times, end_times): + """ Constructor. + Params: + loglikes (torch.Tensor): The tensor of log-likelihoods + output by the neural network. Will be interpreted + as having shape (num_seqs, num_frames, num_symbols). + Here `num_symbols` includes epsilon. + seq_indexes, start_times, end_times (torch.Tensor): + These must all have the same shape, of the form + (num_segments,). Here, num_segments would normally + be >= num_seqs (each sequence may have more than + one supervised segment in it, and they may + overlap). + - seq_indexes says, for each segment, which sequence + it is a part of + - start_times says, for each segment, what the first + frame-index in `loglikes` is + - start_times says, for each segment, what the + one-past-the-last frame-index in `loglikes` is. + """ + # note: this will reate a csrc.CfsaVec object internal to this + # object. + # It will also create self.arc_loglikes containing the + # loglikes, one per arc of the CfsaVec object. This is + # a repeat of `loglikes` but possibly in a different + # order. + pass + + @property + def loglikes(self): + return self.arc_loglikes + + + def seg_frames_for_arcs(self, arc_indexes): + """ + Returns the frame-indexes relative to the start of each segment + for each of a provided list of arc indexes, as a torch.LongTensor. + """ + + # Note: self.seg_frame_indexes will be a torch.IntTensor containing + # the frame index for each arc. Later we'll address not being + # able to index with IntTensor but only LongTensor. + return self.seg_frame_indexes[arc_indexes / self.num_symbols] + + def seq_frames_for_arcs(self, arc_indexes): + """ + Returns the frame-indexes relative to the start of each sequence + for each of a provided list of arc indexes, as a torch.LongTensor. + + Note: if a returned frame-index equals num_frames, then that + frame was a `final-arc` (a special arc going to the final state), + which cannot be used to index the `loglikes` array provided to + the constructor because it's out-of-range. + """ + + # Note: self.seq_frame_indexes will be a torch.IntTensor containing + # the frame index for each arc. Later we'll address not being + # able to index with IntTensor but only LongTensor. + return self.seq_frame_indexes[arc_indexes / self.num_symbols] + + def segments_for_arcs(self, arc_indexes): + """ + Return the segment-indexes for each of a provided list of arcs, + which tells you which segment it was a part of. + """ + return self.segment_indexes[arc_indexes / self.num_symbols] + + def seqs_for_arcs(self, arc_indexes): + """ + Return the segment-indexes for each of a provided list of arcs, + which tells you which sequence it was a part of. + """ + return self.input_seq_indexes[self.segments_for_arcs(arc_indexes)] + + + + + + +# compute posteriors.. + first_pass_posts = + + + + + + + + nnet_post = log_softmax(nnet_output) # might use this later for something..