Skip to content

Commit

Permalink
Fixed bug with collator
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Nov 29, 2024
1 parent 4c6a73a commit 1e6f1bf
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Optional
from typing import Optional, Callable

import torch
from torch.utils.data import Dataset, DataLoader, default_collate
Expand Down Expand Up @@ -96,6 +96,14 @@ def init_dataset(args):
return train_dataset, test_dataset


def custom_collate(cpu_transforms: Callable):
def collator(batch):
data, labels = default_collate(batch)
return cpu_transforms(data), labels

return collator


def init_loaders(
args, train_dataset: CachedDataset, test_dataset: CachedDataset, pin_memory
):
Expand All @@ -108,10 +116,9 @@ def init_loaders(
0 if not hasattr(args, "num_workers_val") else args.num_workers_val
)

train_collate_fn = default_collate
if train_dataset.batch_transforms_cpu is not None:
train_collate_fn = lambda batch: train_dataset.batch_transforms_cpu(default_collate(batch))
else:
train_collate_fn = default_collate
train_collate_fn = custom_collate(train_dataset.batch_transforms_cpu)

train_loader = DataLoader(
train_dataset,
Expand All @@ -123,10 +130,9 @@ def init_loaders(
collate_fn=train_collate_fn,
)

test_collate_fn = default_collate
if test_dataset.batch_transforms_cpu is not None:
test_collate_fn = lambda batch: test_dataset.batch_transforms_cpu(default_collate(batch))
else:
test_collate_fn = default_collate
test_collate_fn = custom_collate(test_dataset.batch_transforms_cpu)

test_loader = DataLoader(
test_dataset,
Expand Down

0 comments on commit 1e6f1bf

Please sign in to comment.