Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concerns with GridGeoSampler for evaluation #1245

Closed
adamjstewart opened this issue Apr 14, 2023 · 11 comments
Closed

Concerns with GridGeoSampler for evaluation #1245

adamjstewart opened this issue Apr 14, 2023 · 11 comments
Assignees
Labels
samplers Samplers for indexing datasets
Milestone

Comments

@adamjstewart
Copy link
Collaborator

adamjstewart commented Apr 14, 2023

A couple of issues that may affect the usage of GridGeoSampler for benchmarking.

Overlapping patches

Back in #630, we modified GridGeoSampler to ensure that every part of the image is sampled from, even if the height/width of the image is not a multiple of the stride. At the time, I decided that we should adjust the stride of the last row/col in order to avoid sampling outside of the bounds of the image. In hindsight, I think this was a mistake.

The problem is that we end up with the last row/col overlapping with the second-to-last row/col, resulting in some areas being double counted when computing performance metrics. It also makes stitching together prediction patches unnecessarily complicated.

I think we should modify GridGeoSampler to avoid adjusting stride and instead sample outside the bounds of the image. I believe rasterio will simply return nodata pixels for areas outside of the image. @remtav are you okay with this solution? I believe this was actually the first idea you implemented, apologies for pushing that PR in the wrong direction.

Technically this issue also occurs when multiple images in the dataset intersect, but this is harder to mitigate without storing all predictions in one giant tensor and computing performance only on the final predicted mask. I think we would run out of memory very quickly.

ignore_index weighting

This one may also affect training for other GeoSamplers as well, although I'm most concerned about evaluation.

When sampling from large tiles, many patches will contain partial or complete nodata pixels. TorchMetrics allows us to ignore these areas using ignore_index. However, it's unclear to me if all patches are weighted equally when computing the final performance metrics with Lightning. Ideally, the overall reported accuracy would match regardless of whether we chip up the image into small patches or if we compute accuracy on the entire image/mask in one go.

We could peruse the internals of TorchMetrics and Lightning, but I think it's actually easier to construct a toy example to determine whether or not this issue occurs. Consider an image with width 200 and height 100. Let the first 99 columns of the ground truth mask be 0, the 100th column be 1, and the last 100 columns be 2. Let the predicted mask be a tensor of all 1s. If we use a GridGeoSampler with size 100 and stride 100, and let ignore_index=0, the correct performance should be ~1%. If the actual reported performance is 50%, we'll know we have an issue.

@adamjstewart adamjstewart added the samplers Samplers for indexing datasets label Apr 14, 2023
@adamjstewart adamjstewart added this to the 0.4.2 milestone Apr 14, 2023
@calebrob6
Copy link
Member

I'm fine with this, and can take it.

@calebrob6 calebrob6 self-assigned this Apr 14, 2023
@adamjstewart
Copy link
Collaborator Author

@calebrob6 any updates on this?

@remtav
Copy link
Contributor

remtav commented Apr 27, 2023

With some delay, no trouble with this on my end. Thanks for asking.

@adamjstewart
Copy link
Collaborator Author

Closed by #1329 and #1331

@adamjstewart adamjstewart modified the milestones: 0.4.2, 0.5.0 Sep 28, 2023
@FogDrip
Copy link

FogDrip commented Sep 19, 2024

Hi I know this is way after the fact, but when I use "ignore_index=0" inside the PixelClassificationModel()

class PixelClassificationModel(pl.LightningModule):
    def __init__(self, num_classes, in_channels, learning_rate=0.001):
        super(PixelClassificationModel, self).__init__()
        self.learning_rate = learning_rate
        # Example of using a larger model
        self.model = models.segmentation.deeplabv3_resnet50(weights=None, num_classes=num_classes)
        # self.model = models.segmentation.deeplabv3_mobilenet_v3_large(weights=None, num_classes=num_classes)
        self.model.backbone.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

I have the following issue:

Epoch 0: 100%|█████████████████████████████████████████████████████████████████| 17/17 [00:17<00:00,  0.95it/s, v_num=91, train_loss_step=2.240]NaN or Inf found in input tensor.████████████████████████████████████████████████████████████████████████████████| 17/17 [00:02<00:00,  6.63it/s]
Epoch 0: 100%|█████████████████████████| 17/17 [00:31<00:00,  0.54it/s, v_num=91, train_loss_step=2.240, val_loss=nan.0, train_loss_epoch=nan.0]NaN or Inf found in input tensor.
Epoch 0: 100%|█████████████████████████| 17/17 [00:32<00:00,  0.52it/s, v_num=91, train_loss_step=2.240, val_loss=nan.0, train_loss_epoch=nan.0]

I study habitats along rivers, so in any given raster ~95% of the data area areas that I don't want to classify (e.g. urban areas, upland habitat). I've set these areas with a label of 0 as I want to mask them. The ignore index of 0 is not working as expected so I've tried other ways to omit these tiles from the model. I've spent a lot time trying to filter out these images I don't care from the batches but I can't seem to get it to work using a custom filter function for that dataloader, nor ignoring these images in the trainer and validator, each of these options creates a model that has very odd predictions. I can share code and some sample .tifs for labels and images if that would help, but I'm sort of at a loss as to how to get around this.

@adamjstewart
Copy link
Collaborator Author

Hi @FogDrip, it doesn't seem like you're using TorchGeo in your code snippet. Can you try using SemanticSegmentationTask instead of your custom PixelClassificationModel?

@FogDrip
Copy link

FogDrip commented Sep 20, 2024

Hi @adamjstewart thank you so much for your quick response, I'm eager to learn and use TorchGeo as I'm impressed by it's customizability relative to ArcGIS DL. I switched to using SemanticSegementationTask:

    task = SemanticSegmentationTask(
        model="deeplabv3+",
        backbone="resnet50",
        in_channels=num_input_channels,
        num_classes=num_classes,
        ignore_index=0,
        lr=learning_rate, 
        loss="ce"
    )

I'm now running into issues around:

  File "...\apply_func.py", line 172, in _apply_to_collection_slow
    raise ValueError(
ValueError: A frozen dataclass was passed to `apply_to_collection` but this is not allowed.

I'm not sure if this relates to a custom label class I created for my labels that inherits from Chesapeake. I'm using the default NAIP class for the NAIP imagery in my study area so I'm less concerned there.

class ChesapeakeCA(Chesapeake):
    """Custom Chesapeake dataset class for California."""

    base_folder = "ARP"
    filename = "label_clipped.tif"
    filename_glob = filename
    ...

To solve this frozen problem, I've trying using a custom collate function to filter out the bounding box and CRS, but that gets me into a bit of a rabbit hole with having to deal with mismatch between tensor shapes, where that solution causes other problems.

@calebrob6
Copy link
Member

ValueError: A frozen dataclass was passed to apply_to_collection but this is not allowed.

I get this error when training models with lightning and GeoDatasets as the BoundingBoxes can't be turned into a batch. I get around this by deleting the bbox key per sample (in torchgeo < 0.6) or bounds (in torchgeo>=0.6).

@FogDrip
Copy link

FogDrip commented Sep 20, 2024

Hi @calebrob6 thanks for the quick reply. I saw your colleagues presentation at ESA on fence classification using DeepLabs, which was excellent.

I'm able to get rid of the bounds with a custom collate function, where I tried to initially do it within the custom class ChesapeakeCA(Chesapeake) using a def getitem: del sample['bbox'] but it wasn't removing them permanently.

def custom_collate_fn(batch):
    """Custom collate function to remove 'bbox', 'bounds', and 'crs' fields before batching."""
    for sample in batch:
        if 'bbox' in sample:
            del sample['bbox']
        if 'bounds' in sample:
            del sample['bounds']
        if 'crs' in sample:
            del sample['crs']  # Remove 'crs' field as it causes issues with collating
    return default_collate(batch)

I'm still getting the following feedback:

Epoch 1: 100%|█████████████████████| 7/7 [00:00<00:00,  8.02it/s, v_num=3]NaN or Inf found in input tensor.████████████| 2/2 [00:00<00:00, 40.00it/s] 
Epoch 2: 100%|█████████████████████| 7/7 [00:00<00:00,  8.91it/s, v_num=3]NaN or Inf found in input tensor.

I think this could be related to when all the values in a batch label mask equal to 0, then it will be completely masked by
ignore_index=0 and a batch of all NaNs will be sent to the trainer. I tried to add logic to remove these batches from the training and validation, which mostly removed "NaN or Inf found in input tensor" message during the training, but the output was quite strange looking, where it was more normal when I didn't use ignore_index = 0.

@adamjstewart
Copy link
Collaborator Author

How many classes do you have in your task? If you only have 2 classes and you ignore 1 of them, I wouldn't be surprised to see this.

@FogDrip
Copy link

FogDrip commented Sep 21, 2024

Hi @adamjstewart, thanks for your quick response. I have 18 classes but I'm trying to ignore class 0. I hope this attachment provides context:

labels

The prediction looks okay in some areas, so I'm wondering if I should just ignore the
"NaN or Inf found in input tensor." during the training

prediction_0014

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
samplers Samplers for indexing datasets
Projects
None yet
Development

No branches or pull requests

4 participants