Dataloader overrides Transforms to a boolean #1162
-
Hi! I'm trying to build a custom dataset inherited from RasterDataset. class MyCustomDataset(RasterDataset):
filename_glob = "*_sen2.tif"
#filename_regex = r"\d+_(?P<band>B[\d])_sen2"
is_image = True
separate_files = False
all_bands = ["B1", "B11", "B12", "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8a", "B9"]
rgb_bands = ["B4", "B3", "B2"]
targets = None
def __init__(self, root="data", crs=None, res=None, transforms=None, cache=True, targets_csv=None):
self.root = root
super().__init__(root, crs, res, transforms, cache)
if targets_csv is not None:
self.targets = pd.read_csv(targets_csv, index_col=0)
self.targets = torch.Tensor(self.targets["rh98"])
def plot(self, sample):
# Find the correct band index order
rgb_indices = []
for band in self.rgb_bands:
rgb_indices.append(self.all_bands.index(band))
# Reorder and rescale the image
image = sample["image"][rgb_indices].permute(1, 2, 0)
# Plot the image
fig, ax = plt.subplots()
ax.imshow(image)
return fig I followed the tutorial on the TorchGeo docs "Custom Datasets". When I try to instantiate a Dataloader and iterate through the samples: custom = MyCustomDataset(data_dir, targets_csv=targets_csv)
sampler = PreChippedGeoSampler(dataset) # Ds Already in Chips
dataloader = DataLoader(dataset, sampler=sampler, collate_fn=stack_samples) At this line: batch = next(iter(dataloader)) It gives me the following error:
Entering line 436 means that the transforms variable is not None, which means it overrides to something, which is a 'bool' object. In the RasterDataset class, the transforms argument is also None. I also tried defining transforms=None when instantiating an object of my class, but the same problem occurs. Is there something I'm missing when I'm creating my custom class? Why does the transforms variable override to a boolean? I'm really thankful for any tips and answers! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I forgot the transforms=transforms on the : super().__init__(root, crs, res, transforms, cache) changed to: super().__init__(root, crs, res, transforms=transforms, cache=cache) |
Beta Was this translation helpful? Give feedback.
-
Note that you may also need to override def __getitem__(self, query: BoundingBox) -> Dict[str, Any]:
sample = super()[query]
# idk how you plan on matching a particular geographic query to a target label
target = self.targets[magic]
sample["label"] = target
return sample |
Beta Was this translation helpful? Give feedback.
I forgot the transforms=transforms on the :
changed to: