From 829ea29d9a8d25abb975e098171062814807d594 Mon Sep 17 00:00:00 2001 From: Amanpreet Singh Date: Thu, 17 Jun 2021 13:02:27 -0700 Subject: [PATCH] [feat] Support custom mask key in MMF Transformer Now, you can specify custom mask key for a modality using `mask_key` attribute. This also fixes an issue which caused MMFT to fail when text modality was not present. --- mmf/models/mmf_transformer.py | 11 ++++++++++- mmf/models/transformers/base.py | 4 +++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mmf/models/mmf_transformer.py b/mmf/models/mmf_transformer.py index f181e95ef..74c134196 100644 --- a/mmf/models/mmf_transformer.py +++ b/mmf/models/mmf_transformer.py @@ -87,6 +87,7 @@ def __init__(self, config: BaseTransformer.Config, *args, **kwargs): self.modality_keys: List = [] self.modality_type: List = [] self.modality_segments: List = [] + self.modality_masks: List = [] for modality in self.config.modalities: self.modality_keys.append(modality.key) self.modality_type.append(modality.type) @@ -95,6 +96,11 @@ def __init__(self, config: BaseTransformer.Config, *args, **kwargs): else: self.modality_segments.append(-1) + if "mask_key" in modality: + self.modality_masks.append(modality.mask_key) + else: + self.modality_masks.append(f"${modality.key}_mask") + # Backward compatibility for code for old mmft checkpoints @classmethod def format_state_key(cls, key): @@ -136,6 +142,9 @@ def build_encoders(self): def tie_weights(self): """Tie some head weights with backend embeddings""" + if "text" not in self.modality_type: + return + text_embedding_idx = self.modality_type.index("text") if text_embedding_idx >= 0: for head in self.heads: @@ -294,7 +303,7 @@ def _infer_masks( else: masks[modality] = sample_list["input_mask"] else: - mask_attribute = f"{modality}_mask" + mask_attribute = self.modality_masks[idx] if mask_attribute in sample_list: masks[modality] = sample_list[mask_attribute] else: diff --git a/mmf/models/transformers/base.py b/mmf/models/transformers/base.py index c6ddd3251..6d78e51c4 100644 --- a/mmf/models/transformers/base.py +++ b/mmf/models/transformers/base.py @@ -9,7 +9,7 @@ from mmf.models import BaseModel from mmf.modules.encoders import IdentityEncoder from mmf.utils.modeling import get_bert_configured_parameters -from omegaconf import MISSING, OmegaConf +from omegaconf import MISSING, OmegaConf, SI from torch import Tensor, nn @@ -39,6 +39,8 @@ class BaseTransformerModalityConfig: # This is actually: Union[EncoderFactory.Config, Encoder.Config] # NOTE: Waiting on https://github.com/omry/omegaconf/issues/144 encoder: Any = IdentityEncoder.Config() + # separate mask key if needed, defaults to `{key}_mask` + mask_key: str = SI("${key}_mask") @dataclass