Skip to content

Commit

Permalink
fix deprecated aux
Browse files Browse the repository at this point in the history
  • Loading branch information
volgachen committed Aug 18, 2021
1 parent ce1b5a6 commit 62e6322
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
17 changes: 5 additions & 12 deletions classification/models/dpt/depatch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,13 @@ def forward(self, x, model_offset=None):
return self.get_output(img, pred_offset, model_offset), (self.patch_count, self.patch_count)

class Simple_DePatch(Simple_Patch):
def __init__(self, box_coder, show_dim=4, use_auxiliary=-1, **kwargs):
def __init__(self, box_coder, show_dim=4, **kwargs):
super().__init__(show_dim, **kwargs)
self.box_coder = box_coder
self.register_buffer("value_spatial_shapes", torch.as_tensor([[self.img_size, self.img_size]], dtype=torch.long))
self.register_buffer("value_level_start_index", torch.as_tensor([0], dtype=torch.long))
self.output_proj = nn.Linear(self.in_chans * self.patch_pixel * self.patch_pixel, self.embed_dim)
self.num_sample_points = self.patch_pixel * self.patch_pixel * self.patch_count * self.patch_count
self.use_auxiliary = use_auxiliary > -1
if self.use_auxiliary:
self.classifier = Sample_Classifier(use_auxiliary, kwargs["in_chans"])
if kwargs["with_norm"]:
self.with_norm=True
self.norm = nn.LayerNorm(self.embed_dim)
Expand All @@ -111,11 +108,7 @@ def get_output(self, img, pred_offset, model_offset=None):
output = MSDeformAttnFunction.apply(x, self.value_spatial_shapes, self.value_level_start_index, sampling_locations, attention_weights, 1)
# output_proj
output = output.view(B, self.num_patches, self.in_chans*self.patch_pixel*self.patch_pixel)
if self.use_auxiliary:
aux_logits = self.classifier(img, pred_offset[:, :, :4])
return self.output_proj(output), aux_logits
else:
output = self.output_proj(output)
if self.with_norm:
output = self.norm(output)
return output
output = self.output_proj(output)
if self.with_norm:
output = self.norm(output)
return output
2 changes: 1 addition & 1 deletion detection/dpt_models/depatch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, x):
return self.get_output(img, pred_offset, img_size=(H, W), output_size=output_size), output_size

class Simple_DePatch(Simple_Patch):
def __init__(self, box_coder, show_dim=4, use_auxiliary=-1, **kwargs):
def __init__(self, box_coder, show_dim=4, **kwargs):
super().__init__(show_dim, **kwargs)
self.box_coder = box_coder
#self.register_buffer("value_spatial_shapes", torch.as_tensor([[self.img_size, self.img_size]], dtype=torch.long))
Expand Down

0 comments on commit 62e6322

Please sign in to comment.