Skip to content

Commit

Permalink
Add use cuda if available for the training of yolo
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Nov 2, 2024
1 parent 9f6ad27 commit 0b51a80
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion hot_fair_utilities/training/yolo_v8_v1/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,15 @@ def main():
train(**vars(opt))


def train(data, weights, gpu, epochs, batch_size, pc, output_path=None):
def train(
data,
weights,
gpu=("cuda" if torch.cuda.is_available() else "cpu"),
epochs=20,
batch_size=8,
pc=2.0,
output_path=None,
):
back = (
"n"
if "yolov8n" in weights
Expand Down

0 comments on commit 0b51a80

Please sign in to comment.