-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import uuid | ||
from typing import Dict, Optional | ||
|
||
from torch import Tensor | ||
|
||
|
||
class FairseqIncrementalState(object): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.init_incremental_state() | ||
|
||
def init_incremental_state(self): | ||
self._incremental_state_id = str(uuid.uuid4()) | ||
|
||
def _get_full_incremental_state_key(self, key: str) -> str: | ||
return "{}.{}".format(self._incremental_state_id, key) | ||
|
||
def get_incremental_state( | ||
self, | ||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | ||
key: str, | ||
) -> Optional[Dict[str, Optional[Tensor]]]: | ||
"""Helper for getting incremental state for an nn.Module.""" | ||
full_key = self._get_full_incremental_state_key(key) | ||
if incremental_state is None or full_key not in incremental_state: | ||
return None | ||
return incremental_state[full_key] | ||
|
||
def set_incremental_state( | ||
self, | ||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], | ||
key: str, | ||
value: Dict[str, Optional[Tensor]], | ||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: | ||
"""Helper for setting incremental state for an nn.Module.""" | ||
if incremental_state is not None: | ||
full_key = self._get_full_incremental_state_key(key) | ||
incremental_state[full_key] = value | ||
return incremental_state | ||
|
||
|
||
def with_incremental_state(cls): | ||
cls.__bases__ = (FairseqIncrementalState,) + tuple( | ||
b for b in cls.__bases__ if b != FairseqIncrementalState | ||
) | ||
return cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import logging | ||
from typing import Dict, Optional | ||
|
||
from fairseq.incremental_decoding_utils import with_incremental_state | ||
from fairseq.models import FairseqDecoder | ||
from torch import Tensor | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@with_incremental_state | ||
class FairseqIncrementalDecoder(FairseqDecoder): | ||
"""Base class for incremental decoders. | ||
Incremental decoding is a special mode at inference time where the Model | ||
only receives a single timestep of input corresponding to the previous | ||
output token (for teacher forcing) and must produce the next output | ||
*incrementally*. Thus the model must cache any long-term state that is | ||
needed about the sequence, e.g., hidden states, convolutional states, etc. | ||
Compared to the standard :class:`FairseqDecoder` interface, the incremental | ||
decoder interface allows :func:`forward` functions to take an extra keyword | ||
argument (*incremental_state*) that can be used to cache state across | ||
time-steps. | ||
The :class:`FairseqIncrementalDecoder` interface also defines the | ||
:func:`reorder_incremental_state` method, which is used during beam search | ||
to select and reorder the incremental state based on the selection of beams. | ||
To learn more about how incremental decoding works, refer to `this blog | ||
<http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_. | ||
""" | ||
|
||
def __init__(self, dictionary): | ||
super().__init__(dictionary) | ||
|
||
def forward( | ||
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs | ||
): | ||
""" | ||
Args: | ||
prev_output_tokens (LongTensor): shifted output tokens of shape | ||
`(batch, tgt_len)`, for teacher forcing | ||
encoder_out (dict, optional): output from the encoder, used for | ||
encoder-side attention | ||
incremental_state (dict, optional): dictionary used for storing | ||
state during :ref:`Incremental decoding` | ||
Returns: | ||
tuple: | ||
- the decoder's output of shape `(batch, tgt_len, vocab)` | ||
- a dictionary with any model-specific outputs | ||
""" | ||
raise NotImplementedError | ||
|
||
def extract_features( | ||
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs | ||
): | ||
""" | ||
Returns: | ||
tuple: | ||
- the decoder's features of shape `(batch, tgt_len, embed_dim)` | ||
- a dictionary with any model-specific outputs | ||
""" | ||
raise NotImplementedError | ||
|
||
def reorder_incremental_state( | ||
self, | ||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | ||
new_order: Tensor, | ||
): | ||
"""Reorder incremental state. | ||
This will be called when the order of the input has changed from the | ||
previous time step. A typical use case is beam search, where the input | ||
order changes between time steps based on the selection of beams. | ||
""" | ||
pass | ||
|
||
def reorder_incremental_state_scripting( | ||
self, | ||
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], | ||
new_order: Tensor, | ||
): | ||
"""Main entry point for reordering the incremental state. | ||
Due to limitations in TorchScript, we call this function in | ||
:class:`fairseq.sequence_generator.SequenceGenerator` instead of | ||
calling :func:`reorder_incremental_state` directly. | ||
""" | ||
for module in self.modules(): | ||
if hasattr(module, "reorder_incremental_state"): | ||
result = module.reorder_incremental_state(incremental_state, new_order) | ||
if result is not None: | ||
incremental_state = result | ||
|
||
def set_beam_size(self, beam_size): | ||
"""Sets the beam size in the decoder and all children.""" | ||
if getattr(self, "_beam_size", -1) != beam_size: | ||
seen = set() | ||
|
||
def apply_set_beam_size(module): | ||
if ( | ||
module != self | ||
and hasattr(module, "set_beam_size") | ||
and module not in seen | ||
): | ||
seen.add(module) | ||
module.set_beam_size(beam_size) | ||
|
||
self.apply(apply_set_beam_size) | ||
self._beam_size = beam_size |