diff --git a/tools/detect.py b/tools/detect.py new file mode 100644 index 00000000..2a497fb1 --- /dev/null +++ b/tools/detect.py @@ -0,0 +1,259 @@ +# Copyright (c) 2021, yolort team. All rights reserved. +# +# This source code is licensed under the GPL-3.0 license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +from pathlib import Path +from typing import List, Optional, Tuple + +import cv2 +import torch +from torchvision.ops import box_convert +from yolort.runtime import PredictorTRT +from yolort.utils.image_utils import to_numpy +from yolort.v5.utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages +from yolort.v5.utils.general import ( + colorstr, + set_logging, + increment_path, + check_img_size, + check_file, + scale_coords, + strip_optimizer, +) +from yolort.v5.utils.plots import Annotator, colors, save_one_box +from yolort.v5.utils.torch_utils import select_device, time_sync + +logger = set_logging(__name__) + + +@torch.no_grad() +def run( + weights: str = "yolort.engine", + source: str = "bus.jpg", + img_size: Tuple[int, int] = (640, 640), + conf_thres: float = 0.25, + iou_thres: float = 0.45, + max_det: int = 1000, + device: str = "", + view_img: bool = False, + save_txt: bool = False, + save_conf: bool = False, + save_crop: bool = False, + nosave: bool = False, + classes: Optional[List] = None, + visualize: bool = False, + update: bool = False, + project: str = "./runs/detect", + name: str = "exp", + exist_ok: bool = False, + line_thickness=3, + hide_labels: bool = False, + hide_conf: bool = False, + half: bool = False, +): + """ + The core function for detecting source of image, path or directory. + + Adapted form https://github.com/ultralytics/yolov5/blob/db6ec66/detect.py + + Args: + weights: Path of the engine + source: file/dir/URL/glob, 0 for webcam + img_size: inference size (height, width) + conf_thres: confidence threshold + iou_thres: NMS IOU threshold + max_det: maximum detections per image + device: cuda device, i.e. 0 or 0,1,2,3 or cpu + view_img: show results + save_txt: save results to *.txt + save_conf: save confidences in --save-txt labels + save_crop: save cropped prediction boxes + nosave: do not save images/videos + classes: filter by class: --class 0, or --class 0 2 3 + visualize: visualize features + update: update all models + project: save results to project/name + name: save results to project/name + exist_ok: existing project/name ok, do not increment + line_thickness: bounding box thickness (pixels) + hide_labels: hide labels + hide_conf: hide confidences + half: use FP16 half-precision inference + """ + source = str(source) + save_img = not nosave and not source.endswith(".txt") # save inference images + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(("rtsp://", "rtmp://", "http://", "https://")) + webcam = source.isnumeric() or source.endswith(".txt") or (is_url and not is_file) + if is_url and is_file: + source = check_file(source) # download + + # Directories + # increment run + save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) + # make dir + (save_dir / "labels" if save_txt else save_dir).mkdir(parents=True, exist_ok=True) + + # Load model + device = select_device(device) + model = PredictorTRT( + weights, + device=device, + score_thresh=conf_thres, + iou_thresh=iou_thres, + detections_per_img=max_det, + ) + stride, names = model.stride, model.names + img_size = check_img_size(img_size, stride=stride) # check image size + + # Dataloader + dataset = LoadImages(source, img_size=img_size, stride=stride, auto=False) + + # Run inference + model.warmup(img_size=(1, 3, *img_size), half=half) + dt, seen = [0.0, 0.0, 0.0], 0 + for path, im, im0s, _, s in dataset: + t1 = time_sync() + im = torch.from_numpy(im).to(device) + im = im.half() if half else im.float() # uint8 to fp16/32 + im /= 255 # 0 - 255 to 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # expand for batch dim + t2 = time_sync() + dt[0] += t2 - t1 + + # Inference + visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False + pred_logits = model(im) + t3 = time_sync() + dt[1] += t3 - t2 + + # NMS + detections = model.postprocessing(pred_logits) + dt[2] += time_sync() - t3 + + # Process predictions + for i, det in enumerate(detections): # per image + seen += 1 + if webcam: # batch_size >= 1 + p, im0, frame = path[i], im0s[i].copy(), dataset.count + s += f"{i}: " + else: + p, im0, frame = path, im0s.copy(), getattr(dataset, "frame", 0) + + p = Path(p) # to Path + save_path = str(save_dir / p.name) # im.jpg + txt_path = str(save_dir / "labels" / p.stem) + ("" if dataset.mode == "image" else f"_{frame}") + s += "%gx%g " % im.shape[2:] # print string + gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh + imc = im0.copy() if save_crop else im0 # for save_crop + annotator = Annotator(im0, line_width=line_thickness, example=str(names)) + if len(det): + # Rescale boxes from img_size to im0 size + boxes = scale_coords(im.shape[2:], det["boxes"], im0.shape).round() + scores = det["scores"] + labels = det["labels"] + + # Print results + for c in labels.unique(): + n = (labels == c).sum() # detections per class + s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string + + # Write results + for box, score, class_idx in zip(boxes, scores, labels): + if save_txt: # Write to file + # normalized xywh + xywh = box_convert(torch.tensor(box).view(1, 4), in_fmt="xyxy", out_fmt="cxcywh") + xywh = (xywh / gn).view(-1).tolist() + # label format + line = (class_idx, *xywh, score) if save_conf else (class_idx, *xywh) + with open(txt_path + ".txt", "a") as f: + f.write(("%g " * len(line)).rstrip() % line + "\n") + + xyxy = to_numpy(box) + if save_img or save_crop or view_img: # Add bbox to image + cls = int(class_idx) # integer class + label = ( + None + if hide_labels + else (names[cls] if hide_conf else f"{names[cls]} {score:.2f}") + ) + annotator.box_label(xyxy, label, color=colors(cls, True)) + if save_crop: + save_path = save_dir / "crops" / names[cls] / f"{p.stem}.jpg" + save_one_box(xyxy, imc, file=save_path, BGR=True) + + # Print time (inference-only) + logger.info(f"{s}Done. ({t3 - t2:.3f}s)") + + # Stream results + im0 = annotator.result() + if view_img: + cv2.imshow(str(p), im0) + cv2.waitKey(1) # 1 millisecond + + # Save results (image with detections) + if save_img: + if dataset.mode == "image": + cv2.imwrite(save_path, im0) + else: + raise NotImplementedError("Currently this method hasn't implemented yet.") + + # Print results + speeds_info = tuple(x / seen * 1e3 for x in dt) # speeds per image + logger.info( + f"Speed: {speeds_info[0]:.1f}ms pre-process, {speeds_info[1]:.1f}ms inference, " + f"{speeds_info[2]:.1f}ms NMS per image at shape {(1, 3, *img_size)}", + ) + if save_txt or save_img: + saved_info = ( + f"\n{len(list(save_dir.glob('labels/*.txt')))} labels " f"saved to {save_dir / 'labels'}" + if save_txt + else "", + ) + logger.info(f"Results saved to {colorstr('bold', save_dir)}{saved_info}") + if update: + # update model (to fix SourceChangeWarning) + strip_optimizer(weights) + + +def get_parser(): + parser = argparse.ArgumentParser("CLI tool for detecting source.", add_help=True) + parser.add_argument("--weights", type=str, default="yolov5s.pt", help="model path(s)") + parser.add_argument("--source", type=str, default="data/images", help="file/dir/URL/glob, 0 for webcam") + parser.add_argument("--img_size", nargs="+", type=int, default=[640], help="inference size h,w") + parser.add_argument("--conf_thres", type=float, default=0.25, help="confidence threshold") + parser.add_argument("--iou_thres", type=float, default=0.45, help="NMS IoU threshold") + parser.add_argument("--max_det", type=int, default=1000, help="maximum detections per image") + parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu") + parser.add_argument("--view_img", action="store_true", help="show results") + parser.add_argument("--save_txt", action="store_true", help="save results to *.txt") + parser.add_argument("--save_conf", action="store_true", help="save confidences in --save-txt labels") + parser.add_argument("--save_crop", action="store_true", help="save cropped prediction boxes") + parser.add_argument("--nosave", action="store_true", help="do not save images/videos") + parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0 2 3") + parser.add_argument("--visualize", action="store_true", help="visualize features") + parser.add_argument("--update", action="store_true", help="update all models") + parser.add_argument("--project", default="./runs/detect", help="save results to project/name") + parser.add_argument("--name", default="exp", help="save results to project/name") + parser.add_argument("--exist_ok", action="store_true", help="existing project/name ok, do not increment") + parser.add_argument("--line_thickness", default=3, type=int, help="bounding box thickness (pixels)") + parser.add_argument("--hide_labels", default=False, action="store_true", help="hide labels") + parser.add_argument("--hide_conf", default=False, action="store_true", help="hide confidences") + parser.add_argument("--half", action="store_true", help="use FP16 half-precision inference") + + return parser + + +def cli_main(): + parser = get_parser() + args = parser.parse_args() + logger.info(f"Command Line Args: {args}") + run(**vars(args)) + + +if __name__ == "__main__": + cli_main() diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index b044f616..cc9a8494 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -359,7 +359,8 @@ def _concat_pred_logits( return all_pred_logits - def _decode_pred_logits(self, pred_logits: Tensor): + @staticmethod + def _decode_pred_logits(pred_logits: Tensor): """ Decode the prediction logit from the PostPrecess. """ diff --git a/yolort/runtime/__init__.py b/yolort/runtime/__init__.py index fb1a661c..cc4f3a64 100644 --- a/yolort/runtime/__init__.py +++ b/yolort/runtime/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) 2021, yolort team. All Rights Reserved. from .y_onnxruntime import PredictorORT +from .y_tensorrt import PredictorTRT from .yolo_tensorrt_model import YOLOTRTModule -__all__ = ["PredictorORT", "YOLOTRTModule"] +__all__ = ["PredictorORT", "YOLOTRTModule", "PredictorTRT"] diff --git a/yolort/runtime/trt_helper.py b/yolort/runtime/trt_helper.py index 362e8a62..8b9443e8 100644 --- a/yolort/runtime/trt_helper.py +++ b/yolort/runtime/trt_helper.py @@ -28,22 +28,22 @@ class EngineBuilder: Parses an ONNX graph and builds a TensorRT engine from it. """ - def __init__(self, verbose=False, workspace=8): + def __init__(self, verbose=False, workspace=4): """ Args: verbose: If enabled, a higher verbosity level will be set on the TensorRT logger. workspace: Max memory workspace to allow, in Gb. """ - self.trt_logger = trt.Logger(trt.Logger.INFO) + self.logger = trt.Logger(trt.Logger.INFO) if verbose: - self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE + self.logger.min_severity = trt.Logger.Severity.VERBOSE - trt.init_libnvinfer_plugins(self.trt_logger, namespace="") + trt.init_libnvinfer_plugins(self.logger, namespace="") - self.builder = trt.Builder(self.trt_logger) + self.builder = trt.Builder(self.logger) self.config = self.builder.create_builder_config() - self.config.max_workspace_size = workspace * (2 ** 30) + self.config.max_workspace_size = workspace * 1 << 30 self.batch_size = None self.network = None @@ -56,19 +56,12 @@ def create_network(self, onnx_path: str): Args: onnx_path: The path to the ONNX graph to load. """ - network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - self.network = self.builder.create_network(network_flags) - self.parser = trt.OnnxParser(self.network, self.trt_logger) - - onnx_path = Path(onnx_path) - with open(onnx_path, "rb") as f: - if not self.parser.parse(f.read()): - err_message = f"Failed to load ONNX file: {onnx_path}" - log.error(err_message) - for error in range(self.parser.num_errors): - log.error(self.parser.get_error(error)) - raise OSError(err_message) + flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + self.network = self.builder.create_network(flag) + self.parser = trt.OnnxParser(self.network, self.logger) + if not self.parser.parse_from_file(onnx_path): + raise RuntimeError(f"Failed to load ONNX file: {onnx_path}") inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] @@ -120,5 +113,5 @@ def create_engine( raise NotImplementedError(f"Currently hasn't been implemented: {precision}.") with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f: - log.info(f"Serializing engine to file: {engine_path}") f.write(engine.serialize()) + log.info(f"Serialize engine success, saved as {engine_path}") diff --git a/yolort/runtime/y_tensorrt.py b/yolort/runtime/y_tensorrt.py new file mode 100644 index 00000000..d0ec368a --- /dev/null +++ b/yolort/runtime/y_tensorrt.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021, yolort team. All rights reserved. +# +# This source code is licensed under the GPL-3.0 license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +from collections import OrderedDict, namedtuple +from typing import Dict, List + +import numpy as np +import torch +from torch import Tensor +from torchvision.ops import box_convert, boxes as box_ops + +try: + import tensorrt as trt +except ImportError: + trt = None + +logging.basicConfig(level=logging.INFO) +logging.getLogger("PredictorTRT").setLevel(logging.INFO) +logger = logging.getLogger("PredictorTRT") + + +class PredictorTRT: + """ + Create a simple end-to-end predictor with the given checkpoint that runs on + single device for a single input image. + + Args: + engine_path (str): Path of the ONNX checkpoint. + + Examples: + >>> import torch + >>> from yolort.runtime import PredictorTRT + >>> + >>> engine_path = 'yolov5s.engine' + >>> device = torch.device("cuda") + >>> detector = PredictorTRT(engine_path, device) + >>> + >>> img_path = 'bus.jpg' + >>> scores, class_ids, boxes = detector.run_on_image(img_path) + """ + + def __init__( + self, + engine_path: str, + device: torch.device = torch.device("cuda"), + score_thresh: float = 0.25, + iou_thresh: float = 0.45, + detections_per_img: int = 100, + ) -> None: + self.engine_path = engine_path + self.device = device + self.named_binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + self.stride = 32 + self.names = [f"class{i}" for i in range(1000)] # assign defaults + self.score_thresh = score_thresh + self.iou_thresh = iou_thresh + self.detections_per_img = detections_per_img + + self.engine = self._build_engine() + self._set_context() + + def _build_engine(self): + logger.info(f"Loading {self.engine_path} for TensorRT inference...") + trt_logger = trt.Logger(trt.Logger.INFO) + with open(self.engine_path, "rb") as f, trt.Runtime(trt_logger) as runtime: + engine = runtime.deserialize_cuda_engine(f.read()) + + return engine + + def _set_context(self): + self.bindings = OrderedDict() + for index in range(self.engine.num_bindings): + name = self.engine.get_binding_name(index) + dtype = trt.nptype(self.engine.get_binding_dtype(index)) + shape = tuple(self.engine.get_binding_shape(index)) + data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(self.device) + self.bindings[name] = self.named_binding(name, dtype, shape, data, int(data.data_ptr())) + + self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items()) + self.context = self.engine.create_execution_context() + + def __call__(self, image: Tensor): + """ + Args: + image (Tensor): an image of shape (C, N, H, W). + + Returns: + predictions (Tuple[List[float], List[int], List[float, float]]): + stands for scores, labels and boxes respectively. + """ + assert image.shape == self.bindings["images"].shape, (image.shape, self.bindings["images"].shape) + self.binding_addrs["images"] = int(image.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + pred_logits = self.bindings["output"].data + return pred_logits + + def run_on_image(self, image): + """ + Run the TensorRT engine for one image only. + + Args: + image_path (str): The image path to be predicted. + """ + pred_logits = self(image) + detections = self.postprocessing(pred_logits) + return detections + + @staticmethod + def _decode_pred_logits(pred_logits: Tensor): + """ + Decode the prediction logit from the PostPrecess. + """ + # Compute conf + # box_conf x class_conf, w/ shape: num_anchors x num_classes + scores = pred_logits[:, 5:] * pred_logits[:, 4:5] + boxes = box_convert(pred_logits[:, :4], in_fmt="cxcywh", out_fmt="xyxy") + + return boxes, scores + + def postprocessing(self, pred_logits: Tensor): + batch_size = pred_logits.shape[0] + detections: List[Dict[str, Tensor]] = [] + + for idx in range(batch_size): # image idx, image inference + # Decode the predict logits + boxes, scores = self._decode_pred_logits(pred_logits[idx]) + + # remove low scoring boxes + inds, labels = torch.where(scores > self.score_thresh) + boxes, scores = boxes[inds], scores[inds, labels] + + # non-maximum suppression, independently done per level + keep = box_ops.batched_nms(boxes, scores, labels, self.iou_thresh) + # Keep only topk scoring head_outputs + keep = keep[: self.detections_per_img] + boxes, scores, labels = boxes[keep], scores[keep], labels[keep] + + detections.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return detections + + def warmup(self, img_size=(1, 3, 320, 320), half=False): + # Warmup model by running inference once + # only warmup GPU models + if isinstance(self.device, torch.device) and self.device.type != "cpu": + image = torch.zeros(*img_size).to(self.device).type(torch.half if half else torch.float) + self(image) diff --git a/yolort/runtime/yolo_tensorrt_model.py b/yolort/runtime/yolo_tensorrt_model.py index c7238b3e..6315980f 100644 --- a/yolort/runtime/yolo_tensorrt_model.py +++ b/yolort/runtime/yolo_tensorrt_model.py @@ -5,6 +5,7 @@ import torch from torch import nn, Tensor from yolort.models import YOLO +from yolort.models.anchor_utils import AnchorGenerator from yolort.models.backbone_utils import darknet_pan_backbone from yolort.models.box_head import LogitsDecoder from yolort.utils import load_from_ultralytics @@ -39,12 +40,13 @@ def __init__( version=version, use_p6=use_p6, ) + num_classes = model_info["num_classes"] + anchor_generator = AnchorGenerator(model_info["strides"], model_info["anchor_grids"]) post_process = LogitsDecoder(model_info["strides"]) model = YOLO( backbone, - model_info["num_classes"], - strides=model_info["strides"], - anchor_grids=model_info["anchor_grids"], + num_classes, + anchor_generator=anchor_generator, post_process=post_process, ) diff --git a/yolort/utils/image_utils.py b/yolort/utils/image_utils.py index 0eb9cff6..d2e78d9d 100644 --- a/yolort/utils/image_utils.py +++ b/yolort/utils/image_utils.py @@ -200,14 +200,14 @@ def cast_image_tensor_to_numpy(images): return images -def parse_images(images): +def parse_images(images: Tensor): images = images.permute(0, 2, 3, 1) images = cast_image_tensor_to_numpy(images) return images -def parse_single_image(image): +def parse_single_image(image: Tensor): image = image.permute(1, 2, 0) image = cast_image_tensor_to_numpy(image) return image @@ -221,7 +221,7 @@ def parse_single_target(target): return boxes -def to_numpy(tensor): +def to_numpy(tensor: Tensor): if tensor.requires_grad: return tensor.detach().cpu().numpy() else: diff --git a/yolort/v5/utils/datasets.py b/yolort/v5/utils/datasets.py index c30d1387..5666a5b3 100644 --- a/yolort/v5/utils/datasets.py +++ b/yolort/v5/utils/datasets.py @@ -2,8 +2,25 @@ """ Dataloaders and dataset utils """ +import glob +import os +from pathlib import Path + +import cv2 +import numpy as np from PIL import Image +from .augmentations import letterbox + + +# Parameters + +# acceptable image suffixes +IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"] +# acceptable video suffixes +VID_FORMATS = ["mov", "avi", "mp4", "mpg", "mpeg", "m4v", "wmv", "mkv"] +WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) # DPP + def exif_transpose(image): """ @@ -34,3 +51,90 @@ def exif_transpose(image): del exif[0x0112] image.info["exif"] = exif.tobytes() return image + + +class LoadImages: + """ + YOLOv5 image/video dataloader. And we're using th CHW RGB format. + """ + + def __init__(self, path: str, img_size: int = 640, stride: int = 32, auto: bool = True): + path_source = str(Path(path).resolve()) # os-agnostic absolute path + if "*" in path_source: + files = sorted(glob.glob(path_source, recursive=True)) # glob + elif os.path.isdir(path_source): + files = sorted(glob.glob(os.path.join(path_source, "*.*"))) # dir + elif os.path.isfile(path_source): + files = [path_source] # files + else: + raise Exception(f"ERROR: {path_source} does not exist") + + images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS] + videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS] + num_images, num_videos = len(images), len(videos) + + self.img_size = img_size + self.stride = stride + self.files = images + videos + self.num_files = num_images + num_videos # number of files + self.video_flag = [False] * num_images + [True] * num_videos + self.mode = "image" + self.auto = auto + if any(videos): + self.new_video(videos[0]) # new video + else: + self.cap = None + assert self.num_files > 0, ( + f"No images or videos found in {path_source}. Supported formats " + f"are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" + ) + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count == self.num_files: + raise StopIteration + path = self.files[self.count] + + if self.video_flag[self.count]: + # Read video + self.mode = "video" + ret_val, img_origin = self.cap.read() + while not ret_val: + self.count += 1 + self.cap.release() + if self.count == self.num_files: # last video + raise StopIteration + else: + path = self.files[self.count] + self.new_video(path) + ret_val, img_origin = self.cap.read() + + self.frame += 1 + source_bar = f"video {self.count + 1}/{self.num_files} ({self.frame}/{self.frames}) {path}: " + + else: + # Read image + self.count += 1 + img_origin = cv2.imread(path) # opencv set the BGR order as the default + assert img_origin is not None, f"Not Found Image: {path}" + source_bar = f"image {self.count}/{self.num_files} {path}: " + + # Padded resize + img = letterbox(img_origin, self.img_size, stride=self.stride, auto=self.auto)[0] + + # Convert HWC to CHW, BGR to RGB + img = img.transpose((2, 0, 1))[::-1] + img = np.ascontiguousarray(img) + + return path, img, img_origin, self.cap, source_bar + + def new_video(self, path): + self.frame = 0 + self.cap = cv2.VideoCapture(path) + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + def __len__(self): + return self.num_files # number of files diff --git a/yolort/v5/utils/general.py b/yolort/v5/utils/general.py index 6f23d613..d052895a 100644 --- a/yolort/v5/utils/general.py +++ b/yolort/v5/utils/general.py @@ -237,14 +237,19 @@ def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=Fals return result -def check_img_size(imgsz, s=32, floor=0): - # Verify image size is a multiple of stride s in each dimension - if isinstance(imgsz, int): # integer i.e. img_size=640 - new_size = max(make_divisible(imgsz, int(s)), floor) +def check_img_size(image_size, stride=32, floor=0): + """ + Verify image size is a multiple of stride stride in each dimension + """ + if isinstance(image_size, int): # integer i.e. img_size=640 + new_size = max(make_divisible(image_size, int(stride)), floor) else: # list i.e. img_size=[640, 480] - new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz] - if new_size != imgsz: - print(f"WARNING: --img-size {imgsz} must be multiple of " f"max stride {s}, updating to {new_size}") + new_size = [max(make_divisible(x, int(stride)), floor) for x in image_size] + if new_size != image_size: + print( + f"WARNING: --img-size {image_size} must be multiple of " + f"max stride {stride}, updating to {new_size}" + ) return new_size @@ -265,12 +270,18 @@ def check_yaml(file, suffix=(".yaml", ".yml")): def check_file(file, suffix=""): - # Search/download file (if necessary) and return path + """ + Search/download file (if necessary) and return path + """ check_suffix(file, suffix) # optional file = str(file) # convert to str() - if Path(file).is_file() or file == "": # exists + + if Path(file).is_file() or file == "": + # return the file if the file exists return file - elif file.startswith(("http:/", "https:/")): # download + + if file.startswith(("http:/", "https:/")): + # download the file if the image source is a link url = str(Path(file)).replace(":/", "://") # Pathlib turns :// -> :/ # '%2F' to '/', split https://url.com/file.txt?auth file = Path(urllib.parse.unquote(file).split("?")[0]).name @@ -281,14 +292,15 @@ def check_file(file, suffix=""): torch.hub.download_url_to_file(url, file) assert Path(file).exists() and Path(file).stat().st_size > 0, f"File download failed: {url}" return file - else: # search - files = [] - for d in "data", "models", "utils": # search directories - files.extend(glob.glob(str(ROOT / d / "**" / file), recursive=True)) # find file - assert len(files), f"File not found: {file}" # assert file was found - # assert unique - assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" - return files[0] # return file + + files = [] + for d in "data", "models", "utils": + # search the directories + files.extend(glob.glob(str(ROOT / d / "**" / file), recursive=True)) # find file + assert len(files), f"File not found: {file}" # assert file was found + # assert unique + assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" + return files[0] # return file def url2file(url): @@ -604,9 +616,12 @@ def non_max_suppression( return output -def strip_optimizer(f="best.pt", s=""): - # Strip optimizer from 'f' to finalize training, optionally save as 's' - x = torch.load(f, map_location=torch.device("cpu")) +def strip_optimizer(checkpoint_path="best.pt", saved_path=""): + """ + Strip optimizer from 'checkpoint_path' to finalize training, + optionally save as 'saved_path' + """ + x = torch.load(checkpoint_path, map_location=torch.device("cpu")) if x.get("ema"): x["model"] = x["ema"] # replace model with ema for k in "optimizer", "training_results", "wandb_id", "ema", "updates": # keys @@ -615,9 +630,10 @@ def strip_optimizer(f="best.pt", s=""): x["model"].half() # to FP16 for p in x["model"].parameters(): p.requires_grad = False - torch.save(x, s or f) - mb = os.path.getsize(s or f) / 1e6 # filesize - print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") + torch.save(x, saved_path or checkpoint_path) + mb = os.path.getsize(saved_path or checkpoint_path) / 1e6 # filesize + saved_info = f" saved as {saved_path}," if saved_path else "" + print(f"Optimizer stripped from {checkpoint_path},{saved_info} {mb:.1f}MB") def print_mutation(results, hyp, save_dir, bucket): @@ -703,7 +719,10 @@ def apply_classifier(x, model, img, im0): def increment_path(path, exist_ok=False, sep="", mkdir=False): - # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + """ + Increment file or directory path. + i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + """ path = Path(path) # os-agnostic if path.exists() and not exist_ok: path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")