Skip to content

Commit

Permalink
Add post_process_depth_estimation for GLPN (#34413)
Browse files Browse the repository at this point in the history
* add depth postprocessing for GLPN

* remove previous temp fix for glpn tests

* Style changes for GLPN's `post_process_depth_estimation`

Co-authored-by: Arthur <[email protected]>

* additional style fix

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
alex-bene and ArthurZucker authored Oct 28, 2024
1 parent 6cc4a67 commit a769ed4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
54 changes: 52 additions & 2 deletions src/transformers/models/glpn/image_processing_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
# limitations under the License.
"""Image processor class for GLPN."""

from typing import List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union


if TYPE_CHECKING:
from ...modeling_outputs import DepthEstimatorOutput

import numpy as np
import PIL.Image
Expand All @@ -27,12 +31,17 @@
get_image_size,
infer_channel_dimension_format,
is_scaled_image,
is_torch_available,
make_list_of_images,
to_numpy_array,
valid_images,
validate_preprocess_arguments,
)
from ...utils import TensorType, filter_out_non_signature_kwargs, logging
from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends


if is_torch_available():
import torch


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -218,3 +227,44 @@ def preprocess(

data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

def post_process_depth_estimation(
self,
outputs: "DepthEstimatorOutput",
target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None,
) -> List[Dict[str, TensorType]]:
"""
Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images.
Only supports PyTorch.
Args:
outputs ([`DepthEstimatorOutput`]):
Raw outputs of the model.
target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
(height, width) of each image in the batch. If left to None, predictions will not be resized.
Returns:
`List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth
predictions.
"""
requires_backends(self, "torch")

predicted_depth = outputs.predicted_depth

if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the predicted depth"
)

results = []
target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes
for depth, target_size in zip(predicted_depth, target_sizes):
if target_size is not None:
depth = depth[None, None, ...]
depth = torch.nn.functional.interpolate(depth, size=target_size, mode="bicubic", align_corners=False)
depth = depth.squeeze()

results.append({"predicted_depth": depth})

return results
16 changes: 7 additions & 9 deletions src/transformers/models/glpn/modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,20 +723,18 @@ def forward(
>>> with torch.no_grad():
... outputs = model(**inputs)
... predicted_depth = outputs.predicted_depth
>>> # interpolate to original size
>>> prediction = torch.nn.functional.interpolate(
... predicted_depth.unsqueeze(1),
... size=image.size[::-1],
... mode="bicubic",
... align_corners=False,
>>> post_processed_output = image_processor.post_process_depth_estimation(
... outputs,
... target_sizes=[(image.height, image.width)],
... )
>>> # visualize the prediction
>>> output = prediction.squeeze().cpu().numpy()
>>> formatted = (output * 255 / np.max(output)).astype("uint8")
>>> depth = Image.fromarray(formatted)
>>> predicted_depth = post_processed_output[0]["predicted_depth"]
>>> depth = predicted_depth * 255 / predicted_depth.max()
>>> depth = depth.detach().cpu().numpy()
>>> depth = Image.fromarray(depth.astype("uint8"))
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
Expand Down
8 changes: 0 additions & 8 deletions tests/models/glpn/test_modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,6 @@ def setUp(self):
self.model_tester = GLPNModelTester(self)
self.config_tester = GLPNConfigTester(self, config_class=GLPNConfig)

@unittest.skip(reason="Failing after #32550")
def test_pipeline_depth_estimation(self):
pass

@unittest.skip(reason="Failing after #32550")
def test_pipeline_depth_estimation_fp16(self):
pass

def test_config(self):
self.config_tester.run_common_tests()

Expand Down

0 comments on commit a769ed4

Please sign in to comment.