From 7db870c0d4122fda4b6efb8f3d51cab4acc0cece Mon Sep 17 00:00:00 2001 From: Fhrozen Date: Wed, 10 Jul 2024 21:34:16 +0900 Subject: [PATCH] additional files --- fairseq/incremental_decoding_utils.py | 51 ++++++++ fairseq/models/fairseq_incremental_decoder.py | 118 ++++++++++++++++++ 2 files changed, 169 insertions(+) create mode 100644 fairseq/incremental_decoding_utils.py create mode 100644 fairseq/models/fairseq_incremental_decoder.py diff --git a/fairseq/incremental_decoding_utils.py b/fairseq/incremental_decoding_utils.py new file mode 100644 index 0000000..b26e6cd --- /dev/null +++ b/fairseq/incremental_decoding_utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import uuid +from typing import Dict, Optional + +from torch import Tensor + + +class FairseqIncrementalState(object): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.init_incremental_state() + + def init_incremental_state(self): + self._incremental_state_id = str(uuid.uuid4()) + + def _get_full_incremental_state_key(self, key: str) -> str: + return "{}.{}".format(self._incremental_state_id, key) + + def get_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + ) -> Optional[Dict[str, Optional[Tensor]]]: + """Helper for getting incremental state for an nn.Module.""" + full_key = self._get_full_incremental_state_key(key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + def set_incremental_state( + self, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], + key: str, + value: Dict[str, Optional[Tensor]], + ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = self._get_full_incremental_state_key(key) + incremental_state[full_key] = value + return incremental_state + + +def with_incremental_state(cls): + cls.__bases__ = (FairseqIncrementalState,) + tuple( + b for b in cls.__bases__ if b != FairseqIncrementalState + ) + return cls diff --git a/fairseq/models/fairseq_incremental_decoder.py b/fairseq/models/fairseq_incremental_decoder.py new file mode 100644 index 0000000..cc72a0f --- /dev/null +++ b/fairseq/models/fairseq_incremental_decoder.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +from fairseq.incremental_decoding_utils import with_incremental_state +from fairseq.models import FairseqDecoder +from torch import Tensor + + +logger = logging.getLogger(__name__) + + +@with_incremental_state +class FairseqIncrementalDecoder(FairseqDecoder): + """Base class for incremental decoders. + + Incremental decoding is a special mode at inference time where the Model + only receives a single timestep of input corresponding to the previous + output token (for teacher forcing) and must produce the next output + *incrementally*. Thus the model must cache any long-term state that is + needed about the sequence, e.g., hidden states, convolutional states, etc. + + Compared to the standard :class:`FairseqDecoder` interface, the incremental + decoder interface allows :func:`forward` functions to take an extra keyword + argument (*incremental_state*) that can be used to cache state across + time-steps. + + The :class:`FairseqIncrementalDecoder` interface also defines the + :func:`reorder_incremental_state` method, which is used during beam search + to select and reorder the incremental state based on the selection of beams. + + To learn more about how incremental decoding works, refer to `this blog + `_. + """ + + def __init__(self, dictionary): + super().__init__(dictionary) + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Args: + prev_output_tokens (LongTensor): shifted output tokens of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (dict, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict, optional): dictionary used for storing + state during :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs + ): + """ + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + raise NotImplementedError + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + pass + + def reorder_incremental_state_scripting( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + new_order: Tensor, + ): + """Main entry point for reordering the incremental state. + + Due to limitations in TorchScript, we call this function in + :class:`fairseq.sequence_generator.SequenceGenerator` instead of + calling :func:`reorder_incremental_state` directly. + """ + for module in self.modules(): + if hasattr(module, "reorder_incremental_state"): + result = module.reorder_incremental_state(incremental_state, new_order) + if result is not None: + incremental_state = result + + def set_beam_size(self, beam_size): + """Sets the beam size in the decoder and all children.""" + if getattr(self, "_beam_size", -1) != beam_size: + seen = set() + + def apply_set_beam_size(module): + if ( + module != self + and hasattr(module, "set_beam_size") + and module not in seen + ): + seen.add(module) + module.set_beam_size(beam_size) + + self.apply(apply_set_beam_size) + self._beam_size = beam_size