diff --git a/paz/processors/draw.py b/paz/processors/draw.py index 991e771c2..41a92ad81 100644 --- a/paz/processors/draw.py +++ b/paz/processors/draw.py @@ -26,10 +26,19 @@ def __init__(self, class_names=None, colors=None, self.colors = colors self.weighted = weighted self.scale = scale + + if (self.class_names is not None and + not isinstance(self.class_names, list)): + raise TypeError("Class name should be of type 'List of strings'") + + if (self.colors is not None and + not all(isinstance(color, list) for color in self.colors)): + raise TypeError("Colors should be of type 'List of lists'") + if self.colors is None: self.colors = lincolor(len(self.class_names)) - if class_names is not None: + if self.class_names is not None: self.class_to_color = dict(zip(self.class_names, self.colors)) else: self.class_to_color = {None: self.colors, '': self.colors} diff --git a/tests/paz/processors/draw_test.py b/tests/paz/processors/draw_test.py new file mode 100644 index 000000000..7742445e7 --- /dev/null +++ b/tests/paz/processors/draw_test.py @@ -0,0 +1,16 @@ +import pytest +from paz import processors as pr + + +def test_DrawBoxes2D_with_invalid_class_names_type(): + with pytest.raises(TypeError): + class_names = 'Face' + colors = [[255, 0, 0]] + pr.DrawBoxes2D(class_names, colors) + + +def test_DrawBoxes2D_with_invalid_colors_type(): + with pytest.raises(TypeError): + class_names = ['Face'] + colors = [255, 0, 0] + pr.DrawBoxes2D(class_names, colors)