diff --git a/src/cvdata/relabel.py b/src/cvdata/relabel.py index 6ccf122..ae255fb 100644 --- a/src/cvdata/relabel.py +++ b/src/cvdata/relabel.py @@ -20,6 +20,35 @@ _logger = logging.getLogger(__name__) +# ------------------------------------------------------------------------------ +def relabel_darknet( + file_path: str, + old_index: int, + new_index: int, +): + """ + Replaces the label index values of a Darknet (YOLO) annotation file. + + :param file_path: path of the Darknet (YOLO) file + :param old_index: label index value which if found will be replaced by the + new label index + :param new_index: new label index value + """ + + # arguments validation + if (old_index < 0) or (new_index < 0): + raise ValueError("Invalid label index argument, must be equal or greater than zero") + + # replace the label indices in-place + with fileinput.FileInput(file_path, inplace=True) as file_input: + for line in file_input: + line = line.rstrip("\r\n") + parts = line.split() + if (len(parts) > 0) and (parts[0] == str(old_index)): + parts[0] = str(new_index) + print(" ".join(parts)) + + # ------------------------------------------------------------------------------ def relabel_kitti( file_path: str, @@ -29,7 +58,7 @@ def relabel_kitti( """ Replaces the label values of a KITTI annotation file. - :param file_path: path of the KITTI file to have labels replaced + :param file_path: path of the KITTI file :param old_label: label value which if found will be replaced by the new label :param new_label: new label value """ @@ -53,7 +82,7 @@ def relabel_pascal( """ Replaces the label values of a PASCAL VOC annotation file. - :param file_path: path of the PASCAL VOC file to have labels replaced + :param file_path: path of the PASCAL VOC file :param old_label: label value which if found will be replaced by the new label :param new_label: new label value """ @@ -92,6 +121,21 @@ def _validate_args( raise ValueError(f"File path argument {file_path} is not a valid file path") +# ------------------------------------------------------------------------------ +def _relabel_darknet(arguments: Dict): + """ + Unpacks a dictionary of arguments and calls the function for replacing the + labels of a Darknet (YOLO) annotation file. + + :param arguments: dictionary of function arguments, should include: + "file_path": path of the Darknet (YOLO) file to have labels renamed + "old": label index which if found will be renamed + "new": new label index value + """ + + relabel_darknet(arguments["file_path"], arguments["old"], arguments["new"]) + + # ------------------------------------------------------------------------------ def _relabel_kitti(arguments: Dict): """ @@ -161,6 +205,9 @@ def main(): elif args["format"] == "pascal": file_ext = ".xml" relabel_function = _relabel_pascal + elif args["format"] == "darknet": + file_ext = ".txt" + relabel_function = _relabel_darknet else: raise ValueError("Only KITTI and PASCAL annotation files are supported")