From be8acfe0d23325d3ac66e79f78be3040a08f0d6d Mon Sep 17 00:00:00 2001 From: Martin Huschenbett Date: Wed, 20 Nov 2024 05:13:53 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 698353624 --- official/vision/modeling/heads/instance_heads.py | 4 ++-- .../vision/modeling/heads/segmentation_heads.py | 2 +- official/vision/modeling/layers/roi_aligner.py | 10 ++++++---- official/vision/modeling/retinanet_model.py | 14 ++++++++------ official/vision/modeling/segmentation_model.py | 5 +++-- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/official/vision/modeling/heads/instance_heads.py b/official/vision/modeling/heads/instance_heads.py index cf0b5ea3077..9a17537f13b 100644 --- a/official/vision/modeling/heads/instance_heads.py +++ b/official/vision/modeling/heads/instance_heads.py @@ -173,7 +173,7 @@ def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): super(DetectionHead, self).build(input_shape) - def call(self, inputs: tf.Tensor, training: bool = None): + def call(self, inputs: tf.Tensor, training: bool = None): # pytype: disable=annotation-type-mismatch """Forward pass of box and class branches for the Mask-RCNN model. Args: @@ -379,7 +379,7 @@ def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): super(MaskHead, self).build(input_shape) - def call(self, inputs: List[tf.Tensor], training: bool = None): + def call(self, inputs: List[tf.Tensor], training: bool = None): # pytype: disable=annotation-type-mismatch """Forward pass of mask branch for the Mask-RCNN model. Args: diff --git a/official/vision/modeling/heads/segmentation_heads.py b/official/vision/modeling/heads/segmentation_heads.py index 5234ba29188..cba82a222ca 100644 --- a/official/vision/modeling/heads/segmentation_heads.py +++ b/official/vision/modeling/heads/segmentation_heads.py @@ -167,7 +167,7 @@ def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]): super(MaskScoring, self).build(input_shape) - def call(self, inputs: tf.Tensor, training: bool = None): # pytype: disable=signature-mismatch # overriding-parameter-count-checks + def call(self, inputs: tf.Tensor, training: bool = None): # pytype: disable=annotation-type-mismatch,signature-mismatch """Forward pass mask scoring head. Args: diff --git a/official/vision/modeling/layers/roi_aligner.py b/official/vision/modeling/layers/roi_aligner.py index 71b5c7886ae..3debfcf1a7b 100644 --- a/official/vision/modeling/layers/roi_aligner.py +++ b/official/vision/modeling/layers/roi_aligner.py @@ -38,10 +38,12 @@ def __init__(self, crop_size: int = 7, sample_offset: float = 0.5, **kwargs): } super(MultilevelROIAligner, self).__init__(**kwargs) - def call(self, - features: Mapping[str, tf.Tensor], - boxes: tf.Tensor, - training: bool = None): + def call( + self, # pytype: disable=annotation-type-mismatch + features: Mapping[str, tf.Tensor], + boxes: tf.Tensor, + training: bool = None, + ): """Generates ROIs. Args: diff --git a/official/vision/modeling/retinanet_model.py b/official/vision/modeling/retinanet_model.py index 3be6785cc10..7ea889a120c 100644 --- a/official/vision/modeling/retinanet_model.py +++ b/official/vision/modeling/retinanet_model.py @@ -81,12 +81,14 @@ def __init__(self, self._detection_generator = detection_generator self._anchor_boxes = anchor_boxes - def call(self, - images: Union[tf.Tensor, Sequence[tf.Tensor]], - image_shape: Optional[tf.Tensor] = None, - anchor_boxes: Mapping[str, tf.Tensor] | None = None, - output_intermediate_features: bool = False, - training: bool = None) -> Mapping[str, tf.Tensor]: + def call( + self, # pytype: disable=annotation-type-mismatch + images: Union[tf.Tensor, Sequence[tf.Tensor]], + image_shape: Optional[tf.Tensor] = None, + anchor_boxes: Mapping[str, tf.Tensor] | None = None, + output_intermediate_features: bool = False, + training: bool = None, + ) -> Mapping[str, tf.Tensor]: """Forward pass of the RetinaNet model. Args: diff --git a/official/vision/modeling/segmentation_model.py b/official/vision/modeling/segmentation_model.py index 74134aa043a..c82e59fa024 100644 --- a/official/vision/modeling/segmentation_model.py +++ b/official/vision/modeling/segmentation_model.py @@ -58,8 +58,9 @@ def __init__(self, backbone: tf_keras.Model, decoder: tf_keras.Model, self.head = head self.mask_scoring_head = mask_scoring_head - def call(self, inputs: tf.Tensor, training: bool = None # pytype: disable=signature-mismatch # overriding-parameter-count-checks - ) -> Dict[str, tf.Tensor]: + def call( + self, inputs: tf.Tensor, training: bool = None # pytype: disable=annotation-type-mismatch,signature-mismatch + ) -> Dict[str, tf.Tensor]: backbone_features = self.backbone(inputs) if self.decoder: