Skip to content

Commit

Permalink
make it easier to get colors in a standard format from ClassConfig
Browse files Browse the repository at this point in the history
These can be passed directly to matplotlib.color.ListedColorMap.
  • Loading branch information
AdeelH committed Sep 19, 2022
1 parent 949c2fb commit 3613061
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
7 changes: 6 additions & 1 deletion rastervision_core/rastervision/core/data/class_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from rastervision.pipeline.config import (Config, register_config, ConfigError,
Field, validator)
from rastervision.core.data.utils import color_to_triple
from rastervision.core.data.utils import color_to_triple, normalize_color

DEFAULT_NULL_CLASS_NAME = 'null'
DEFAULT_NULL_CLASS_COLOR = 'black'
Expand Down Expand Up @@ -116,3 +116,8 @@ def ensure_null_class(self) -> None:

def __len__(self) -> int:
return len(self.names)

@property
def color_triples(self) -> List[Tuple[float, float, float]]:
color_triples = [normalize_color(c) for c in self.colors]
return color_triples
36 changes: 29 additions & 7 deletions rastervision_core/rastervision/core/data/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
from PIL import ImageColor


def color_to_triple(color: Optional[str] = None) -> Tuple[int, int, int]:
def color_to_triple(
color: Optional[Union[str, Sequence]] = None) -> Tuple[int, int, int]:
"""Given a PIL ImageColor string, return a triple of integers
representing the red, green, and blue values.
Expand All @@ -18,12 +19,14 @@ def color_to_triple(color: Optional[str] = None) -> Tuple[int, int, int]:
"""
if color is None:
r = np.random.randint(0, 0x100)
g = np.random.randint(0, 0x100)
b = np.random.randint(0, 0x100)
return (r, g, b)
else:
r, g, b = np.random.randint(0, 256, size=3).tolist()
return r, g, b
elif isinstance(color, str):
return ImageColor.getrgb(color)
elif isinstance(color, (tuple, list)):
return color
else:
raise TypeError(f'Unsupported type: {type(color)}')


def color_to_integer(color: str) -> int:
Expand All @@ -44,6 +47,25 @@ def color_to_integer(color: str) -> int:
return integer


def normalize_color(
color: Union[str, tuple, list]) -> Tuple[float, float, float]:
"""Convert color representation to a float 3-tuple with values in [0-1]."""
if isinstance(color, str):
color = color_to_triple(color)

if isinstance(color, (tuple, list)):
if all(isinstance(c, int) for c in color):
return tuple(c / 255. for c in color)
elif all(isinstance(c, float) for c in color):
return tuple(color)
else:
raise ValueError('RGB values must be either all ints (0-255) '
'or all floats (0.0-1.0)')

raise TypeError('Expected color to be a string or tuple or list, '
f'but found {type(color)}.')


def rgb_to_int_array(rgb_array):
r = np.array(rgb_array[:, :, 0], dtype=np.uint32) * (1 << 16)
g = np.array(rgb_array[:, :, 1], dtype=np.uint32) * (1 << 8)
Expand Down

0 comments on commit 3613061

Please sign in to comment.