Skip to content

Commit

Permalink
fix a bug in k2.ragged.normalize_scores. (#563)
Browse files Browse the repository at this point in the history
* fix a bug in k2.ragged.normalize_scores.

If k2.ragged.RaggedFloat is constructed from a shape and a value,
its scores should be set to the given value to make autograd work.

* remove k2.Fsa.detach().

* fix style issues.

* rename scores -> values for k2.ragged.RaggedFloat.

* fix an error.

* resolve comments.

* check that scores.numel() == self.scores.numel().

* use k2.create_fsa_vec which already supports autograd.
  • Loading branch information
csukuangfj authored Jan 4, 2021
1 parent 80abc77 commit 4c629e1
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 61 deletions.
13 changes: 13 additions & 0 deletions k2/python/csrc/torch/ragged_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ static void PybindNormalizePerSublistBackward(py::module &m, const char *name) {
py::arg("out_grad"));
}

template <typename T, typename Op>
static void PybindOpPerSublist(py::module &m, Op op, const char *name) {
m.def(
name,
[op](Ragged<T> &src, T initial_value) -> torch::Tensor {
Array1<T> values(src.Context(), src.TotSize(src.NumAxes() - 2));
op(src, initial_value, &values);
return ToTensor(values);
},
py::arg("src"), py::arg("initial_value"));
}

} // namespace k2

void PybindRaggedOps(py::module &m) {
Expand All @@ -171,4 +183,5 @@ void PybindRaggedOps(py::module &m) {
PybindRaggedIntToList(m, "ragged_int_to_list");
PybindNormalizePerSublist<float>(m, "normalize_per_sublist");
PybindNormalizePerSublistBackward<float>(m, "normalize_per_sublist_backward");
PybindOpPerSublist<float>(m, SumPerSublist<float>, "sum_per_sublist");
}
50 changes: 17 additions & 33 deletions k2/python/k2/fsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,44 +1027,28 @@ def from_openfst(cls, s: str, acceptor: bool = True) -> 'Fsa':
arcs, aux_labels = _k2.fsa_from_str(s, acceptor, True)
return Fsa(arcs, aux_labels=aux_labels)

def set_scores_stochastic_(self) -> None:
'''Set `scores` to random numbers.
def set_scores_stochastic_(self, scores) -> None:
'''Normalize the given `scores` and assign it to `self.scores`.
Scores are normalized per state. That is, the sum of the probabilities
of all arcs leaving a state equal to 1.
Args:
scores:
Tensor of scores of dtype torch.float32, and shape equal to
`self.scores.shape` (one axis). Will be normalized so the
sum, after exponentiating, of the scores leaving each state
that has at least one arc leaving it is 1.
Caution:
The function name ends with an underline indicating this function
will modify `self` **in-place**.
'''
scores = torch.randn_like(self.scores)
ragged_scores = k2.ragged.RaggedFloat(self.arcs.shape(), scores)
ragged_scores = k2.ragged.normalize_scores(ragged_scores)

# note that `self.scores` also works here, but [:] is more efficient
self.scores[:] = ragged_scores.scores

def detach(self) -> 'Fsa':
'''Return a new FSA, detached from the current graph.
Like torch.Tensor.detach(), the returned FSA shares the underlying
memory with `self`. The only difference is that the returned FSA's
requires_grad is False.
Caution:
The returned FSA shares memory with this FSA.
Returns:
Return an FSA whose `requires_grad` is False.
'''
# Keep this code in sync with that in to()
ans = Fsa(self.arcs, properties=self.properties)

for name, value in self.named_tensor_attr(include_scores=False):
setattr(ans, name, value)
assert scores.ndim == 1
assert scores.dtype == torch.float32
assert scores.numel() == self.scores.numel()

for name, value in self.named_non_tensor_attr():
setattr(ans, name, value)
ragged_scores = k2.ragged.RaggedFloat(
self.arcs.shape().to(scores.device), scores)
ragged_scores = k2.ragged.normalize_scores(ragged_scores)

ans.scores = self.scores.detach()
return ans
# Note we use `to` here since `scores` and `self.scores` may not
# be on the same device.
self.scores = ragged_scores.values.to(self.scores.device)
2 changes: 2 additions & 0 deletions k2/python/k2/ragged/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .ops import remove_axis
from .ops import remove_values_eq
from .ops import remove_values_leq
from .ops import sum_per_sublist
from .ops import to_list
from .ragged_shape import RaggedShape
from .ragged_shape import compose_ragged_shapes
Expand All @@ -22,5 +23,6 @@
'remove_axis',
'remove_values_eq',
'remove_values_leq',
'sum_per_sublist',
'to_list',
]
10 changes: 5 additions & 5 deletions k2/python/k2/ragged/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def forward(ctx, src: RaggedFloat, out: List[RaggedFloat],
in the list since this function can only return values of type
`torch.Tensor`. On input, we check that `len(out) == 1`.
unused_scores:
Its sole purpose is for autograd. It equals to `src.scores`.
Its sole purpose is for autograd. It equals to `src.values`.
Returns:
Returns a tensor that equals to `out.scores`. Callers should
Returns a tensor that equals to `out.values`. Callers should
discard the return value.
'''
assert len(out) == 1
ans_ragged = _k2.normalize_per_sublist(src.ragged)
out[0] = RaggedFloat(ans_ragged)
ctx.out = out[0] # save for backward
return out[0].scores
return out[0].values

@staticmethod
def backward(ctx,
Expand Down Expand Up @@ -90,7 +90,7 @@ def normalize_scores(src: RaggedFloat) -> RaggedFloat:
out = [None] # placeholder

# the return value is discarded for the following call
# as it equals to out[0].scores
_NormalizeScores.apply(src, out, src.scores)
# as it equals to out[0].values
_NormalizeScores.apply(src, out, src.values)

return out[0]
26 changes: 24 additions & 2 deletions k2/python/k2/ragged/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
from typing import Tuple
from typing import Union

import numpy as np
import torch

import _k2


Expand Down Expand Up @@ -117,3 +115,27 @@ def to_list(src: _k2.RaggedInt) -> List:
as `src`.
'''
return _k2.ragged_int_to_list(src)


def sum_per_sublist(src: _k2.RaggedFloat,
initial_value: float = 0) -> torch.Tensor:
'''Return the sum of each sublist.
For example, if `src` has the following values::
[ [a b] [h j k] [m] ]
Then it returns a 1-D tensor with 3 entries:
- entry 0: a + b + initial_value
- entry 1: h + j + k + initial_value
- entry 2: m + initial_value
Args:
src:
A ragged float tensor. Note that the sum is performed on the last axis.
Returns:
Return a 1-D torch.Tensor with dtype torch.float32. Its `numel` equals to
`src.tot_size(src.num_axes() - 2)`.
'''
return _k2.sum_per_sublist(src, initial_value)
20 changes: 13 additions & 7 deletions k2/python/k2/ragged/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class RaggedFloat(object):
It is a wrapper of :class:`_k2.RaggedFloat`, whose purpose
is to implement autograd for :class:`_k2.RaggedFloat`.
Currently, it is used only in `k2.ragged.normalize_scores`.
'''

def __init__(self,
Expand All @@ -42,40 +44,44 @@ def __init__(self,
'''
if isinstance(ragged, str):
ragged = _k2.RaggedFloat(ragged)
assert values is None
elif isinstance(ragged, _k2.RaggedShape):
assert values is not None
ragged = _k2.RaggedFloat(ragged, values)

assert isinstance(ragged, _k2.RaggedFloat)

self.ragged = ragged
self._scores = ragged.values()
if values is not None:
self._values = values
else:
self._values = ragged.values()

def __str__(self) -> str:
return str(self.ragged)

@property
def scores(self) -> torch.Tensor:
def values(self) -> torch.Tensor:
'''Return the underlying array as a 1-D torch.Tensor.
'''
return self._scores
return self._values

@property
def grad(self) -> torch.Tensor:
return self._scores.grad
return self._values.grad

@property
def requires_grad(self) -> bool:
'''
Return True if this object requires grad.
Return False otherwise.
'''
return self._scores.requires_grad
return self._values.requires_grad

def requires_grad_(self, requires_grad: bool) -> 'RaggedFloat':
'''Change if autograd should record operations on this tensor.
Sets the `scores`'s requires_grad attribute in-place.
Sets the `values`'s requires_grad attribute in-place.
Returns this object.
You can test whether this object has the requires_grad property
true or false by accessing self.requires_grad property.
Expand All @@ -91,5 +97,5 @@ def requires_grad_(self, requires_grad: bool) -> 'RaggedFloat':
Returns:
This object itself.
'''
self._scores.requires_grad_(requires_grad)
self._values.requires_grad_(requires_grad)
return self
4 changes: 1 addition & 3 deletions k2/python/tests/compose_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def test_compose(self):
ans = k2.compose(a_fsa, b_fsa, inner_labels='inner')
ans = k2.connect(ans)

# Convert a single FSA to a FsaVec.
# It will retain `requires_grad_` of `ans`.
ans.__dict__['arcs'] = _k2.create_fsa_vec([ans.arcs])
ans = k2.create_fsa_vec([ans])

scores = k2.get_tot_scores(ans,
log_semiring=True,
Expand Down
6 changes: 4 additions & 2 deletions k2/python/tests/fsa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,8 @@ def test_set_scores_stochastic(self):
3
'''
fsa = k2.Fsa.from_str(s)
fsa.set_scores_stochastic_()
scores = torch.randn_like(fsa.scores)
fsa.set_scores_stochastic_(scores)

# scores of state 0 should be normalized
assert torch.allclose(fsa.scores[0:2].exp().sum(), torch.Tensor([1]))
Expand Down Expand Up @@ -601,7 +602,8 @@ def test_scores_autograd_with_assignment(self):

# CAUTION: had we used fsa.scores = scores,
# would we have `fsa.scores != fsa.arcs.values()[:, -1]`.
# That is, `fsa.scores` shares memory with `scores`, but not with fsa.arcs.values!
# That is, `fsa.scores` shares memory with `scores`,
# but not with fsa.arcs.values!
assert _k2.as_float(fsa.arcs.values()[:, -1]).item() == 100

def test_detach(self):
Expand Down
78 changes: 69 additions & 9 deletions k2/python/tests/ragged_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import unittest

import _k2
import k2
import numpy as np
import torch
Expand Down Expand Up @@ -79,18 +80,18 @@ def test_normalize_scores_non_zero_stride(self):
[ [1 -1 0] [2 10] [] [3] [5 8] ]
'''
src = k2.ragged.RaggedFloat(s)
saved = src.scores.clone().detach()
saved = src.values.clone().detach()
saved.requires_grad_(True)
src.requires_grad_(True)

ans = k2.ragged.normalize_scores(src)

scale = torch.arange(ans.scores.numel())
scale = torch.arange(ans.values.numel())

# the stride of grad is not 0
(ans.scores * scale).sum().backward()
(ans.values * scale).sum().backward()

expected = saved.new_zeros(*ans.scores.shape)
expected = saved.new_zeros(*ans.values.shape)

normalizer = saved[:3].exp().sum().log()
expected[:3] = saved[:3] - normalizer
Expand All @@ -103,7 +104,7 @@ def test_normalize_scores_non_zero_stride(self):
normalizer = saved[6:8].exp().sum().log()
expected[6:8] = saved[6:8] - normalizer

self.assertTrue(torch.allclose(expected, ans.scores))
self.assertTrue(torch.allclose(expected, ans.values))
(expected * scale).sum().backward()

self.assertTrue(torch.allclose(saved.grad, src.grad))
Expand All @@ -113,16 +114,16 @@ def test_normalize_scores_zero_stride(self):
[ [1 3 5] [2 -1] [] [3] [5 2] ]
'''
src = k2.ragged.RaggedFloat(s)
saved = src.scores.clone().detach()
saved = src.values.clone().detach()
saved.requires_grad_(True)
src.requires_grad_(True)

ans = k2.ragged.normalize_scores(src)

# the stride of grad is 0
ans.scores.sum().backward()
ans.values.sum().backward()

expected = saved.new_zeros(*ans.scores.shape)
expected = saved.new_zeros(*ans.values.shape)

normalizer = saved[:3].exp().sum().log()
expected[:3] = saved[:3] - normalizer
Expand All @@ -135,11 +136,70 @@ def test_normalize_scores_zero_stride(self):
normalizer = saved[6:8].exp().sum().log()
expected[6:8] = saved[6:8] - normalizer

self.assertTrue(torch.allclose(expected, ans.scores))
self.assertTrue(torch.allclose(expected, ans.values))
expected.sum().backward()

self.assertTrue(torch.allclose(saved.grad, src.grad))

def test_normalize_scores_from_shape(self):
s = '''
0 1 1 0.
0 1 2 0.
0 1 3 0.
1 2 4 0.
1 2 5 0.
2 3 -1 0.
3
'''
fsa = k2.Fsa.from_str(s)
scores = torch.arange(fsa.scores.numel(), dtype=torch.float32)
scores.requires_grad_(True)

ragged_scores = k2.ragged.RaggedFloat(fsa.arcs.shape(), scores)
assert ragged_scores.requires_grad is True

normalized_scores = k2.ragged.normalize_scores(ragged_scores)
assert normalized_scores.requires_grad is True

fsa.scores = normalized_scores.values
assert fsa.scores.requires_grad is True

# arcs leaving state 0
self.assertAlmostEqual(fsa.scores[:3].exp().sum().item(),
1.0,
places=6)

# arcs leaving state 1
self.assertAlmostEqual(fsa.scores[3:5].exp().sum().item(),
1.0,
places=6)

# arcs leaving state 2
self.assertAlmostEqual(fsa.scores[5].exp().sum().item(), 1.0, places=6)

def test_sum_per_sublist(self):
s = '''
0 1 1 0.
0 1 2 0.
0 1 3 0.
1 2 4 0.
1 2 5 0.
2 3 -1 0.
3
'''
fsa = k2.Fsa.from_str(s)
scores = torch.randn_like(fsa.scores)
fsa.set_scores_stochastic_(scores)
normalized_scores = k2.ragged.sum_per_sublist(
_k2.RaggedFloat(fsa.arcs.shape(), fsa.scores.exp()))
assert normalized_scores.numel() == fsa.arcs.dim0()

assert torch.allclose(normalized_scores[:-1],
torch.ones(normalized_scores.numel() - 1))

# the final state has no leaving arcs
assert normalized_scores[-1].item() == 0


if __name__ == '__main__':
unittest.main()

0 comments on commit 4c629e1

Please sign in to comment.