Skip to content

Commit

Permalink
Update references/classification/train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and kit1980 committed Dec 14, 2023
1 parent a708255 commit cd17926
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def load_data(traindir, valdir, args):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path, weights_only=True)
# TODO: this could probably be weights_only=True
dataset_test, _ = torch.load(cache_path, weights_only=False)
else:
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
Expand Down

0 comments on commit cd17926

Please sign in to comment.