forked from explosion/thinc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This layer is an extension of the existing `ParametricAttention` layer, adding support for transformations (such as a non-linear layer) of the key representation. This brings the model closer to the paper that suggested it (Yang et al, 2016) and gave slightly better results in experiments.
- Loading branch information
Showing
5 changed files
with
152 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Callable, Optional, Tuple, cast | ||
|
||
from ..config import registry | ||
from ..model import Model | ||
from ..types import Floats2d, Ragged | ||
from ..util import get_width | ||
|
||
InT = Ragged | ||
OutT = Ragged | ||
|
||
|
||
@registry.layers("ParametricAttention.v2") | ||
def ParametricAttention_v2( | ||
*, | ||
key_transform: Optional[Model[Floats2d, Floats2d]] = None, | ||
nO: Optional[int] = None | ||
) -> Model[InT, OutT]: | ||
if key_transform is None: | ||
layers = [] | ||
refs = {} | ||
else: | ||
layers = [key_transform] | ||
refs = {"key_transform": cast(Optional[Model], key_transform)} | ||
|
||
layers = [key_transform] if key_transform is not None else [] | ||
|
||
"""Weight inputs by similarity to a learned vector""" | ||
return Model( | ||
"para-attn", | ||
forward, | ||
init=init, | ||
params={"Q": None}, | ||
dims={"nO": nO}, | ||
layers=layers, | ||
refs=refs, | ||
) | ||
|
||
|
||
def forward(model: Model[InT, OutT], Xr: InT, is_train: bool) -> Tuple[OutT, Callable]: | ||
Q = model.get_param("Q") | ||
key_transform = model.maybe_get_ref("key_transform") | ||
|
||
attention, bp_attention = _get_attention( | ||
model.ops, Q, key_transform, Xr.dataXd, Xr.lengths, is_train | ||
) | ||
output, bp_output = _apply_attention(model.ops, attention, Xr.dataXd, Xr.lengths) | ||
|
||
def backprop(dYr: OutT) -> InT: | ||
dX, d_attention = bp_output(dYr.dataXd) | ||
dQ, dX2 = bp_attention(d_attention) | ||
model.inc_grad("Q", dQ.ravel()) | ||
dX += dX2 | ||
return Ragged(dX, dYr.lengths) | ||
|
||
return Ragged(output, Xr.lengths), backprop | ||
|
||
|
||
def init( | ||
model: Model[InT, OutT], X: Optional[InT] = None, Y: Optional[OutT] = None | ||
) -> None: | ||
key_transform = model.maybe_get_ref("key_transform") | ||
width = get_width(X) if X is not None else None | ||
if width: | ||
model.set_dim("nO", width) | ||
if key_transform is not None: | ||
key_transform.set_dim("nO", width) | ||
|
||
# Randomly initialize the parameter, as though it were an embedding. | ||
Q = model.ops.alloc1f(model.get_dim("nO")) | ||
Q += model.ops.xp.random.uniform(-0.1, 0.1, Q.shape) | ||
model.set_param("Q", Q) | ||
|
||
X_array = X.dataXd if X is not None else None | ||
Y_array = Y.dataXd if Y is not None else None | ||
|
||
if key_transform is not None: | ||
key_transform.initialize(X_array, Y_array) | ||
|
||
|
||
def _get_attention(ops, Q, key_transform, X, lengths, is_train): | ||
if key_transform is None: | ||
K, K_bp = X, lambda dY: dY | ||
else: | ||
K, K_bp = key_transform(X, is_train=is_train) | ||
|
||
attention = ops.gemm(K, ops.reshape2f(Q, -1, 1)) | ||
attention = ops.softmax_sequences(attention, lengths) | ||
|
||
def get_attention_bwd(d_attention): | ||
d_attention = ops.backprop_softmax_sequences(d_attention, attention, lengths) | ||
dQ = ops.gemm(K, d_attention, trans1=True) | ||
dY = ops.xp.outer(d_attention, Q) | ||
dX = K_bp(dY) | ||
return dQ, dX | ||
|
||
return attention, get_attention_bwd | ||
|
||
|
||
def _apply_attention(ops, attention, X, lengths): | ||
output = X * attention | ||
|
||
def apply_attention_bwd(d_output): | ||
d_attention = (X * d_output).sum(axis=1, keepdims=True) | ||
dX = d_output * attention | ||
return dX, d_attention | ||
|
||
return output, apply_attention_bwd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters