Skip to content

Commit

Permalink
[feat] Add standard poolers (#1173)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1173

Add identity, cls, and average poolers.
Registered and instantiable through registry.

Test Plan: Imported from OSS

Reviewed By: apsdehal

Differential Revision: D33032556

Pulled By: Ryan-Qiyu-Jiang

fbshipit-source-id: ddccba75f75ba61972760600903cd75810545ff6
  • Loading branch information
Ryan-Qiyu-Jiang authored and facebook-github-bot committed Dec 16, 2021
1 parent f3ede0e commit ee19bd9
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
48 changes: 47 additions & 1 deletion mmf/modules/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class CustomPool(nn.Module):
...
"""
from typing import List
from typing import Any, List

import torch
import torch.nn as nn
Expand Down Expand Up @@ -78,3 +78,49 @@ def forward(self, encoded_layers: List[torch.Tensor], pad_mask: torch.Tensor):
torch.sum(pad_mask, 1).float() + self.tol
)
return pooled_output


@registry.register_pooler("identity")
class IdentityPooler(nn.Module):
def forward(self, x: Any):
return x


@registry.register_pooler("cls")
class ClsPooler(nn.Module):
def __init__(self, dim=1, cls_index=0):
super().__init__()
self.dim = dim
self.cls_index = cls_index

def forward(self, last_hidden_state: torch.Tensor):
"""Returns the last layer hidden-state of the first token of of the
sequence, the classification (cls) token.
Args:
last_hidden_state (torch.Tensor): Sequence of hidden-state of
at the output of the last layer of the model (bs, seq length, hidden size)
Returns:
[torch.Tensor]: First token of the last hidden-state. (bs, hidden size)
"""
return last_hidden_state.select(dim=self.dim, index=self.cls_index)


@registry.register_pooler("avg")
class MeanPooler(nn.Module):
def __init__(self, dim=1):
super().__init__()
self.dim = dim

def forward(self, last_hidden_state: torch.Tensor):
"""Returns the averaged feature of last layer hidden-state sequence,
Args:
last_hidden_state (torch.Tensor): Sequence of hidden-state of
at the output of the last layer of the model (bs, seq length, hidden size)
Returns:
[torch.Tensor]: First token of the last hidden-state. (bs, hidden size)
"""
return torch.mean(last_hidden_state, dim=self.dim)
27 changes: 24 additions & 3 deletions tests/modules/test_poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,41 @@ def setUp(self):
]
self.pad_mask = torch.randn(self.batch_size, self.token_len).to(self.device)

def test_AverageConcat(self):
def test_average_concat(self):
pool_fn = poolers.AverageConcatLastN(self.k).to(self.device)
out = pool_fn(self.encoded_layers, self.pad_mask)

assert torch.Size([self.batch_size, self.embedding_size * self.k]) == out.shape

def test_AverageKFromLast(self):
def test_average_k_from_last(self):
pool_fn = poolers.AverageKFromLast(self.k).to(self.device)
out = pool_fn(self.encoded_layers, self.pad_mask)

assert torch.Size([self.batch_size, self.embedding_size]) == out.shape

def test_AverageSumLastK(self):
def test_average_sum_last_k(self):
pool_fn = poolers.AverageSumLastK(self.k).to(self.device)
out = pool_fn(self.encoded_layers, self.pad_mask)

assert torch.Size([self.batch_size, self.embedding_size]) == out.shape

def test_identity(self):
pool_fn = poolers.IdentityPooler().to(self.device)
out = pool_fn(self.encoded_layers[-1])

assert (
torch.Size([self.batch_size, self.token_len, self.embedding_size])
== out.shape
)

def test_cls(self):
pool_fn = poolers.ClsPooler().to(self.device)
out = pool_fn(self.encoded_layers[-1])

assert torch.Size([self.batch_size, self.embedding_size]) == out.shape

def test_average(self):
pool_fn = poolers.MeanPooler().to(self.device)
out = pool_fn(self.encoded_layers[-1])

assert torch.Size([self.batch_size, self.embedding_size]) == out.shape

0 comments on commit ee19bd9

Please sign in to comment.