Skip to content

Commit

Permalink
Attach transforms to model (ultralytics#9028)
Browse files Browse the repository at this point in the history
* Attach transforms to model

Signed-off-by: Glenn Jocher <[email protected]>

* Update val.py

Signed-off-by: Glenn Jocher <[email protected]>

* Update train.py

Signed-off-by: Glenn Jocher <[email protected]>

Signed-off-by: Glenn Jocher <[email protected]>
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 9c1878c commit 2ba2392
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
10 changes: 5 additions & 5 deletions classify/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,16 @@ def train(opt, device):
for p in model.parameters():
p.requires_grad = True # for training
model = model.to(device)
names = trainloader.dataset.classes # class names
model.names = names # attach class names

# Info
if RANK in {-1, 0}:
model.names = trainloader.dataset.classes # attach class names
model.transforms = testloader.dataset.torch_transforms # attach inference transforms
model_info(model)
if opt.verbose:
LOGGER.info(model)
images, labels = next(iter(trainloader))
file = imshow_cls(images[:25], labels[:25], names=names, f=save_dir / 'train_images.jpg')
file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
logger.log_images(file, name='Train Examples')
logger.log_graph(model, imgsz) # log model

Expand Down Expand Up @@ -254,8 +254,8 @@ def train(opt, device):

# Plot examples
images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
pred = torch.max(ema.ema((images.half() if cuda else images.float()).to(device)), 1)[1]
file = imshow_cls(images, labels, pred, names, verbose=False, f=save_dir / 'test_images.jpg')
pred = torch.max(ema.ema(images.to(device)), 1)[1]
file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')

# Log results
meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}
Expand Down
3 changes: 1 addition & 2 deletions classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def run(
project=ROOT / 'runs/val-cls', # save to project/name
name='exp', # save to project/name
exist_ok=False, # existing project/name ok, do not increment
half=True, # use FP16 half-precision inference
half=False, # use FP16 half-precision inference
dnn=False, # use OpenCV DNN for ONNX inference
model=None,
dataloader=None,
Expand Down Expand Up @@ -124,7 +124,6 @@ def run(
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

model.float() # for training
return top1, top5, loss


Expand Down

0 comments on commit 2ba2392

Please sign in to comment.