Skip to content

Commit

Permalink
added padding in fusion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Madjid-CH committed May 17, 2024
1 parent f6741ea commit 1ee8fc4
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions auto_mixer/modules/fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import torch
import torch.nn.functional as F
from torch import nn


Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(self, dim=1, **_kwargs):
self.dim = dim

def __call__(self, *args):
args = pad_tensors(args)
return torch.cat(args, dim=self.dim)

def get_output_shape(self, *args, dim=None):
Expand Down Expand Up @@ -146,6 +148,15 @@ def get_output_shape(self, *args, dim=None):
return tuple(shape)


def pad_tensors(args):
max_size = max(arg.size(2) for arg in args)
padded_tensors = []
for tensor in args:
padding = (0, max_size - tensor.size(2))
padded_tensors.append(F.pad(tensor, padding))
return padded_tensors


class ConcatDynaFusion:
def __init__(self, dim=1, **_kwargs):
self.dim = dim
Expand Down

0 comments on commit 1ee8fc4

Please sign in to comment.