diff --git a/hot_fair_utilities/training/yolo_v8_v1/train.py b/hot_fair_utilities/training/yolo_v8_v1/train.py index 56d5d66..8fcdca8 100644 --- a/hot_fair_utilities/training/yolo_v8_v1/train.py +++ b/hot_fair_utilities/training/yolo_v8_v1/train.py @@ -80,7 +80,15 @@ def main(): train(**vars(opt)) -def train(data, weights, gpu, epochs, batch_size, pc, output_path=None): +def train( + data, + weights, + gpu=("cuda" if torch.cuda.is_available() else "cpu"), + epochs=20, + batch_size=8, + pc=2.0, + output_path=None, +): back = ( "n" if "yolov8n" in weights