Skip to content

Commit

Permalink
Add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
nsoranzo committed Feb 29, 2024
1 parent 711fd6a commit 3d3ad72
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 36 deletions.
51 changes: 31 additions & 20 deletions lib/galaxy/tool_util/verify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
Any,
Callable,
Dict,
List,
Optional,
TYPE_CHECKING,
)

import numpy
Expand All @@ -37,6 +39,9 @@
from .asserts import verify_assertions
from .test_data import TestDataResolver

if TYPE_CHECKING:
import numpy.typing

log = logging.getLogger(__name__)

DEFAULT_TEST_DATA_RESOLVER = TestDataResolver()
Expand Down Expand Up @@ -442,21 +447,23 @@ def files_contains(file1, file2, attributes=None):
raise AssertionError(f"Failed to find '{contains}' in history data. (lines_diff={lines_diff}).")


def _multiobject_intersection_over_union(mask1, mask2, repeat_reverse=True):
iou_list = list()
def _multiobject_intersection_over_union(
mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray", repeat_reverse: bool = True
) -> List[numpy.floating]:
iou_list = []
for label1 in numpy.unique(mask1):
cc1 = mask1 == label1
cc1_iou_list = list()
cc1_iou_list = []
for label2 in numpy.unique(mask2[cc1]):
cc2 = mask2 == label2
cc1_iou_list.append(intersection_over_union(cc1, cc2))
iou_list.append(max(cc1_iou_list))
if repeat_reverse:
iou_list += _multiobject_intersection_over_union(mask2, mask1, repeat_reverse=False)
iou_list.extend(_multiobject_intersection_over_union(mask2, mask1, repeat_reverse=False))
return iou_list


def intersection_over_union(mask1, mask2):
def intersection_over_union(mask1: "numpy.typing.NDArray", mask2: "numpy.typing.NDArray") -> numpy.floating:
assert mask1.dtype == mask2.dtype
assert mask1.ndim == mask2.ndim == 2
assert mask1.shape == mask2.shape
Expand All @@ -466,36 +473,40 @@ def intersection_over_union(mask1, mask2):
return min(_multiobject_intersection_over_union(mask1, mask2))


def get_image_metric(attributes):
def get_image_metric(
attributes: Dict[str, Any]
) -> Callable[["numpy.typing.NDArray", "numpy.typing.NDArray"], numpy.floating]:
metric_name = attributes.get("metric", DEFAULT_METRIC)
attributes = attributes or {}
metrics = {
"mae": lambda im1, im2: numpy.abs(im1 - im2).mean(),
"mse": lambda im1, im2: numpy.square((im1 - im2).astype(float)).mean(),
"rms": lambda im1, im2: math.sqrt(numpy.square((im1 - im2).astype(float)).mean()),
"fro": lambda im1, im2: numpy.linalg.norm((im1 - im2).reshape(1, -1), "fro"),
"iou": lambda im1, im2: 1 - intersection_over_union(im1, im2),
"mae": lambda arr1, arr2: numpy.abs(arr1 - arr2).mean(),
# Convert to float before squaring to prevent overflows
"mse": lambda arr1, arr2: numpy.square((arr1 - arr2).astype(float)).mean(),
"rms": lambda arr1, arr2: math.sqrt(numpy.square((arr1 - arr2).astype(float)).mean()),
"fro": lambda arr1, arr2: numpy.linalg.norm((arr1 - arr2).reshape(1, -1), "fro"),
"iou": lambda arr1, arr2: 1 - intersection_over_union(arr1, arr2),
}
try:
return metrics[metric_name]
except KeyError:
raise ValueError(f'No such metric: "{metric_name}"')


def files_image_diff(file1, file2, attributes=None):
def files_image_diff(file1: str, file2: str, attributes: Optional[Dict[str, Any]] = None) -> None:
"""Check the pixel data of 2 image files for differences."""
attributes = attributes or {}

im1 = numpy.array(Image.open(file1))
im2 = numpy.array(Image.open(file2))
with Image.open(file1) as im1:
arr1 = numpy.array(im1)
with Image.open(file2) as im2:
arr2 = numpy.array(im2)

if im1.dtype != im2.dtype:
raise AssertionError(f"Image data types did not match ({im1.dtype}, {im2.dtype}).")
if arr1.dtype != arr2.dtype:
raise AssertionError(f"Image data types did not match ({arr1.dtype}, {arr2.dtype}).")

if im1.shape != im2.shape:
raise AssertionError(f"Image dimensions did not match ({im1.shape}, {im2.shape}).")
if arr1.shape != arr2.shape:
raise AssertionError(f"Image dimensions did not match ({arr1.shape}, {arr2.shape}).")

distance = get_image_metric(attributes)(im1, im2)
distance = get_image_metric(attributes)(arr1, arr2)
distance_eps = attributes.get("eps", DEFAULT_EPS)
if distance > distance_eps:
raise AssertionError(f"Image difference {distance} exceeds eps={distance_eps}.")
6 changes: 2 additions & 4 deletions lib/galaxy/util/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,9 @@ def iter_zip(file_path: str):
yield (z.open(f), f)


def check_image(file_path: str):
def check_image(file_path: str) -> bool:
"""Simple wrapper around image_type to yield a True/False verdict"""
if image_type(file_path):
return True
return False
return bool(image_type(file_path))


COMPRESSION_CHECK_FUNCTIONS: Dict[str, CompressionChecker] = {
Expand Down
24 changes: 12 additions & 12 deletions lib/galaxy/util/image_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@

import imghdr
import logging
from typing import (
List,
Optional,
)

try:
import Image as PIL
from PIL import Image
except ImportError:
try:
from PIL import Image as PIL
except ImportError:
PIL = None
PIL = None

log = logging.getLogger(__name__)


def image_type(filename):
def image_type(filename: str) -> Optional[str]:
fmt = None
if PIL is not None:
if Image is not None:
try:
im = PIL.open(filename)
fmt = im.format
im.close()
with Image.open(filename) as im:
fmt = im.format
except Exception:
# We continue to try with imghdr, so this is a rare case of an
# exception we expect to happen frequently, so we're not logging
Expand All @@ -30,10 +30,10 @@ def image_type(filename):
if fmt:
return fmt.upper()
else:
return False
return None


def check_image_type(filename, types):
def check_image_type(filename: str, types: List[str]) -> bool:
fmt = image_type(filename)
if fmt in types:
return True
Expand Down

0 comments on commit 3d3ad72

Please sign in to comment.