Skip to content

Commit

Permalink
Merge pull request k2-fsa#1 from qindazhu/haowen-arc-score
Browse files Browse the repository at this point in the history
Fix some issues of k2-fsa#587 to pass test
  • Loading branch information
danpovey authored Jan 15, 2021
2 parents d7ebea3 + 6fdb210 commit 02c30ed
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 163 deletions.
15 changes: 7 additions & 8 deletions k2/python/csrc/torch/fsa.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ static void PybindFsaUtil(py::module &m) {
// the following methods are for debugging only
m.def("fsa_to_fsa_vec", &FsaToFsaVec, py::arg("fsa"));

m.def("get_fsa_vec_element", &GetFsaVecElement, py::arg("vec"),
py::arg("i"));
m.def("get_fsa_vec_element", &GetFsaVecElement, py::arg("vec"), py::arg("i"));

m.def(
"create_fsa_vec",
Expand Down Expand Up @@ -210,8 +209,8 @@ static void PybindGetBackwardScores(py::module &m, const char *name) {
[](FsaVec &fsas, Ragged<int32_t> &state_batches,
Ragged<int32_t> &leaving_arc_batches,
bool log_semiring = true) -> torch::Tensor {
Array1<T> ans = GetBackwardScores<T>(
fsas, state_batches, leaving_arc_batches, nullptr, log_semiring);
Array1<T> ans = GetBackwardScores<T>(fsas, state_batches,
leaving_arc_batches, log_semiring);
return ToTensor(ans);
},
py::arg("fsas"), py::arg("state_batches"), py::arg("leaving_arc_batches"),
Expand Down Expand Up @@ -415,8 +414,8 @@ static void PybindGetTotScoresTropicalBackward(py::module &m,

template <typename T>
static void PybindGetTotScoresLogBackward(py::module &m, const char *name) {
m.def(name, &GetTotScoresLogBackward<T>, py::arg("fsas"),
py::arg("arc_post"), py::arg("tot_scores_grad"));
m.def(name, &GetTotScoresLogBackward<T>, py::arg("fsas"), py::arg("arc_post"),
py::arg("tot_scores_grad"));
}

} // namespace k2
Expand All @@ -438,8 +437,8 @@ void PybindFsa(py::module &m) {
m, "get_tot_scores_float_tropical_backward");
k2::PybindGetTotScoresTropicalBackward<double>(
m, "get_tot_scores_double_tropical_backward");
k2::PybindGetTotScoresLogBackward<float>(
m, "get_tot_scores_float_log_backward");
k2::PybindGetTotScoresLogBackward<float>(m,
"get_tot_scores_float_log_backward");
k2::PybindGetTotScoresLogBackward<double>(
m, "get_tot_scores_double_log_backward");
}
37 changes: 17 additions & 20 deletions k2/python/k2/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ def backward(ctx, tot_scores_grad: torch.Tensor
bprop_func = _k2.get_tot_scores_double_log_backward
else:
bprop_func = _k2.get_tot_scores_float_log_backward

arc_post = func(fsas=fsas.arcs,
forward_scores=forward_scores,
backward_scores=backward_scores)
scores_grad = bprop_func(fsas.arcs, arc_post, tot_scores_grad)
return None, None, None, scores_grad

Expand Down Expand Up @@ -136,8 +132,9 @@ def forward(ctx, fsas: Fsa, log_semiring: bool, use_double_scores: bool,
# that, the backward_fn of forward_scores, which is cached in `fsas`,
# would be set to this object, giving `fsas` a reference to this object,
# which also has a reference to `fsas`.
forward_scores = fsas._get_forward_scores(use_double_scores=use_double_scores,
log_semiring=log_semiring).detach()
forward_scores = fsas._get_forward_scores(
use_double_scores=use_double_scores,
log_semiring=log_semiring).detach()

# NOTE: since `fsas`, `log_semiring` and `use_double_scores` are
# not tensors, they are saved as attributes of `ctx`.
Expand All @@ -159,24 +156,25 @@ def backward(ctx, backward_scores_grad: torch.Tensor
entering_arcs = fsas._get_entering_arcs(use_double_scores)
state_batches = fsas._get_state_batches()
leaving_arc_batches = fsas._get_leaving_arc_batches()
backward_scores = fsas._get_backward_score(use_double_scores=use_double_scores,
log_semiring=log_semiring)

backward_scores = fsas._get_backward_score(
use_double_scores=use_double_scores, log_semiring=log_semiring)

# Note: perhaps _k2.backprop_get_backward_scores() can figure out the
# type, float vs. double. Whatever works and is easy, though..
scores_grad = _k2.backprop_get_backward_scores(
fsas, state_batches, leaving_arc_batches,
log_semiring, backward_scores, backward_scores_grad)
scores_grad = _k2.backprop_get_backward_scores(fsas, state_batches,
leaving_arc_batches,
log_semiring,
backward_scores,
backward_scores_grad)

return None, None, None, scores_grad


class _GetArcPostFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, fsas: Fsa, log_semiring: bool, use_double_scores: bool,
unused_scores: torch.Tensor,
forward_scores: torch.Tensor,
unused_scores: torch.Tensor, forward_scores: torch.Tensor,
backward_scores: torch.Tensor) -> torch.Tensor:
'''Compute the arc-level posteriors of an FsaVec
Expand Down Expand Up @@ -227,8 +225,8 @@ def forward(ctx, fsas: Fsa, log_semiring: bool, use_double_scores: bool,

@staticmethod
def backward(ctx, arc_post_grad: torch.Tensor
) -> Tuple[None, None, None, torch.Tensor,
torch.Tensor, torch.Tensor]: # noqa
) -> Tuple[None, None, None, torch.Tensor, torch.Tensor, torch.
Tensor]: # noqa
fsas = ctx.fsas
log_semiring = ctx.log_semiring
use_double_scores = ctx.use_double_scores
Expand All @@ -238,8 +236,9 @@ def backward(ctx, arc_post_grad: torch.Tensor
else _k2.get_arc_scores_float_log_backward)

incoming_arcs = fsas._get_incoming_arcs()
(arc_scores_grad, forward_scores_grad, backward_scores_grad) = bprop_func(
fsas.arcs, incoming_arcs, arc_post_grad)
(arc_scores_grad, forward_scores_grad,
backward_scores_grad) = bprop_func(fsas.arcs, incoming_arcs,
arc_post_grad)

return None, None, None, arc_scores_grad, forward_scores_grad, backward_scores_grad

Expand Down Expand Up @@ -480,8 +479,6 @@ def backward(ctx, out_fsa_grad: torch.Tensor
return None, None, ans




def intersect_dense_pruned(a_fsas: Fsa, b_fsas: DenseFsaVec,
search_beam: float, output_beam: float,
min_active_states: int,
Expand Down
102 changes: 46 additions & 56 deletions k2/python/k2/fsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,10 @@ def _get_leaving_arc_batches(self) -> _k2.RaggedInt:
self.arcs, self._get_state_batches())
return cache[name]

def _get_forward_scores(self, use_double_scores: bool, log_semiring: bool) -> torch.Tensor:
def _get_forward_scores(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Get (and compute if necessary) cached property
self.forward_scores_xxx (where xxx indicates float-type and semiring).
self.forward_scores_xxx_yyy (where xxx indicates float-type and yyy indicates semiring).
For use by internal k2 code; returns the total score from start-state to
each state. Not differentiable; see "get_forward_scores" which is the
differentiable version.
Expand Down Expand Up @@ -496,8 +497,7 @@ def _get_forward_scores(self, use_double_scores: bool, log_semiring: bool) -> to
cache['entering_arcs'] = entering_arcs
return cache[name]

def get_forward_scores(self,
use_double_scores: bool,
def get_forward_scores(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute forward-scores, i.e. total weight (or best-path weight)
from start state to each state. Supports autograd.
Expand All @@ -512,29 +512,9 @@ def get_forward_scores(self,
self, log_semiring, use_double_scores, self.scores)
return forward_scores

def get_tot_scores(self, use_double_scores: bool, log_semiring: bool) -> torch.Tensor:
'''Compute total-scores in log semiring (one per FSA) as the
best-path score. This version is differentiable.
CAUTION:
These are just the raw total-scores and are not differentiable.
Use `k2.get_tot_scores(self)` to get differentiable total-scores.
Args:
use_double_scores:
True to use `double precision` floating point;
False to use `single precision`.
log_semiring:
True to use log semiring (log-sum), false to use tropical (i.e. max
on scores).
'''
tot_scores = k2.autograd._GetTotScoresFunction.apply(self, log_semiring,
use_double_scores, self.scores)
return tot_scores


def _get_tot_scores(self, use_double_scores: bool, log_semiring: bool) -> torch.Tensor:
'''Compute total-scores in log semiring (one per FSA) as the
def _get_tot_scores(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute total-scores (one per FSA) as the
best-path score. This version is not differentiable; see also
self.get_tot_scores() which is differentiable.
Expand All @@ -554,18 +534,33 @@ def _get_tot_scores(self, use_double_scores: bool, log_semiring: bool) -> torch.
func = _k2.get_tot_scores_double
else:
func = _k2.get_tot_scores_float
forward_scores = self.get_forward_scores(use_double_scores,
log_semiring)
forward_scores = self._get_forward_scores(use_double_scores,
log_semiring)
total_scores = func(self.arcs, forward_scores)
cache[name] = total_scores
return cache[name]

def get_tot_scores(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute total-scores (one per FSA) as the
best-path score. This version is differentiable.
def _get_backward_scores(self,
use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
Args:
use_double_scores:
True to use `double precision` floating point;
False to use `single precision`.
log_semiring:
True to use log semiring (log-sum), false to use tropical (i.e. max
on scores).
'''
tot_scores = k2.autograd._GetTotScoresFunction.apply(
self, log_semiring, use_double_scores, self.scores)
return tot_scores

def _get_backward_scores(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute backward-scores, i.e. total weight (or best-path weight)
from each state to end state. For internal k2 use. Not differentiable.
from each state to the final state. For internal k2 use. Not differentiable.
See also get_backward_scores() which is differentiable.
Args:
Expand All @@ -586,9 +581,8 @@ def _get_backward_scores(self,
else:
func = _k2.get_backward_scores_float

state_batches = self.get_state_batches()
leaving_arc_batches = self.get_leaving_arc_batches()
tot_scores = self.get_tot_scores(use_double_scores, log_semiring)
state_batches = self._get_state_batches()
leaving_arc_batches = self._get_leaving_arc_batches()
backward_scores_tropical = func(
self.arcs,
state_batches=state_batches,
Expand All @@ -597,11 +591,10 @@ def _get_backward_scores(self,
cache[name] = backward_scores_tropical
return cache[name]

def get_backward_scores(self,
use_double_scores: bool,
def get_backward_scores(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute backward-scores, i.e. total weight (or best-path weight)
from each state to end state. Supports autograd.
from each state to the final state. Supports autograd.
Args:
use_double_scores: if True, use double precision.
Expand All @@ -612,9 +605,8 @@ def get_backward_scores(self,
self, log_semiring, use_double_scores, self.scores)
return backward_scores

def _get_arc_post(self,
use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
def _get_arc_post(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute scores on arcs, representing log probabilities;
with log_semiring=True you could call these log posteriors,
but if log_semiring=False they can only be interpreted as the
Expand All @@ -637,18 +629,17 @@ def _get_arc_post(self,
if name not in cache:
forward_scores = self._get_forward_scores(use_double_scores,
log_semiring)
backward_scores = self._get_backward_scores(use_double_scores,
log_semiring)
func = (_k2.get_arc_post_double if use_double_scores else
_k2.get_arc_post_float)
backward_scores = self._get_backward_scores(
use_double_scores, log_semiring)
func = (_k2.get_arc_post_double
if use_double_scores else _k2.get_arc_post_float)
arc_post = func(fsas=self.arcs,
forward_scores=forward_scores,
backward_scores=backward_scores)
cache[name] = arc_post
return cache[name]

def get_arc_post(self,
use_double_scores: bool,
def get_arc_post(self, use_double_scores: bool,
log_semiring: bool) -> torch.Tensor:
'''Compute scores on arcs, representing log probabilities;
with log_semiring=True you could call these log posteriors,
Expand All @@ -669,20 +660,19 @@ def get_arc_post(self,
# We don't cache this! User should store it if needed more than once,
# to avoid duplicate code in backprop. We may be able to partially fix
# this at some point with a weak dictionary.
forward_scores = self.get_forward_scores(use_double_scores,
log_semiring)
forward_scores = self._get_forward_scores(use_double_scores,
log_semiring)
backward_scores = self._get_backward_scores(use_double_scores,
log_semiring)

# Below, the last 3 args are active w.r.t. autograd, the backward function
# will return non-None derivatives for them.
arc_post = k2.autograd._GetArcPostFunction(use_double_scores, log_semiring,
self.scores,
forward_scores, backward_scores);
arc_post = k2.autograd._GetArcPostFunction(use_double_scores,
log_semiring, self.scores,
forward_scores,
backward_scores)
return arc_post



def _get_entering_arcs(self, use_double_scores: bool) -> torch.Tensor:
'''Compute, for each state, the index of the best arc entering it.
For internal k2 use.
Expand All @@ -695,7 +685,7 @@ def _get_entering_arcs(self, use_double_scores: bool) -> torch.Tensor:
name, cache = 'entering_arcs', self._cache
if name not in cache:
# the following will set self._cache['entering_arcs']
self.get_forward_scores_tropical(use_double_scores)
self._get_forward_scores(use_double_scores, False)
return cache[name]

def requires_grad_(self, requires_grad: bool) -> 'Fsa':
Expand Down
2 changes: 1 addition & 1 deletion k2/python/k2/fsa_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def shortest_path(fsa: Fsa, use_double_scores: bool) -> Fsa:
Returns:
FsaVec, it contains the best paths as linear FSAs
'''
entering_arcs = fsa.get_entering_arcs(use_double_scores)
entering_arcs = fsa._get_entering_arcs(use_double_scores)
ragged_arc, ragged_int = _k2.shortest_path(fsa.arcs, entering_arcs)
out_fsa = Fsa(ragged_arc)

Expand Down
4 changes: 1 addition & 3 deletions k2/python/tests/compose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def test_compose(self):

ans = k2.create_fsa_vec([ans])

scores = k2.get_tot_scores(ans,
log_semiring=True,
use_double_scores=False)
scores = ans.get_tot_scores(log_semiring=True, use_double_scores=False)
# The reference values for `scores`, `a_fsa.grad` and `b_fsa.grad`
# are computed using GTN.
# See https://bit.ly/3heLAJq
Expand Down
Loading

0 comments on commit 02c30ed

Please sign in to comment.