Skip to content

Commit

Permalink
additional files
Browse files Browse the repository at this point in the history
  • Loading branch information
Fhrozen committed Jul 10, 2024
1 parent 3b50fa3 commit 7db870c
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
51 changes: 51 additions & 0 deletions fairseq/incremental_decoding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import uuid
from typing import Dict, Optional

from torch import Tensor


class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()

def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())

def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)

def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]

def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state


def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
118 changes: 118 additions & 0 deletions fairseq/models/fairseq_incremental_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, Optional

from fairseq.incremental_decoding_utils import with_incremental_state
from fairseq.models import FairseqDecoder
from torch import Tensor


logger = logging.getLogger(__name__)


@with_incremental_state
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model
only receives a single timestep of input corresponding to the previous
output token (for teacher forcing) and must produce the next output
*incrementally*. Thus the model must cache any long-term state that is
needed about the sequence, e.g., hidden states, convolutional states, etc.
Compared to the standard :class:`FairseqDecoder` interface, the incremental
decoder interface allows :func:`forward` functions to take an extra keyword
argument (*incremental_state*) that can be used to cache state across
time-steps.
The :class:`FairseqIncrementalDecoder` interface also defines the
:func:`reorder_incremental_state` method, which is used during beam search
to select and reorder the incremental state based on the selection of beams.
To learn more about how incremental decoding works, refer to `this blog
<http://www.telesens.co/2019/04/21/understanding-incremental-decoding-in-fairseq/>`_.
"""

def __init__(self, dictionary):
super().__init__(dictionary)

def forward(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
):
"""
Args:
prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (dict, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict, optional): dictionary used for storing
state during :ref:`Incremental decoding`
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError

def extract_features(
self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs
):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
raise NotImplementedError

def reorder_incremental_state(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Reorder incremental state.
This will be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
pass

def reorder_incremental_state_scripting(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
new_order: Tensor,
):
"""Main entry point for reordering the incremental state.
Due to limitations in TorchScript, we call this function in
:class:`fairseq.sequence_generator.SequenceGenerator` instead of
calling :func:`reorder_incremental_state` directly.
"""
for module in self.modules():
if hasattr(module, "reorder_incremental_state"):
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result

def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, "_beam_size", -1) != beam_size:
seen = set()

def apply_set_beam_size(module):
if (
module != self
and hasattr(module, "set_beam_size")
and module not in seen
):
seen.add(module)
module.set_beam_size(beam_size)

self.apply(apply_set_beam_size)
self._beam_size = beam_size

0 comments on commit 7db870c

Please sign in to comment.