Skip to content

Commit

Permalink
perf: improved performance in various methods in Image and `ImageLi…
Browse files Browse the repository at this point in the history
…st` (#879)

### Summary of Changes

1. Improved memory usage and runtime in:

- `Image`
  - `convert_to_grayscale`
  - `adjust_brightness`
  - `add_noise`
  - `adjust_contrast`
  - `adjust_color_balance`
  - `find_edges`
- `ImageList`
  - `from_images`
  - `add_image`
  - `add_images`
  - `remove_image_by_index`
  - `remove_duplicate_images`
  - `convert_to_grayscale`
  - `resize`
  - `crop`
  - `adjust_brightness`
  - `add_noise`
  - `adjust_contrast`
  - `adjust_color_balance`
  - `find_edges`

2. Changed `blur` algorithm in `Image` and `ImageList` from Gaussian
blur to box blur
3. Fixed a bug in `blur` and `sharpen`, that they could not work with
Tensors of size greater than 2**31

### Details to the performance upgrades:

_These details will explain the performance upgrades in the `ImageList`.
All performance upgrades and changes in `Image` are made according to
the changes in `ImageList`_

#### Early stopping

`convert_to_grayscale` returns `self` when it has only one channel (in
this case, the `ImageList` is already in grayscale)
`remove_duplicate_images` returns `self` if the unique image tensor has
the same size as the original (in this case, there are no duplicates)
`adjust_brightness` returns `ImageList` with complete black images if
factor is 0
`adjust_color_balance` returns `ImageList.convert_to_grayscale` if
factor is 0

#### General changes

If a float Tensor is used during the computation, it will be `float16`
instead of `float32`
Improved the order of Tensor allocations, so that there will be fewer
problems with tensors not being completely on the VRAM

#### Benchmark

_Only the transformation methods have benchmarks for the runtime. Their
benchmark includes only changes over/under a change factor of 0.25,
rounded to one decimal point, as runtime depends on multiple factors and
fluctuates heavily with changes of only a few milliseconds in most
cases._
_All differences are measured as a factor compared to the original
results. That means a factor below 1 is worse, while a factor above 1 is
a better result. For readability, all factors equal to 1 are not
included._
_Due to the bug fix mentioned above for `blur` and `sharpen` the
performance of these methods decreased in most cases._

##### Benchmark with RGB images of size 250×250
| method | result size difference | runtime difference | max memory
allocation during runtime difference |

|-------------------------|------------------------|--------------------|-------------------------------------------------|
| `from_images` | 8 | | 7,95 |
| `remove_image_by_index` | | | 1,5 |
| `adjust_brightness` | | 1,6 | 1,5 |
| `add_noise` | 4 | 42,5 | 2,25 |
| `adjust_contrast` | | 1,4 | 1,69 |
| `adjust_color_balance` | 4 | 6,3 | 3 |
| `blur` | | 0,7 | 0,9 |
| `sharpen` | | 0,6 | 0,91 |
| `find_edges` | | 1,3 | |

##### Benchmark with RGBA images (RGB images with transparent layer) of
size 256×256
| method | result size difference | runtime difference | max memory
allocation during runtime difference |

|-------------------------|------------------------|--------------------|-------------------------------------------------|
| `from_images` | 8 | | 6,29 |
| `remove_image_by_index` | | | 1,5 |
| `adjust_brightness` | | 1,7 | 1,36 |
| `add_noise` | 4 | 24,5 | 2,25 |
| `adjust_contrast` | | | 1,53 |
| `adjust_color_balance` | 4 | 1,7 | 2,6 |
| `blur` | | 1,6 | 0,9 |
| `sharpen` | | | 0,89 |

##### Benchmark with RGB images of multiple different sizes
| method | result size difference | runtime difference | max memory
allocation during runtime difference |

|-------------------------|------------------------|--------------------|-------------------------------------------------|
| `from_images` | 7,7 | | 6,3 |
| `add_images` | 8 | | 2,38 |
| `remove_image_by_index` | | | 1,28 |
| `resize` | | 1,3 | |
| `crop` | | 1,3 | 1,09 |
| `adjust_brightness` | | 1,3 | 1,09 |
| `add_noise` | 4 | 44,3 | 1,47 |
| `adjust_contrast` | | 2 | 2,17 |
| `adjust_color_balance` | 4 | 6,7 | 1,65 |
| `blur` | | | 2,87 |
| `sharpen` | | 0,7 | 0,9 |
| `find_edges` | | 1,9 | 0,91 |

---------

Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
Marsmaennchen221 and megalinter-bot authored Jul 12, 2024
1 parent a88a609 commit 134e7d8
Show file tree
Hide file tree
Showing 89 changed files with 615 additions and 222 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,13 @@ def _check_resize_errors(new_width: int, new_height: int) -> None:
_check_bounds("new_height", new_height, lower_bound=_ClosedBound(1))


def _check_crop_errors_and_warnings(
def _check_crop_warnings(
x: int,
y: int,
width: int,
height: int,
min_width: int,
min_height: int,
plural: bool,
) -> None:
_check_bounds("x", x, lower_bound=_ClosedBound(0))
_check_bounds("y", y, lower_bound=_ClosedBound(0))
_check_bounds("width", width, lower_bound=_ClosedBound(1))
_check_bounds("height", height, lower_bound=_ClosedBound(1))

if x >= min_width or y >= min_height:
warnings.warn(
f"The specified bounding rectangle does not contain any content of {'at least one' if plural else 'the'} image. Therefore {'these images' if plural else 'the image'} will be blank.",
Expand All @@ -35,6 +28,18 @@ def _check_crop_errors_and_warnings(
)


def _check_crop_errors(
x: int,
y: int,
width: int,
height: int,
) -> None:
_check_bounds("x", x, lower_bound=_ClosedBound(0))
_check_bounds("y", y, lower_bound=_ClosedBound(0))
_check_bounds("width", width, lower_bound=_ClosedBound(1))
_check_bounds("height", height, lower_bound=_ClosedBound(1))


def _check_adjust_brightness_errors_and_warnings(factor: float, plural: bool) -> None:
_check_bounds("factor", factor, lower_bound=_ClosedBound(0))
if factor == 1:
Expand Down
12 changes: 2 additions & 10 deletions src/safeds/data/image/containers/_empty_image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_check_adjust_color_balance_errors_and_warnings,
_check_adjust_contrast_errors_and_warnings,
_check_blur_errors_and_warnings,
_check_crop_errors_and_warnings,
_check_crop_errors,
_check_remove_images_with_size_errors,
_check_resize_errors,
_check_sharpen_errors_and_warnings,
Expand Down Expand Up @@ -161,15 +161,7 @@ def convert_to_grayscale(self) -> ImageList:

def crop(self, x: int, y: int, width: int, height: int) -> ImageList:
_EmptyImageList._warn_empty_image_list()
_check_crop_errors_and_warnings(
x,
y,
width,
height,
x + 1,
y + 1,
plural=True,
) # Disable x|y >= min_width|min_height check with min_width|min_height=x|y+1
_check_crop_errors(x, y, width, height)
return _EmptyImageList()

def flip_vertically(self) -> ImageList:
Expand Down
157 changes: 105 additions & 52 deletions src/safeds/data/image/containers/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
_check_adjust_color_balance_errors_and_warnings,
_check_adjust_contrast_errors_and_warnings,
_check_blur_errors_and_warnings,
_check_crop_errors_and_warnings,
_check_crop_errors,
_check_crop_warnings,
_check_resize_errors,
_check_sharpen_errors_and_warnings,
)
Expand Down Expand Up @@ -46,7 +47,7 @@ def _filter_edges_kernel() -> Tensor:

if Image._filter_edges_kernel_cache is None:
Image._filter_edges_kernel_cache = (
torch.tensor([[-1.0, -1.0, -1.0], [-1.0, 8.0, -1.0], [-1.0, -1.0, -1.0]])
torch.tensor([[-1.0, -1.0, -1.0], [-1.0, 8.0, -1.0], [-1.0, -1.0, -1.0]], dtype=torch.float16)
.unsqueeze(dim=0)
.unsqueeze(dim=0)
.to(_get_device())
Expand Down Expand Up @@ -118,7 +119,13 @@ def from_bytes(data: bytes) -> Image:
return Image(image_tensor=torchvision.io.decode_image(input_tensor).to(_get_device()))

def __init__(self, image_tensor: Tensor) -> None:
self._image_tensor: Tensor = image_tensor
import torch

self._image_tensor: Tensor
if image_tensor.dtype != torch.uint8:
self._image_tensor = torch.clamp(image_tensor, 0, 255).to(torch.uint8)
else:
self._image_tensor = image_tensor

def __eq__(self, other: object) -> bool:
"""
Expand Down Expand Up @@ -444,18 +451,20 @@ def convert_to_grayscale(self) -> Image:

_init_default_device()

if self.channel == 4:
if self.channel == 1:
return self
elif self.channel == 4:
return Image(
torch.cat(
[
func2.rgb_to_grayscale(self._image_tensor[0:3], num_output_channels=3),
self._image_tensor[3].unsqueeze(dim=0),
self._image_tensor[3:4],
],
),
)
else:
else: # channel == 3
return Image(
func2.rgb_to_grayscale(self._image_tensor[0:3], num_output_channels=self.channel),
func2.rgb_to_grayscale(self._image_tensor[0:3], num_output_channels=3),
)

def crop(self, x: int, y: int, width: int, height: int) -> Image:
Expand Down Expand Up @@ -489,7 +498,8 @@ def crop(self, x: int, y: int, width: int, height: int) -> Image:

_init_default_device()

_check_crop_errors_and_warnings(x, y, width, height, self.width, self.height, plural=False)
_check_crop_errors(x, y, width, height)
_check_crop_warnings(x, y, self.width, self.height, plural=False)
return Image(func2.crop(self._image_tensor, y, x, height, width))

def flip_vertically(self) -> Image:
Expand Down Expand Up @@ -552,22 +562,31 @@ def adjust_brightness(self, factor: float) -> Image:
If factor is smaller than 0.
"""
import torch
from torchvision.transforms.v2 import functional as func2

_init_default_device()

_check_adjust_brightness_errors_and_warnings(factor, plural=False)
if self.channel == 4:
return Image(
torch.cat(
[
func2.adjust_brightness(self._image_tensor[0:3], factor * 1.0),
self._image_tensor[3].unsqueeze(dim=0),
],
),
)
if self._image_tensor.size(dim=-3) != 4:
if factor == 0:
return Image(torch.zeros(self._image_tensor.size(), dtype=torch.uint8))
else:
temp_tensor = self._image_tensor * torch.tensor([factor * 1.0], dtype=torch.float16)
torch.clamp(temp_tensor, 0, 255, out=temp_tensor)
return Image(temp_tensor.to(torch.uint8))
else:
return Image(func2.adjust_brightness(self._image_tensor, factor * 1.0))
img_tensor = torch.empty(self._image_tensor.size(), dtype=torch.uint8)
img_tensor[3] = self._image_tensor[3]
if factor == 0:
torch.zeros(
(3, self._image_tensor.size(dim=-2), self._image_tensor.size(dim=-1)),
dtype=torch.uint8,
out=img_tensor[:, 0:3],
)
else:
temp_tensor = self._image_tensor[0:3] * torch.tensor([factor * 1.0], dtype=torch.float16)
torch.clamp(temp_tensor, 0, 255, out=temp_tensor)
img_tensor[0:3] = temp_tensor[:]
return Image(img_tensor)

def add_noise(self, standard_deviation: float) -> Image:
"""
Expand Down Expand Up @@ -595,9 +614,12 @@ def add_noise(self, standard_deviation: float) -> Image:
_init_default_device()

_check_add_noise_errors(standard_deviation)
return Image(
self._image_tensor + torch.normal(0, standard_deviation, self._image_tensor.size()).to(_get_device()) * 255,
)
float_tensor = torch.empty(self._image_tensor.size(), dtype=torch.float16)
torch.normal(0, standard_deviation, self._image_tensor.size(), out=float_tensor)
float_tensor *= 255
float_tensor += self._image_tensor
torch.clamp(float_tensor, 0, 255, out=float_tensor)
return Image(float_tensor.to(torch.uint8))

def adjust_contrast(self, factor: float) -> Image:
"""
Expand All @@ -624,22 +646,26 @@ def adjust_contrast(self, factor: float) -> Image:
If factor is smaller than 0.
"""
import torch
from torchvision.transforms.v2 import functional as func2

_init_default_device()

_check_adjust_contrast_errors_and_warnings(factor, plural=False)

factor *= 1.0
adjusted_factor = (1 - factor) / factor
gray_tensor = self.convert_to_grayscale()._image_tensor[0]
mean = torch.mean(gray_tensor, dim=(-2, -1), dtype=torch.float16)
del gray_tensor
mean *= torch.tensor(adjusted_factor, dtype=torch.float16)
tensor = mean.repeat(min(self.channel, 3), self._image_tensor.size(dim=-2), self._image_tensor.size(dim=-1))
tensor += self._image_tensor[0 : min(self.channel, 3)]
tensor *= factor
torch.clamp(tensor, 0, 255, out=tensor)

if self.channel == 4:
return Image(
torch.cat(
[
func2.adjust_contrast(self._image_tensor[0:3], factor * 1.0),
self._image_tensor[3].unsqueeze(dim=0),
],
),
)
return Image(torch.cat([tensor.to(torch.uint8), self._image_tensor[3:4]], dim=0))
else:
return Image(func2.adjust_contrast(self._image_tensor, factor * 1.0))
return Image(tensor.to(torch.uint8))

def adjust_color_balance(self, factor: float) -> Image:
"""
Expand All @@ -665,10 +691,20 @@ def adjust_color_balance(self, factor: float) -> Image:
OutOfBoundsError
If factor is smaller than 0.
"""
import torch

_check_adjust_color_balance_errors_and_warnings(factor, self.channel, plural=False)
return Image(
self.convert_to_grayscale()._image_tensor * (1.0 - factor * 1.0) + self._image_tensor * (factor * 1.0),
)

factor *= 1.0
if factor == 0:
return self.convert_to_grayscale()
else:
adjusted_factor = (1 - factor) / factor
tensor = self.convert_to_grayscale()._image_tensor * torch.tensor(adjusted_factor, dtype=torch.float16)
tensor += self._image_tensor
tensor *= factor
torch.clamp(tensor, 0, 255, out=tensor)
return Image(tensor.to(torch.uint8))

def blur(self, radius: int) -> Image:
"""
Expand All @@ -692,12 +728,30 @@ def blur(self, radius: int) -> Image:
OutOfBoundsError
If radius is smaller than 0 or equal or greater than the smaller size of the image.
"""
from torchvision.transforms.v2 import functional as func2
import torch

_init_default_device()

float_dtype = torch.float32 if _get_device() != torch.device("cuda") else torch.float16

_check_blur_errors_and_warnings(radius, min(self.width, self.height), plural=False)
return Image(func2.gaussian_blur(self._image_tensor, [radius * 2 + 1, radius * 2 + 1]))

kernel = torch.full(
(self._image_tensor.size(dim=-3), 1, radius * 2 + 1, radius * 2 + 1),
1 / (radius * 2 + 1) ** 2,
dtype=float_dtype,
)
tensor = torch.nn.functional.conv2d(
torch.nn.functional.pad(
self._image_tensor.to(float_dtype),
(radius, radius, radius, radius),
mode="replicate",
),
kernel,
padding="valid",
groups=self._image_tensor.size(dim=-3),
).to(torch.uint8)
return Image(tensor)

def sharpen(self, factor: float) -> Image:
"""
Expand Down Expand Up @@ -734,7 +788,7 @@ def sharpen(self, factor: float) -> Image:
torch.cat(
[
func2.adjust_sharpness(self._image_tensor[0:3], factor * 1.0),
self._image_tensor[3].unsqueeze(dim=0),
self._image_tensor[3:4],
],
),
)
Expand All @@ -759,7 +813,7 @@ def invert_colors(self) -> Image:

if self.channel == 4:
return Image(
torch.cat([func2.invert(self._image_tensor[0:3]), self._image_tensor[3].unsqueeze(dim=0)]),
torch.cat([func2.invert(self._image_tensor[0:3]), self._image_tensor[3:4]]),
)
else:
return Image(func2.invert(self._image_tensor))
Expand Down Expand Up @@ -813,20 +867,19 @@ def find_edges(self) -> Image:

_init_default_device()

edges_tensor = torch.clamp(
torch.nn.functional.conv2d(
self.convert_to_grayscale()._image_tensor.float()[0].unsqueeze(dim=0),
Image._filter_edges_kernel(),
padding="same",
).squeeze(dim=1),
0,
255,
).to(torch.uint8)
edges_tensor_float16 = torch.nn.functional.conv2d(
self.convert_to_grayscale()._image_tensor.to(torch.float16)[0:1],
Image._filter_edges_kernel(),
padding="same",
)
torch.clamp(edges_tensor_float16, 0, 255, out=edges_tensor_float16)
if self.channel == 1:
return Image(edges_tensor_float16.to(torch.uint8))
edges_tensor = edges_tensor_float16.to(torch.uint8)
del edges_tensor_float16
if self.channel == 3:
return Image(edges_tensor.repeat(3, 1, 1))
elif self.channel == 4:
else: # self.channel == 4
return Image(
torch.cat([edges_tensor.repeat(3, 1, 1), self._image_tensor[3].unsqueeze(dim=0)]),
torch.cat([edges_tensor.repeat(3, 1, 1), self._image_tensor[3:4]]),
)
else:
return Image(edges_tensor)
Loading

0 comments on commit 134e7d8

Please sign in to comment.