-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add propainter #33217
base: main
Are you sure you want to change the base?
Add propainter #33217
Conversation
Hi @molbap, I have addressed all the suggested changes, uniformized kwargs and all the stuff that you mentioned. I have put my questions and view about specific things in your review as well. Please check the changes out and iterate on the review and please let me know any more further changes that has to be made. Thank you 😄 cc @amyeroberts |
All tests are green |
soft ping @molbap |
On my radar - I'll review it as soon as I can! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @RUFFY-369 thanks for all the work here 🚀 I did a first pass to cover a few things that were not transformers-compatible, similar to the previous PR :) I have yet to cover processing and tests, but have a couple questions and some of my comments re: naming are applicable to the complete modeling file even though I haven't commented all of it. Ping me back when you've had time to include these changes!
The size of the sliding window for attention operations. | ||
pool_size (`List[int]`, *optional*, defaults to `[4, 4]`): | ||
The size of the pooling layers in the model. | ||
no_dis (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit on variable naming, avoid abbreviations no_dis
--> no_discriminator
would also be better to have it a positive action use_discriminator
defaulting to True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in the recent commits
stride=[3, 3], | ||
stride_3d=[1, 1, 1], | ||
num_hidden_layers=8, | ||
num_attention_heads=4, | ||
window_size=[5, 9], | ||
pool_size=[4, 4], | ||
no_dis=False, | ||
in_channels=[64, 64, 96], | ||
channels=[64, 96, 128], | ||
strides=[1, 2, 2], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stride
, stride_3d
and strides
are very similar -I'd advise going with different kwargs to remove some ambiguity, it's fine to use longer kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in the recent commits
if args.verify_logits: | ||
video, masks = prepare_input() | ||
image_processor = ProPainterVideoProcessor() | ||
inputs = image_processor(video, masks=masks, return_tensors="pt").to(device) | ||
outputs = model(**inputs) | ||
outputs_reconstruction = outputs.reconstruction | ||
|
||
assert torch.allclose( | ||
torch.tensor(outputs_reconstruction[0][0][-3:]), | ||
expected_output_reconstruction, | ||
atol=1e-4, | ||
) | ||
print("Looks good!") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this part to testing, no need to have it around the conversion file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed from weight conversion file and as per test_modeling file, this assertion is already there 👍
def rename_flow_completion(old_key, network_mapping): | ||
new_key = "" | ||
for old_prefix, new_prefix in network_mapping.items(): | ||
if old_prefix in old_key: | ||
new_key = old_key.replace(f"{old_prefix}", f"{new_prefix}") | ||
# Handle specific layer/block transformations | ||
if "mid_dilation" in new_key: | ||
new_key = new_key.replace("mid_dilation", "intermediate_dilation") | ||
if "feat_prop_module" in new_key: | ||
new_key = new_key.replace("feat_prop_module", "feature_propagation_module") | ||
if "edgeDetector.mid_layer" in new_key: | ||
new_key = new_key.replace("edgeDetector.mid_layer", "edgeDetector.intermediate_layer") | ||
|
||
return new_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Across the file, key renames can be handled by regexes here to avoid some if/else slightly harder to read logic, and also to see at a glance in one place which key in the original model corresponds to what key in transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in the recent commits
"encoder1": "flow_completion_net.encoder1", | ||
"encoder2": "flow_completion_net.encoder2", | ||
"decoder1": "flow_completion_net.decoder1", | ||
"decoder2": "flow_completion_net.decoder2", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rel to the regex comment, this could be for instance
network_mapping_completion = {
r"(downsample|encoder1|encoder2|decoder1|decoder2|upsample)": r"flow_completion_net.\1",
...
}
|
||
downsample_inputs = self.downsample(inputs) | ||
|
||
features_enc1 = self.encoder1(downsample_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nits:
naming, separate numbers from names + avoid abbreviations + remove comments in all the forward method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
return flow, edge | ||
|
||
def forward_bidirect_flow(self, masked_flows_bi, masks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming, no abbreviations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
super().__init__() | ||
self.config = config | ||
self.group = [1, 2, 4, 8, 1] | ||
negative_slope = 0.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's good! let's add it to the config too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if i == 8: | ||
x0 = features | ||
_, _, height, width = x0.size() | ||
if i > 8 and i % 2 == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a small comment here 😅 might not be intuitive to understand why there is this treatment of odd/even layers except at layer 8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Co-authored-by: Pablo Montalvo <[email protected]>
Co-authored-by: Pablo Montalvo <[email protected]>
…ow completion net
Hi @molbap I have addressed all your comments and left queries as well. Please review the rest of the remaining files and please iterate on the ones that are addressed when you get the time. Thank you 😄 |
@molbap Soft ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @RUFFY-369 ! I did a pass on naming and configs 🧹 because I identified persisting issues there, considering the very very hefty size of the modeling code - 4600 loc is indeed hefty - I think clarity is paramount to help code inspectors understand the logic flow. Let me know when you think you've addressed the suggestions on all the existing code, and I'll continue promptly 🤗
Also I'd suggest taking a look at currently failing tests, it's mostly docstring mismatches in the configs I believe
|
||
self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, padding=1, stride=stride) | ||
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) | ||
self.relu = nn.ReLU(inplace=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well fair enough, maybe just add a comment there saying so since it's a rarely seen pattern
# using itertools makes flattening a little faster :) | ||
self.resblocks = nn.ModuleList(list(itertools.chain.from_iterable(self.resblocks))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - do you know what difference does it make in numbers?
self.conv_corr1 = nn.Conv2d(correlation_planes, config.num_channels * 2, 1, padding=0) | ||
self.conv_corr2 = nn.Conv2d(config.num_channels * 2, 192, config.patch_size, padding=config.padding) | ||
self.conv_flow1 = nn.Conv2d(2, config.num_channels, config.kernel_size[0], padding=3) | ||
self.conv_flow2 = nn.Conv2d( | ||
config.num_channels, | ||
config.in_channels[0], | ||
config.patch_size, | ||
padding=config.padding, | ||
) | ||
self.conv = nn.Conv2d( | ||
config.in_channels[0] + 192, | ||
config.num_channels - 2, | ||
config.patch_size, | ||
padding=config.padding, | ||
) | ||
|
||
def forward(self, optical_flow, correlation): | ||
hidden_states_correlation = F.relu(self.conv_corr1(correlation)) | ||
hidden_states_correlation = F.relu(self.conv_corr2(hidden_states_correlation)) | ||
hidden_states_flow = F.relu(self.conv_flow1(optical_flow)) | ||
hidden_states_flow = F.relu(self.conv_flow2(hidden_states_flow)) | ||
|
||
hidden_states = torch.cat([hidden_states_correlation, hidden_states_flow], dim=1) | ||
hidden_states = F.relu(self.conv(hidden_states)) | ||
hidden_states = torch.cat([hidden_states, optical_flow], dim=1) | ||
|
||
return hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nits on naming, the comeback: when possible, avoid abbreviations and try to space numbers and letters for legibility.
conv_corr_1
is better than conv_corr1
, and so on.
self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) | ||
self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) | ||
self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) | ||
|
||
self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) | ||
self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) | ||
self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) | ||
|
||
def forward(self, hidden_states, motion_features): | ||
hidden_states_motion_features = torch.cat([hidden_states, motion_features], dim=1) | ||
z = torch.sigmoid(self.convz1(hidden_states_motion_features)) | ||
r = torch.sigmoid(self.convr1(hidden_states_motion_features)) | ||
q = torch.tanh(self.convq1(torch.cat([r * hidden_states, motion_features], dim=1))) | ||
hidden_states = (1 - z) * hidden_states + z * q | ||
hidden_states_motion_features = torch.cat([hidden_states, motion_features], dim=1) | ||
z = torch.sigmoid(self.convz2(hidden_states_motion_features)) | ||
r = torch.sigmoid(self.convr2(hidden_states_motion_features)) | ||
q = torch.tanh(self.convq2(torch.cat([r * hidden_states, motion_features], dim=1))) | ||
hidden_states = (1 - z) * hidden_states + z * q |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nits on naming, book 3, the return of the nit:
- avoid single-letter variables
- space letters and numbers
- avoid abbreviations, we have enough space to afford some more letters
I understand this adds up chores to the PR, but it's necessary for the legibility/harmony of the 200 models in the codebase! 🙇
hidden_dim: int = 128, | ||
input_dim: int = 192 + 128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two things here:
- why is the default
input_dim
defined like this? - ideally modules should only be inititalized with a configuration and a layer index - avoids hardcoding too many things.
In addition, ProPainterBasicUpdateBlock
is initialized with hidden_dim
only, meaning that the other argument could simply be contained in the configuration.
config: ProPainterConfig, | ||
in_channel: int = 2, | ||
out_channel: int = 1, | ||
intermediate_channel: int = 16, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same remark, we can move these values to the config and simply initialize the module with the configuration passed, with explicitly named keys
feat = ( | ||
[feat_current] | ||
+ [features[k][frame_id] for k in features if k not in ["spatial", module_name]] | ||
+ [feature_propagation] | ||
) | ||
|
||
feat = torch.cat(feat, dim=1) | ||
feature_propagation = feature_propagation + self.backbone[module_name](feat) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feat = ( | |
[feat_current] | |
+ [features[k][frame_id] for k in features if k not in ["spatial", module_name]] | |
+ [feature_propagation] | |
) | |
feat = torch.cat(feat, dim=1) | |
feature_propagation = feature_propagation + self.backbone[module_name](feat) | |
aggregated_features = ( | |
[feat_current] | |
+ [features[k][frame_id] for k in features if k not in ["spatial", module_name]] | |
+ [feature_propagation] | |
) | |
aggregated_features = torch.cat(aggregated_features, dim=1) | |
feature_propagation = feature_propagation + self.backbone[module_name](aggregated_features) |
pooling_token: bool = True, | ||
): | ||
super().__init__() | ||
assert hidden_size % num_attention_heads == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use raise rather than assert
mask_tl = torch.ones(self.window_size[0], self.window_size[1]) | ||
mask_tl[: -self.expand_size[0], : -self.expand_size[1]] = 0 | ||
mask_tr = torch.ones(self.window_size[0], self.window_size[1]) | ||
mask_tr[: -self.expand_size[0], self.expand_size[1] :] = 0 | ||
mask_bl = torch.ones(self.window_size[0], self.window_size[1]) | ||
mask_bl[self.expand_size[0] :, : -self.expand_size[1]] = 0 | ||
mask_br = torch.ones(self.window_size[0], self.window_size[1]) | ||
mask_br[self.expand_size[0] :, self.expand_size[1] :] = 0 | ||
masked_rolled_key = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0) | ||
self.register_buffer("valid_ind_rolled", masked_rolled_key.nonzero(as_tuple=False).view(-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
improve naming here too
|
||
self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0)) | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
considering the mighty size of this forward it's acceptable to cut it down into 2 sub-methods at least, for each branching
What does this PR do?
This PR adds ProPainter, a Video Inpainting model with 5.4k stars and 635 forks repo. It fixes #26360 and resolve stale PR #26391 for the above issue from complete scratch to build on with transformers standard.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@amyeroberts @ArthurZucker @NielsRogge (?)
@rafaelpadilla(as he was the initial reviewer on the stale PR)
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
The PR is more than ready for first pass of review!!!
TODO(will be done in a fly :)):
Results:
Here, I am attaching the GIFs for original video, original model's output for object removal through video inpainting and the current PR' HF model's output for object removal through video inpainting:
Original video:
Original model output:
HF ported model output:
Example usage is provided in the doc file here