From 858aab3f4c0e5fe45c7142579e6eb3e90ae51cb4 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 29 Aug 2021 06:45:42 -0700 Subject: [PATCH] add extra feedforward to perceiver io --- perceiver_pytorch/perceiver_io.py | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/perceiver_pytorch/perceiver_io.py b/perceiver_pytorch/perceiver_io.py index f5057a2..8e34606 100644 --- a/perceiver_pytorch/perceiver_io.py +++ b/perceiver_pytorch/perceiver_io.py @@ -149,6 +149,8 @@ def __init__( ])) self.decoder_cross_attn = PreNorm(queries_dim, Attention(queries_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head), context_dim = latent_dim) + self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) + self.to_logits = nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity() def forward( @@ -177,6 +179,7 @@ def forward( # cross attend from decoder queries to latents latents = self.decoder_cross_attn(queries, context = x) + latents = self.decoder_ff(latents) # final linear out diff --git a/setup.py b/setup.py index 7b36dd5..cb80ab9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-pytorch', packages = find_packages(), - version = '0.5.1', + version = '0.6.0', license='MIT', description = 'Perceiver - Pytorch', author = 'Phil Wang',