Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisualBert Doesn't return attentions #1036

Open
EXUPLOOOOSION opened this issue Jul 31, 2021 · 2 comments
Open

VisualBert Doesn't return attentions #1036

EXUPLOOOOSION opened this issue Jul 31, 2021 · 2 comments

Comments

@EXUPLOOOOSION
Copy link

EXUPLOOOOSION commented Jul 31, 2021

🐛 Bug

the VisualBert model ignores the output_attentions config

Command

To Reproduce

Steps to reproduce the behavior:

In a python script:

  1. get a configuration with output_attentions = True
  2. Initialize and build any VisualBert model (any that uses VisualBertBase)
  3. Run an inference (forward) and find out in the ouput that the attention tuple exists but is empty

specifying output_attentions to True in the forward function parameters doesnt help either.

Expected behavior

Of course, attentions shouldn't be empty.

Environment

PyTorch version: 1.9.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26

Python version: 3.7 (64-bit runtime)
Python platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.0.221
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] pytorch-lightning==1.4.0
[pip3] torch==1.9.0
[pip3] torchmetrics==0.4.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.5.0
[pip3] torchvision==0.6.0
[conda] Could not collect

Additional context

Reason

VisualBERT uses VisualBERTForClassification, which uses VisualBERTBase, which uses BertEncoderJit
All of these get their attribute output_attentions right.
However, BertEncoderJit's forward function doesn't use BertEncoderJit's attribute output_attentions, insted it only uses its parameter output_attentions. This paired with VisualBERTBase not sending its own output_attentions as a parameter, makes all models using VisualBERTBase not output any attention.

This also applies to output_hidden_states

Fix

Either 1: make VisualBERTBase specify output_attentions as a parameter for its encoder's forward
or 2: make BertEncoderJit's forwardd function use both, its parameter and BertEncoderJit's attribute to decide wether to output attentions (as long as either one of them is true, it outputs them).

I personally implemented the second one in my local build of mmf and it works.

@vedanuj
Copy link
Contributor

vedanuj commented Jul 31, 2021

Hello @EXUPLOOOOSION, please feel free to open a PR for this feature. We welcome contributions!

@abhinav-bohra
Copy link

abhinav-bohra commented Aug 19, 2021

Hello @vedanuj , I was able to reproduce this error on my local build of MMF. I think the reason for this error is that -

the default value of output_attentions in forward( ) call of BertEncoderJit (in mmf/modules/hf_layers.py) is set as False. So even if the user/developer specifies output_attentions = True in config; its value is taken as default False and thus VisualBERT returns an empty tuple for attentions.

I think a more robust way to fix it would be to set the argument output_attentions as None in BertEncoderJit's forward( ) definition, and update output_attentions to self.output_attentions if it is not passed as an argument (i.e it is None). Therefore, now output_attentions will take the value of self.output_attentions (which was initialized using config during instantiation of BertEncoderJit class)

The issue with output_hidden_states is the same, and can be fixed similarly.

abhinav-bohra added a commit to abhinav-bohra/mmf that referenced this issue Aug 22, 2021
PROBLEM: The default value of output_attentions in forward( ) call of BertEncoderJit (in mmf/modules/hf_layers.py) is set as False. So even if the user/developer specifies output_attentions = True in config; its value is taken as default False and thus VisualBERT returns an empty tuple for attentions.

FIX: Set output_attentions as None in BertEncoderJit's forward( ) definition, and update output_attentions to self.output_attentions if it is not passed as an argument (i.e it is None). Therefore, now output_attentions will take the value of self.output_attentions (which was initialized using config during instantiation of BertEncoderJit class)

Same problem and same fix for output_hidden_states as well.

Tested locally.
abhinav-bohra added a commit to abhinav-bohra/mmf that referenced this issue Aug 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants