Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add propainter #33217

Open
wants to merge 169 commits into
base: main
Choose a base branch
from
Open

Add propainter #33217

wants to merge 169 commits into from

Conversation

RUFFY-369
Copy link
Contributor

@RUFFY-369 RUFFY-369 commented Aug 30, 2024

What does this PR do?

This PR adds ProPainter, a Video Inpainting model with 5.4k stars and 635 forks repo. It fixes #26360 and resolve stale PR #26391 for the above issue from complete scratch to build on with transformers standard.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts @ArthurZucker @NielsRogge (?)
@rafaelpadilla(as he was the initial reviewer on the stale PR)

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

The PR is more than ready for first pass of review!!!

TODO(will be done in a fly :)):

  • Fix all common test failures
  • Update weights conversion scripts with the working one on local machine
  • Review batching nits one more time in the applicable files
  • Update docs in corresponding files
  • Check for video 'outpainting' error

Results:

Here, I am attaching the GIFs for original video, original model's output for object removal through video inpainting and the current PR' HF model's output for object removal through video inpainting:

Original video:
original

Original model output:
original_removal

HF ported model output:

hf_removal

Example usage is provided in the doc file here

@RUFFY-369
Copy link
Contributor Author

Hi @RUFFY-369 , I added a couple comments - overall it looks like there's a lot of tricks to keep the memory usage low which is good ! Now they need to be more inline with the lib standards :) relative to processing, we are also uniformizing the way models process inputs and I left a few additional comments on that. LMK what you think and I'll iterate on the review!

Hi @molbap, I have addressed all the suggested changes, uniformized kwargs and all the stuff that you mentioned. I have put my questions and view about specific things in your review as well.

Please check the changes out and iterate on the review and please let me know any more further changes that has to be made.

Thank you 😄

cc @amyeroberts

@RUFFY-369
Copy link
Contributor Author

All tests are green

@RUFFY-369
Copy link
Contributor Author

soft ping @molbap
Thank you 😄

@molbap
Copy link
Contributor

molbap commented Oct 11, 2024

On my radar - I'll review it as soon as I can!

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @RUFFY-369 thanks for all the work here 🚀 I did a first pass to cover a few things that were not transformers-compatible, similar to the previous PR :) I have yet to cover processing and tests, but have a couple questions and some of my comments re: naming are applicable to the complete modeling file even though I haven't commented all of it. Ping me back when you've had time to include these changes!

The size of the sliding window for attention operations.
pool_size (`List[int]`, *optional*, defaults to `[4, 4]`):
The size of the pooling layers in the model.
no_dis (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit on variable naming, avoid abbreviations no_dis --> no_discriminator
would also be better to have it a positive action use_discriminator defaulting to True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in the recent commits

Comment on lines 155 to 164
stride=[3, 3],
stride_3d=[1, 1, 1],
num_hidden_layers=8,
num_attention_heads=4,
window_size=[5, 9],
pool_size=[4, 4],
no_dis=False,
in_channels=[64, 64, 96],
channels=[64, 96, 128],
strides=[1, 2, 2],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stride, stride_3d and strides are very similar -I'd advise going with different kwargs to remove some ambiguity, it's fine to use longer kwargs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in the recent commits

Comment on lines 219 to 231
if args.verify_logits:
video, masks = prepare_input()
image_processor = ProPainterVideoProcessor()
inputs = image_processor(video, masks=masks, return_tensors="pt").to(device)
outputs = model(**inputs)
outputs_reconstruction = outputs.reconstruction

assert torch.allclose(
torch.tensor(outputs_reconstruction[0][0][-3:]),
expected_output_reconstruction,
atol=1e-4,
)
print("Looks good!")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this part to testing, no need to have it around the conversion file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed from weight conversion file and as per test_modeling file, this assertion is already there 👍

Comment on lines 62 to 75
def rename_flow_completion(old_key, network_mapping):
new_key = ""
for old_prefix, new_prefix in network_mapping.items():
if old_prefix in old_key:
new_key = old_key.replace(f"{old_prefix}", f"{new_prefix}")
# Handle specific layer/block transformations
if "mid_dilation" in new_key:
new_key = new_key.replace("mid_dilation", "intermediate_dilation")
if "feat_prop_module" in new_key:
new_key = new_key.replace("feat_prop_module", "feature_propagation_module")
if "edgeDetector.mid_layer" in new_key:
new_key = new_key.replace("edgeDetector.mid_layer", "edgeDetector.intermediate_layer")

return new_key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Across the file, key renames can be handled by regexes here to avoid some if/else slightly harder to read logic, and also to see at a glance in one place which key in the original model corresponds to what key in transformers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in the recent commits

Comment on lines 100 to 103
"encoder1": "flow_completion_net.encoder1",
"encoder2": "flow_completion_net.encoder2",
"decoder1": "flow_completion_net.decoder1",
"decoder2": "flow_completion_net.decoder2",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rel to the regex comment, this could be for instance

network_mapping_completion = {
    r"(downsample|encoder1|encoder2|decoder1|decoder2|upsample)": r"flow_completion_net.\1",
    ...
}


downsample_inputs = self.downsample(inputs)

features_enc1 = self.encoder1(downsample_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits:
naming, separate numbers from names + avoid abbreviations + remove comments in all the forward method

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


return flow, edge

def forward_bidirect_flow(self, masked_flows_bi, masks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming, no abbreviations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

super().__init__()
self.config = config
self.group = [1, 2, 4, 8, 1]
negative_slope = 0.2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's good! let's add it to the config too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 1466 to 1469
if i == 8:
x0 = features
_, _, height, width = x0.size()
if i > 8 and i % 2 == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a small comment here 😅 might not be intuitive to understand why there is this treatment of odd/even layers except at layer 8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@RUFFY-369
Copy link
Contributor Author

Hey @RUFFY-369 thanks for all the work here 🚀 I did a first pass to cover a few things that were not transformers-compatible, similar to the previous PR :) I have yet to cover processing and tests, but have a couple questions and some of my comments re: naming are applicable to the complete modeling file even though I haven't commented all of it. Ping me back when you've had time to include these changes!

Hi @molbap I have addressed all your comments and left queries as well. Please review the rest of the remaining files and please iterate on the ones that are addressed when you get the time.

Thank you 😄

@RUFFY-369 RUFFY-369 requested a review from molbap October 25, 2024 06:03
@RUFFY-369
Copy link
Contributor Author

@molbap Soft ping
Thanks 😄

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @RUFFY-369 ! I did a pass on naming and configs 🧹 because I identified persisting issues there, considering the very very hefty size of the modeling code - 4600 loc is indeed hefty - I think clarity is paramount to help code inspectors understand the logic flow. Let me know when you think you've addressed the suggestions on all the existing code, and I'll continue promptly 🤗
Also I'd suggest taking a look at currently failing tests, it's mostly docstring mismatches in the configs I believe


self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well fair enough, maybe just add a comment there saying so since it's a rarely seen pattern

Comment on lines +170 to +171
# using itertools makes flattening a little faster :)
self.resblocks = nn.ModuleList(list(itertools.chain.from_iterable(self.resblocks)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - do you know what difference does it make in numbers?

Comment on lines +210 to +236
self.conv_corr1 = nn.Conv2d(correlation_planes, config.num_channels * 2, 1, padding=0)
self.conv_corr2 = nn.Conv2d(config.num_channels * 2, 192, config.patch_size, padding=config.padding)
self.conv_flow1 = nn.Conv2d(2, config.num_channels, config.kernel_size[0], padding=3)
self.conv_flow2 = nn.Conv2d(
config.num_channels,
config.in_channels[0],
config.patch_size,
padding=config.padding,
)
self.conv = nn.Conv2d(
config.in_channels[0] + 192,
config.num_channels - 2,
config.patch_size,
padding=config.padding,
)

def forward(self, optical_flow, correlation):
hidden_states_correlation = F.relu(self.conv_corr1(correlation))
hidden_states_correlation = F.relu(self.conv_corr2(hidden_states_correlation))
hidden_states_flow = F.relu(self.conv_flow1(optical_flow))
hidden_states_flow = F.relu(self.conv_flow2(hidden_states_flow))

hidden_states = torch.cat([hidden_states_correlation, hidden_states_flow], dim=1)
hidden_states = F.relu(self.conv(hidden_states))
hidden_states = torch.cat([hidden_states, optical_flow], dim=1)

return hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits on naming, the comeback: when possible, avoid abbreviations and try to space numbers and letters for legibility.
conv_corr_1 is better than conv_corr1, and so on.

Comment on lines +249 to +267
self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))

self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))

def forward(self, hidden_states, motion_features):
hidden_states_motion_features = torch.cat([hidden_states, motion_features], dim=1)
z = torch.sigmoid(self.convz1(hidden_states_motion_features))
r = torch.sigmoid(self.convr1(hidden_states_motion_features))
q = torch.tanh(self.convq1(torch.cat([r * hidden_states, motion_features], dim=1)))
hidden_states = (1 - z) * hidden_states + z * q
hidden_states_motion_features = torch.cat([hidden_states, motion_features], dim=1)
z = torch.sigmoid(self.convz2(hidden_states_motion_features))
r = torch.sigmoid(self.convr2(hidden_states_motion_features))
q = torch.tanh(self.convq2(torch.cat([r * hidden_states, motion_features], dim=1)))
hidden_states = (1 - z) * hidden_states + z * q
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nits on naming, book 3, the return of the nit:

  • avoid single-letter variables
  • space letters and numbers
  • avoid abbreviations, we have enough space to afford some more letters
    I understand this adds up chores to the PR, but it's necessary for the legibility/harmony of the 200 models in the codebase! 🙇

Comment on lines +243 to +244
hidden_dim: int = 128,
input_dim: int = 192 + 128,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two things here:

  • why is the default input_dim defined like this?
  • ideally modules should only be inititalized with a configuration and a layer index - avoids hardcoding too many things.

In addition, ProPainterBasicUpdateBlock is initialized with hidden_dim only, meaning that the other argument could simply be contained in the configuration.

Comment on lines +553 to +557
config: ProPainterConfig,
in_channel: int = 2,
out_channel: int = 1,
intermediate_channel: int = 16,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same remark, we can move these values to the config and simply initialize the module with the configuration passed, with explicitly named keys

Comment on lines +690 to +697
feat = (
[feat_current]
+ [features[k][frame_id] for k in features if k not in ["spatial", module_name]]
+ [feature_propagation]
)

feat = torch.cat(feat, dim=1)
feature_propagation = feature_propagation + self.backbone[module_name](feat)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
feat = (
[feat_current]
+ [features[k][frame_id] for k in features if k not in ["spatial", module_name]]
+ [feature_propagation]
)
feat = torch.cat(feat, dim=1)
feature_propagation = feature_propagation + self.backbone[module_name](feat)
aggregated_features = (
[feat_current]
+ [features[k][frame_id] for k in features if k not in ["spatial", module_name]]
+ [feature_propagation]
)
aggregated_features = torch.cat(aggregated_features, dim=1)
feature_propagation = feature_propagation + self.backbone[module_name](aggregated_features)

pooling_token: bool = True,
):
super().__init__()
assert hidden_size % num_attention_heads == 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use raise rather than assert

Comment on lines +1599 to +1608
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
mask_tl[: -self.expand_size[0], : -self.expand_size[1]] = 0
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
mask_tr[: -self.expand_size[0], self.expand_size[1] :] = 0
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
mask_bl[self.expand_size[0] :, : -self.expand_size[1]] = 0
mask_br = torch.ones(self.window_size[0], self.window_size[1])
mask_br[self.expand_size[0] :, self.expand_size[1] :] = 0
masked_rolled_key = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
self.register_buffer("valid_ind_rolled", masked_rolled_key.nonzero(as_tuple=False).view(-1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve naming here too


self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0))

def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

considering the mighty size of this forward it's acceptable to cut it down into 2 sub-methods at least, for each branching

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add ProPainter to transformers
7 participants