diff --git a/.gitattributes b/.gitattributes
index ec4a626f..e69de29b 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1 +0,0 @@
-*.pth filter=lfs diff=lfs merge=lfs -text
diff --git a/README.md b/README.md
index 2f07f770..0958d783 100644
--- a/README.md
+++ b/README.md
@@ -10,18 +10,14 @@
-Python package for automatic tree crown delineation based on Mask R-CNN. Pre-trained models can be picked in the [`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden).
-A tutorial on how to prepare data, train models and make predictions is available [here](https://patball1.github.io/detectree2/tutorial.html). For questions, collaboration proposals and requests for data email [James Ball](mailto:ball.jgc@gmail.com). Some example data is available for download [here](https://doi.org/10.5281/zenodo.8136161).
+Python package for automatic tree crown delineation in aerial RGB and multispectral imagery based on Mask R-CNN. Pre-trained models can be picked in the [`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden).
+A tutorial on how to prepare data, train models and make predictions is available [here](https://patball1.github.io/detectree2/tutorial.html). For questions, collaboration proposals and requests for data email [James Ball](mailto:ball.jgc@gmail.com). Some example data is available to download [here](https://doi.org/10.5281/zenodo.8136161).
Detectree2是一个基于Mask R-CNN的自动树冠检测与分割的Python包。您可以在[`model_garden`](https://github.com/PatBall1/detectree2/tree/master/model_garden)中选择预训练模型。[这里](https://patball1.github.io/detectree2/tutorial.html)提供了如何准备数据、训练模型和进行预测的教程。如果有任何问题,合作提案或者需要样例数据,可以邮件联系[James Ball](mailto:ball.jgc@gmail.com)。一些示例数据可以在[这里](https://doi.org/10.5281/zenodo.8136161)下载。
| | Code developed by James Ball, Seb Hickman, Thomas Koay, Oscar Jiang, Luran Wang, Panagiotis Ioannou, James Hinton and Matthew Archer in the [Forest Ecology and Conservation Group](https://coomeslab.org/) at the University of Cambridge. The Forest Ecology and Conservation Group is led by Professor David Coomes and is part of the University of Cambridge [Conservation Research Institute](https://www.conservation.cam.ac.uk/). |
| :---: | :--- |
-
-> [!NOTE]
-> To save bandwidth, trained models have been moved to [Zenodo](https://zenodo.org/records/10522461). Download models directly with `wget` or equivalent.
-
## Citation
diff --git a/detectree2/data_loading/custom.py b/detectree2/data_loading/custom.py
new file mode 100644
index 00000000..81a1becb
--- /dev/null
+++ b/detectree2/data_loading/custom.py
@@ -0,0 +1,74 @@
+import cv2
+import detectron2.data.transforms as T
+import numpy as np
+import rasterio
+import torch
+from detectron2.structures import BitMasks, BoxMode, Instances
+from torch.utils.data import Dataset
+
+
+class CustomTIFFDataset(Dataset):
+ def __init__(self, annotations, transforms=None):
+ """
+ Args:
+ annotations (list): List of dictionaries containing image file paths and annotations.
+ transforms (callable, optional): Optional transform to be applied on a sample.
+ """
+ self.annotations = annotations
+ self.transforms = transforms
+
+ def __len__(self):
+ return len(self.annotations)
+
+ def __getitem__(self, idx):
+ # Load the TIFF image
+ img_info = self.annotations[idx]
+ with rasterio.open(img_info['file_name']) as src:
+ # Read all bands (assuming they are all needed)
+ image = src.read()
+ # Normalize or rescale if necessary
+ image = image.astype(np.float32) / 255.0 # Example normalization
+ # If the number of bands is not 3, reduce to 3 or handle accordingly
+ #if image.shape[0] > 3:
+ # image = image[:3, :, :] # Taking the first 3 bands (e.g., RGB)
+ # Convert to HWC format expected by Detectron2
+ #image = np.transpose(image, (1, 2, 0))
+
+ # Prepare annotations (this part needs to be adapted to your specific annotations)
+ target = {
+ "image_id": idx,
+ "annotations": img_info["annotations"],
+ "width": img_info["width"],
+ "height": img_info["height"],
+ }
+
+ if self.transforms is not None:
+ augmentations = T.AugmentationList(self.transforms)
+ image, target = augmentations(image, target)
+
+ # Convert to Detectron2-compatible format
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
+ instances = self.get_detectron_instances(target)
+
+ return image, instances
+
+ def get_detectron_instances(self, target):
+ """
+ Converts annotations into Detectron2's format.
+ This example assumes annotations are in COCO format, and you'll need to adapt it for your needs.
+ """
+ boxes = [obj["bbox"] for obj in target["annotations"]]
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
+ boxes = BoxMode.convert(boxes, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
+
+ # Create BitMasks from the binary mask data (assuming the mask is a binary numpy array)
+ masks = [obj["segmentation"] for obj in target["annotations"]] # Replace with actual mask loading
+ masks = BitMasks(torch.stack([torch.from_numpy(mask) for mask in masks]))
+
+ instances = Instances(
+ image_size=(target["height"], target["width"]),
+ gt_boxes=boxes,
+ gt_classes=torch.tensor([obj["category_id"] for obj in target["annotations"]], dtype=torch.int64),
+ gt_masks=masks
+ )
+ return instances
diff --git a/detectree2/models/train.py b/detectree2/models/train.py
index d7606793..5e404e4d 100644
--- a/detectree2/models/train.py
+++ b/detectree2/models/train.py
@@ -9,15 +9,18 @@
import logging
import os
import random
+import re
import time
from pathlib import Path
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Optional
import cv2
import detectron2.data.transforms as T # noqa:N812
import detectron2.utils.comm as comm
import numpy as np
+import rasterio
import torch
+import torch.nn as nn
from detectron2 import model_zoo
from detectron2.checkpoint import DetectionCheckpointer # noqa:F401
from detectron2.config import get_cfg
@@ -28,6 +31,7 @@
build_detection_test_loader,
build_detection_train_loader,
)
+from detectron2.data import detection_utils as utils
from detectron2.engine import DefaultTrainer
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import COCOEvaluator, verify_results
@@ -38,29 +42,144 @@
from detectron2.utils.logger import log_every_n_seconds
from detectron2.utils.visualizer import ColorMode, Visualizer
-# from IPython.display import display
-# from PIL import Image
+from detectree2.preprocessing.tiling import load_class_mapping
-class LossEvalHook(HookBase):
- """Do inference and get the loss metric.
+class FlexibleDatasetMapper(DatasetMapper):
+ """
+ A flexible dataset mapper that extends the standard DatasetMapper to handle
+ multi-band images and custom augmentations.
- Class to:
- - Do inference of dataset like an Evaluator does
- - Get the loss metric like the trainer does
- https://github.com/facebookresearch/detectron2/blob/master/detectron2/evaluation/evaluator.py
- https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
- See https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b
+ This class is designed to work with datasets that may contain images with
+ more than three channels (e.g., multispectral images) and allows for custom
+ augmentations to be applied. It also handles semantic segmentation data if
+ provided in the dataset.
+
+ Args:
+ cfg (CfgNode): Configuration object containing dataset and model configurations.
+ is_train (bool): Flag indicating whether the mapper is being used for training. Default is True.
+ augmentations (list, optional): List of augmentations to be applied. Default is an empty list.
Attributes:
- model: model to train
- period: number of iterations between evaluations
- data_loader: data loader to use for evaluation
- patience: number of evaluation periods to wait for improvement
+ cfg (CfgNode): Stores the configuration object for later use.
+ is_train (bool): Indicates whether the mapper is in training mode.
+ logger (Logger): Logger instance for logging messages.
+ """
+ def __init__(self, cfg, is_train=True, augmentations=None):
+ if augmentations is None:
+ augmentations = []
+
+ # Initialize the base DatasetMapper class with provided parameters
+ super().__init__(
+ is_train=is_train,
+ augmentations=augmentations,
+ image_format=cfg.INPUT.FORMAT,
+ use_instance_mask=cfg.MODEL.MASK_ON,
+ use_keypoint=cfg.MODEL.KEYPOINT_ON,
+ instance_mask_format=cfg.INPUT.MASK_FORMAT,
+ keypoint_hflip_indices=None,
+ precomputed_proposal_topk=None,
+ recompute_boxes=False
+ )
+ self.cfg = cfg
+ self.is_train = is_train
+ self.logger = logging.getLogger(__name__)
+ mode = "training" if is_train else "inference"
+ self.logger.info(f"[FlexibleDatasetMapper] Augmentations used in {mode}: {augmentations}")
+
+ def __call__(self, dataset_dict):
+ """
+ Process a single dataset dictionary, applying the necessary transformations and augmentations.
+
+ Args:
+ dataset_dict (dict): A dictionary containing data for a single dataset item, including
+ file names and metadata.
+
+ Returns:
+ dict: The processed dataset dictionary, or None if there was an error.
+ """
+ if dataset_dict is None:
+ self.logger.warning("Received None for dataset_dict, skipping this entry.")
+ return None
+
+ if self.cfg.IMGMODE == "rgb":
+ return super().__call__(dataset_dict)
+
+ try:
+ # Handle multi-band image loading using rasterio
+ with rasterio.open(dataset_dict["file_name"]) as src:
+ img = src.read()
+ if img is None:
+ raise ValueError(f"Image data is None for file: {dataset_dict['file_name']}")
+ # Transpose image dimensions to match expected format (H, W, C)
+ img = np.transpose(img, (1, 2, 0)).astype("float32")
+
+ # Size check similar to utils.check_image_size
+ if img.shape[:2] != (dataset_dict.get("height"), dataset_dict.get("width")):
+ self.logger.warning(
+ f"""Image size {img.shape[:2]} does not match expected size {(dataset_dict.get('height'),
+ dataset_dict.get('width'))}."""
+ )
+
+ # Otherwise, handle custom multi-band logic
+ aug_input = T.AugInput(img)
+ transforms = self.augmentations(aug_input) # Apply the augmentations
+ img = aug_input.image
+
+ dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
+
+ # Handle semantic segmentation if present
+ if "sem_seg_file_name" in dataset_dict:
+ sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
+ dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
+
+ if not self.is_train:
+ # If not in training mode, remove annotations and segmentation file names
+ dataset_dict.pop("annotations", None)
+ dataset_dict.pop("sem_seg_file_name", None)
+ return dataset_dict
+
+ if "annotations" in dataset_dict:
+ # Apply the transformations to the annotations
+ self._transform_annotations(dataset_dict, transforms, img.shape[:2])
+
+ return dataset_dict
+
+ except Exception as e:
+ file_name = dataset_dict.get('file_name', 'unknown') if dataset_dict else 'unknown'
+ self.logger.error(f"Error processing {file_name}: {e}")
+ return None
+
+
+class LossEvalHook(HookBase):
"""
+ A custom hook for evaluating loss during training and managing model checkpoints based on evaluation metrics.
+
+ This hook is designed to:
+ - Perform inference on a dataset similarly to an Evaluator.
+ - Calculate and log the loss metric during training.
+ - Save the best model checkpoint based on a specified evaluation metric (e.g., AP50).
+ - Implement early stopping if the evaluation metric does not improve over a specified number of evaluations.
+ Attributes:
+ _model: The model to evaluate.
+ _period: Number of iterations between evaluations.
+ _data_loader: The data loader used for evaluation.
+ patience: Number of evaluation periods to wait before early stopping.
+ iter: Tracks the number of evaluations since the last improvement in the evaluation metric.
+ max_ap: The best evaluation metric (e.g., AP50) achieved during training.
+ best_iter: The iteration at which the best evaluation metric was achieved.
+ """
def __init__(self, eval_period, model, data_loader, patience):
- """Inits LossEvalHook."""
+ """
+ Initialize the LossEvalHook.
+
+ Args:
+ eval_period (int): The number of iterations between evaluations.
+ model (torch.nn.Module): The model to evaluate.
+ data_loader (torch.utils.data.DataLoader): The data loader for evaluation.
+ patience (int): The number of evaluation periods to wait for improvement before early stopping.
+ """
self._model = model
self._period = eval_period
self._data_loader = data_loader
@@ -70,10 +189,14 @@ def __init__(self, eval_period, model, data_loader, patience):
self.best_iter = 0
def _do_loss_eval(self):
- """Copying inference_on_dataset from evaluator.py.
+ """
+ Perform inference on the dataset and calculate the average loss.
+
+ This method is adapted from `inference_on_dataset` in Detectron2's evaluator.
+ It also calculates and logs the AP50 metric and updates the best model checkpoint if needed.
Returns:
- _type_: _description_
+ list: A list of loss values for each batch in the dataset.
"""
total = len(self._data_loader)
num_warmup = min(5, total - 1)
@@ -83,6 +206,7 @@ def _do_loss_eval(self):
losses = []
for idx, inputs in enumerate(self._data_loader):
if idx == num_warmup:
+ # Reset the start time after the warm-up phase
start_time = time.perf_counter()
total_compute_time = 0
start_compute_time = time.perf_counter()
@@ -92,6 +216,7 @@ def _do_loss_eval(self):
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
seconds_per_img = total_compute_time / iters_after_start
if idx >= num_warmup * 2 or seconds_per_img > 5:
+ # Log progress and estimated time remaining
total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
log_every_n_seconds(
@@ -100,11 +225,13 @@ def _do_loss_eval(self):
str(eta)),
n=5,
)
+ # Calculate loss for the current batch
loss_batch = self._get_loss(inputs)
losses.append(loss_batch)
+
mean_loss = np.mean(losses)
- # print(self.trainer.cfg.DATASETS.TEST)
- # Combine the AP50s of the different datasets
+
+ # Calculate the average AP50 across datasets if multiple datasets are used for testing
if len(self.trainer.cfg.DATASETS.TEST) > 1:
APs = []
for dataset in self.trainer.cfg.DATASETS.TEST:
@@ -112,7 +239,10 @@ def _do_loss_eval(self):
AP = sum(APs) / len(APs)
else:
AP = self.trainer.test(self.trainer.cfg, self.trainer.model)["segm"]["AP50"]
- print("Av. AP50 =", AP)
+
+ print("Av. segm AP50 =", AP)
+
+ # Store the calculated loss and AP50 in the trainer's storage
self.trainer.APs.append(AP)
self.trainer.storage.put_scalar("validation_loss", mean_loss)
self.trainer.storage.put_scalar("validation_ap", AP)
@@ -121,15 +251,17 @@ def _do_loss_eval(self):
return losses
def _get_loss(self, data):
- """Calculate loss in train_loop.
+ """
+ Compute the loss for a given batch of data.
Args:
- data (_type_): _description_
+ data (dict): A batch of input data.
Returns:
- _type_: _description_
+ float: The total loss for the batch.
"""
metrics_dict = self._model(data)
+ # Detach and move to CPU for logging
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
@@ -138,57 +270,74 @@ def _get_loss(self, data):
return total_losses_reduced
def after_step(self):
+ """
+ Hook to be called after each training iteration to evaluate the model and manage checkpoints.
+
+ - Evaluates the model at regular intervals.
+ - Saves the best model checkpoint based on the AP50 metric.
+ - Implements early stopping if the AP50 does not improve after a set number of evaluations.
+ """
next_iter = self.trainer.iter + 1
is_final = next_iter == self.trainer.max_iter
if is_final or (self._period > 0 and next_iter % self._period == 0):
self._do_loss_eval()
+ # Check if the current AP50 is the best so far
if self.max_ap < self.trainer.APs[-1]:
self.iter = 0
self.max_ap = self.trainer.APs[-1]
+ # Save the current best model
self.trainer.checkpointer.save("model_" + str(len(self.trainer.APs)))
self.best_iter = self.trainer.iter
else:
self.iter += 1
if self.iter == self.patience:
+ # Early stopping condition met
self.trainer.early_stop = True
print("Early stopping occurs in iter {}, max ap is {}".format(self.best_iter, self.max_ap))
self.trainer.storage.put_scalars(timetest=12)
def after_train(self):
+ """
+ Hook to be called after training is complete to load the best model checkpoint based on AP50.
+
+ - Selects and loads the model checkpoint with the best AP50.
+ """
+ if not self.trainer.APs:
+ print("No APs were recorded during training. Skipping model selection.")
+ return
# Select the model with the best AP50
index = self.trainer.APs.index(max(self.trainer.APs)) + 1
- # Error in demo:
- # AssertionError: Checkpoint /__w/detectree2/detectree2/detectree2-data/paracou-out/train_outputs-1/model_1.pth
- # not found!
- # Therefore sleep is attempt to allow CI to pass, but it often still fails.
+ # Error handling for checkpoint loading, with a sleep to ensure file availability in CI environments
time.sleep(15)
self.trainer.checkpointer.load(self.trainer.cfg.OUTPUT_DIR + '/model_' + str(index) + '.pth')
# See https://jss367.github.io/data-augmentation-in-detectron2.html for data augmentation advice
class MyTrainer(DefaultTrainer):
- """Summary.
+ """
+ Custom Trainer class that extends the DefaultTrainer.
- Args:
- DefaultTrainer (_type_): _description_
+ This trainer adds flexibility for handling different image types (e.g., RGB and multi-band images)
+ and custom training behavior, such as early stopping and specialized data augmentation strategies.
- Returns:
- _type_: _description_
+ Args:
+ cfg (CfgNode): Configuration object containing the model and dataset configurations.
+ patience (int): Number of evaluation periods to wait for improvement before early stopping.
"""
def __init__(self, cfg, patience): # noqa: D107
self.patience = patience
- # self.resize = resize
super().__init__(cfg)
def train(self):
- """Run training.
+ """
+ Run the training loop.
- Args:
- start_iter, max_iter (int): See docs above
+ This method overrides the DefaultTrainer's train method to include early stopping and
+ custom logging of Average Precision (AP) metrics.
Returns:
- OrderedDict of results, if evaluation is enabled. Otherwise None.
+ OrderedDict: Results from evaluation, if evaluation is enabled. Otherwise, None.
"""
start_iter = self.start_iter
@@ -219,6 +368,7 @@ def train(self):
raise
finally:
self.after_train()
+ # Verify the results if testing is enabled and this is the main process
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
assert hasattr(self, "_last_eval_results"), "No evaluation results obtained during training!"
verify_results(self.cfg, self._last_eval_results)
@@ -226,16 +376,61 @@ def train(self):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
+ """
+ Build the evaluator for the model.
+
+ Args:
+ cfg (CfgNode): Configuration object.
+ dataset_name (str): Name of the dataset to evaluate.
+ output_folder (str, optional): Directory to save evaluation results. Defaults to "eval".
+
+ Returns:
+ COCOEvaluator: An evaluator for COCO-style datasets.
+ """
if output_folder is None:
os.makedirs("eval", exist_ok=True)
output_folder = "eval"
return COCOEvaluator(dataset_name, cfg, True, output_folder)
def build_hooks(self):
+ """
+ Build the training hooks, including the custom LossEvalHook.
+
+ This method adds a custom hook for evaluating the model's loss during training, with support for
+ early stopping based on the AP50 metric.
+
+ Returns:
+ list: A list of hooks to be used during training.
+ """
hooks = super().build_hooks()
- # augmentations = [T.ResizeShortestEdge(short_edge_length=(1000, 1000),
- # max_size=1333,
- # sample_style='choice')]
+
+ # Determine the appropriate resize strategy based on the configuration
+ if self.cfg.RESIZE == "random":
+ size = None
+ # Attempt to determine the image size from the training dataset
+ for i, datas in enumerate(DatasetCatalog.get(self.cfg.DATASETS.TRAIN[0])):
+ location = datas['file_name']
+ try:
+ # Attempt to read the image with OpenCV (for RGB images)
+ img = cv2.imread(location)
+ if img is not None:
+ size = img.shape[0]
+ else:
+ # Fall back to rasterio for multi-band images
+ with rasterio.open(location) as src:
+ size = src.height # Assuming square images
+ except Exception as e:
+ # Handle any errors that occur during loading
+ print(f"Error loading image {location}: {e}")
+ continue
+ break
+ # Define augmentation based on the determined size
+ augmentations = [T.ResizeShortestEdge([size, size], size + 300)]
+ else:
+ # Use fixed size resizing as a default
+ augmentations = [T.ResizeShortestEdge([1000, 1000], 1333)]
+
+ # Insert the custom LossEvalHook before the last hook (typically the evaluation hook)
hooks.insert(
-1,
LossEvalHook(
@@ -244,53 +439,102 @@ def build_hooks(self):
build_detection_test_loader(
self.cfg,
self.cfg.DATASETS.TEST,
- DatasetMapper(self.cfg, True)
+ FlexibleDatasetMapper(self.cfg, True, augmentations=augmentations)
),
self.patience,
),
)
return hooks
+ @classmethod
+ def build_train_loader(cls, cfg):
+ """
+ Build the training data loader with support for custom augmentations and image types.
-def build_train_loader(cls, cfg):
- """Summary.
+ This method configures the data loader to apply specific augmentations depending on the image mode
+ (RGB or multi-band) and resize strategy defined in the configuration.
- Args:
- cfg (_type_): _description_
+ Args:
+ cfg (CfgNode): Configuration object.
- Returns:
- _type_: _description_
- """
- augmentations = [
- T.RandomBrightness(0.8, 1.8),
- T.RandomContrast(0.6, 1.3),
- T.RandomSaturation(0.8, 1.4),
- T.RandomRotation(angle=[90, 90], expand=False),
- T.RandomLighting(0.7),
- T.RandomFlip(prob=0.4, horizontal=True, vertical=False),
- T.RandomFlip(prob=0.4, horizontal=False, vertical=True),
- ]
+ Returns:
+ DataLoader: A data loader for the training dataset.
+ """
- if cfg.RESIZE:
- augmentations.append(T.Resize((1000, 1000)))
- elif cfg.RESIZE == "random":
- for i, datas in enumerate(DatasetCatalog.get(cfg.DATASETS.TRAIN[0])):
- location = datas['file_name']
- size = cv2.imread(location).shape[0]
- break
- print("ADD RANDOM RESIZE WITH SIZE = ", size)
- augmentations.append(T.ResizeScale(0.6, 1.4, size, size))
- return build_detection_train_loader(
- cfg,
- mapper=DatasetMapper(
+ # Define basic augmentations including rotation and flipping
+ augmentations = [
+ T.RandomRotation(angle=[90, 90], expand=False),
+ T.RandomFlip(prob=0.4, horizontal=True, vertical=False),
+ T.RandomFlip(prob=0.4, horizontal=False, vertical=True),
+ ]
+
+ # Additional augmentations for RGB images
+ if cfg.IMGMODE == "rgb":
+ augmentations.extend([
+ T.RandomBrightness(0.7, 1.5),
+ T.RandomLighting(0.7),
+ T.RandomContrast(0.6, 1.3),
+ T.RandomSaturation(0.8, 1.4)
+ ])
+
+ # Add resizing augmentations based on the resize strategy
+ if cfg.RESIZE == "fixed":
+ augmentations.append(T.ResizeShortestEdge([1000, 1000], 1333))
+ elif cfg.RESIZE == "random":
+ size = None
+ for i, datas in enumerate(DatasetCatalog.get(cfg.DATASETS.TRAIN[0])):
+ location = datas['file_name']
+ try:
+ # Try to read with cv2 (for RGB images)
+ img = cv2.imread(location)
+ if img is not None:
+ size = img.shape[0]
+ else:
+ # Fall back to rasterio for multi-band images
+ with rasterio.open(location) as src:
+ size = src.height # Assuming square images
+ except Exception as e:
+ # Handle any errors that occur during loading
+ print(f"Error loading image {location}: {e}")
+ continue
+ break
+
+ if size:
+ print("ADD RANDOM RESIZE WITH SIZE = ", size)
+ augmentations.append(T.ResizeScale(0.6, 1.4, size, size))
+ else:
+ raise ValueError("Failed to determine image size for random resize")
+ elif cfg.RESIZE == "rand_fixed":
+ augmentations.append(T.ResizeScale(0.6, 1.4, 1000, 1000))
+
+ return build_detection_train_loader(
cfg,
- is_train=True,
- augmentations=augmentations,
- ),
- )
+ mapper=FlexibleDatasetMapper(
+ cfg,
+ is_train=True,
+ augmentations=augmentations,
+ ),
+ )
+
+ @classmethod
+ def build_test_loader(cls, cfg, dataset_name):
+ """
+ Build the test data loader.
+
+ This method configures the data loader for evaluation, using the FlexibleDatasetMapper
+ to handle custom augmentations and image types.
+
+ Args:
+ cfg (CfgNode): Configuration object.
+ dataset_name (str): Name of the dataset to load for testing.
+
+ Returns:
+ DataLoader: A data loader for the test dataset.
+ """
+ return build_detection_test_loader(cfg, dataset_name, mapper=FlexibleDatasetMapper(cfg, is_train=False))
-def get_tree_dicts(directory: str, classes: List[str] = None, classes_at: str = None) -> List[Dict]:
+def get_tree_dicts(directory: str, class_mapping: Optional[Dict[str, int]] = None) -> List[Dict[str, Any]]:
"""Get the tree dictionaries.
Args:
@@ -302,36 +546,24 @@ def get_tree_dicts(directory: str, classes: List[str] = None, classes_at: str =
List of dictionaries corresponding to segmentations of trees. Each dictionary includes
bounding box around tree and points tracing a polygon around a tree.
"""
- # filepath = '/content/drive/MyDrive/forestseg/paracou_data/Panayiotis_Outputs/220303_AllSpLabelled.gpkg'
- # datagpd = gpd.read_file(filepath)
- # List_Genus = datagpd.Genus_Species.to_list()
- # Genus_Species_UniqueList = list(set(List_Genus))
-
- #
- if classes is not None:
- # list_of_classes = crowns[variable].unique().tolist()
- classes = classes
- else:
- classes = ["tree"]
- # classes = Genus_Species_UniqueList #['tree'] # genus_species list
+
dataset_dicts = []
- # for root, dirs, files in os.walk(train_location):
- # for file in files:
- # if file.endswith(".geojson"):
- # print(os.path.join(root, file))
for filename in [file for file in os.listdir(directory) if file.endswith(".geojson")]:
json_file = os.path.join(directory, filename)
with open(json_file) as f:
img_anns = json.load(f)
- # Turn off type checking for annotations until we have a better solution
- record: Dict[str, Any] = {}
- # filename = os.path.join(directory, img_anns["imagePath"])
+ record: Dict[str, Any] = {}
filename = img_anns["imagePath"]
# Make sure we have the correct height and width
- height, width = cv2.imread(filename).shape[:2]
+ # If image path ends in .png use cv2 to get height and width else if image path ends in .tif use rasterio
+ if filename.endswith(".png"):
+ height, width = cv2.imread(filename).shape[:2]
+ elif filename.endswith(".tif"):
+ with rasterio.open(filename) as src:
+ height, width = src.shape
record["file_name"] = filename
record["height"] = height
@@ -343,65 +575,78 @@ def get_tree_dicts(directory: str, classes: List[str] = None, classes_at: str =
objs = []
for features in img_anns["features"]:
anno = features["geometry"]
- # pdb.set_trace()
- # GenusSpecies = features['properties']['Genus_Species']
px = [a[0] for a in anno["coordinates"][0]]
py = [np.array(height) - a[1] for a in anno["coordinates"][0]]
- # print("### HERE IS PY ###", py)
poly = [(x, y) for x, y in zip(px, py)]
poly = [p for x in poly for p in x]
- # print("#### HERE ARE SOME POLYS #####", poly)
- if classes != ["tree"]:
- obj = {
- "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
- "bbox_mode": BoxMode.XYXY_ABS,
- "segmentation": [poly],
- "category_id": classes.index(features["properties"][classes_at]), # id
- # "category_id": 0, #id
- "iscrowd": 0,
- }
+
+ # If class mapping is provided, use it; otherwise, default to "tree"
+ if class_mapping:
+ category_id = class_mapping[features["properties"]["status"]]
else:
- obj = {
- "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
- "bbox_mode": BoxMode.XYXY_ABS,
- "segmentation": [poly],
- "category_id": 0, # id
- "iscrowd": 0,
- }
- # pdb.set_trace()
+ category_id = 0 # Default to "tree" if no class mapping is provided
+
+ obj = {
+ "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
+ "bbox_mode": BoxMode.XYXY_ABS,
+ "segmentation": [poly],
+ "category_id": category_id,
+ "iscrowd": 0,
+ }
+
objs.append(obj)
- # print("#### HERE IS OBJS #####", objs)
- record["annotations"] = objs
+
+ record["annotations"] = objs if objs else []
dataset_dicts.append(record)
+
return dataset_dicts
def combine_dicts(root_dir: str,
val_dir: int,
mode: str = "train",
- classes: List[str] = None,
- classes_at: str = None) -> List[Dict]:
- """Join tree dicts from different directories.
+ class_mapping: Optional[Dict[str, int]] = None) -> List[Dict[str, Any]]:
+ """
+ Combine dictionaries from different directories based on the specified mode.
+
+ This function aggregates tree dictionaries from multiple directories within a root directory.
+ Depending on the mode, it either combines dictionaries from all directories,
+ all except a specified validation directory, or only from the validation directory.
Args:
- root_dir:
- val_dir:
+ root_dir (str): The root directory containing subdirectories with tree dictionaries.
+ val_dir (int): The index (1-based) of the validation directory to exclude or use depending on the mode.
+ mode (str, optional): The mode of operation. Can be "train", "val", or "full".
+ "train" excludes the validation directory,
+ "val" includes only the validation directory,
+ and "full" includes all directories. Defaults to "train".
+ class_mapping: A dictionary mapping class labels to category indices (optional).
Returns:
- Concatenated array of dictionaries over all directories
+ List of combined dictionaries from the specified directories.
"""
- train_dirs = [os.path.join(root_dir, dir) for dir in os.listdir(root_dir)]
+ # Get the list of directories within the root directory
+ train_dirs = [
+ os.path.join(root_dir, dir)
+ for dir in os.listdir(root_dir)
+ if os.path.isdir(os.path.join(root_dir, dir))
+ ]
+ # Handle the different modes for combining dictionaries
if mode == "train":
+ # Exclude the validation directory from the list of directories
del train_dirs[(val_dir - 1)]
tree_dicts = []
for d in train_dirs:
- tree_dicts += get_tree_dicts(d, classes=classes, classes_at=classes_at)
+ # Combine dictionaries from all directories except the validation directory
+ tree_dicts += get_tree_dicts(d, class_mapping=class_mapping)
elif mode == "val":
- tree_dicts = get_tree_dicts(train_dirs[(val_dir - 1)], classes=classes, classes_at=classes_at)
+ # Use only the validation directory
+ tree_dicts = get_tree_dicts(train_dirs[(val_dir - 1)], class_mapping=class_mapping)
elif mode == "full":
+ # Combine dictionaries from all directories, including the validation directory
tree_dicts = []
for d in train_dirs:
- tree_dicts += get_tree_dicts(d, classes=classes, classes_at=classes_at)
+ tree_dicts += get_tree_dicts(d, class_mapping=class_mapping)
return tree_dicts
@@ -427,41 +672,52 @@ def get_filenames(directory: str):
def register_train_data(train_location,
name: str = "tree",
val_fold=None,
- classes=None,
- classes_at=None):
+ class_mapping_file=None):
"""Register data for training and (optionally) validation.
Args:
- train_location: directory containing training folds
- name: string to name data
- val_fold: fold assigned for validation and tuning. If not given,
- will take place on all folds.
+ train_location: Directory containing training folds.
+ name: Name to register the dataset.
+ val_fold: Validation fold index (optional).
+ class_mapping_file: Path to the class mapping file (json or pickle).
"""
+ # Load the class mapping from file if provided
+ class_mapping = None
+ if class_mapping_file:
+ class_mapping = load_class_mapping(class_mapping_file)
+ thing_classes = list(class_mapping.keys()) # Convert dictionary to list of class names
+ print(f"Class mapping loaded: {class_mapping}") # Debugging step
+ else:
+ thing_classes = ["tree"]
+
if val_fold is not None:
for d in ["train", "val"]:
- DatasetCatalog.register(name + "_" + d, lambda d=d: combine_dicts(train_location,
- val_fold, d,
- classes=classes, classes_at=classes_at))
- if classes is None:
- MetadataCatalog.get(name + "_" + d).set(thing_classes=["tree"])
- else:
- MetadataCatalog.get(name + "_" + d).set(thing_classes=classes)
+ DatasetCatalog.register(
+ name + "_" + d,
+ lambda d=d: combine_dicts(train_location, val_fold, d, class_mapping=class_mapping)
+ )
+ MetadataCatalog.get(name + "_" + d).set(thing_classes=thing_classes)
else:
- DatasetCatalog.register(name + "_" + "full", lambda d=d: combine_dicts(train_location,
- 0, "full",
- classes=classes, classes_at=classes_at))
- if classes is None:
- MetadataCatalog.get(name + "_" + "full").set(thing_classes=["tree"])
- else:
- MetadataCatalog.get(name + "_" + "full").set(thing_classes=classes)
+ DatasetCatalog.register(
+ name + "_" + "full",
+ lambda d=d: combine_dicts(train_location, 0, "full", class_mapping=class_mapping)
+ )
+ MetadataCatalog.get(name + "_" + "full").set(thing_classes=thing_classes)
+
+def get_classes(out_dir):
+ """Function that will read the classes that are recorded during tiling.
-def read_data(out_dir):
- """Function that will read the classes that are recorded during tiling."""
+ Args:
+ out_dir: directory where classes.txt is located
+
+ Returns:
+ list of classes
+ """
list = []
- out_tif = out_dir + 'classes.txt'
+ classes_txt = out_dir + 'classes.txt'
# open file and read the content in a list
- with open(out_tif, 'r') as fp:
+ with open(classes_txt, 'r') as fp:
for line in fp:
# remove linebreak from a current name
# linebreak is the last character of each line
@@ -483,14 +739,32 @@ def remove_registered_data(name="tree"):
def register_test_data(test_location, name="tree"):
- """Register data for testing."""
+ """Register data for testing.
+
+ Args:
+ test_location: directory containing test data
+ name: string to name data
+ """
d = "test"
- DatasetCatalog.register(name + "_" + d, lambda d=d: get_tree_dicts(test_location))
- MetadataCatalog.get(name + "_" + d).set(thing_classes=["tree"])
+
+ class_mapping = None
+ if class_mapping_file:
+ class_mapping = load_class_mapping(class_mapping_file)
+ thing_classes = list(class_mapping.keys()) # Convert dictionary to list of class names
+ print(f"Class mapping loaded: {class_mapping}") # Debugging step
+ else:
+ thing_classes = ["tree"]
+
+ DatasetCatalog.register(name + "_" + d, lambda d=d: get_tree_dicts(test_location, class_mapping))
+ MetadataCatalog.get(name + "_" + d).set(thing_classes=thing_classes)
def load_json_arr(json_path):
- """Load json array."""
+ """Load json array.
+
+ Args:
+ json_path: path to json file
+ """
lines = []
with open(json_path, "r") as f:
for line in f:
@@ -513,10 +787,12 @@ def setup_cfg(
base_lr=0.0003389,
weight_decay=0.001,
max_iter=1000,
- num_classes=1,
eval_period=100,
out_dir="./train_outputs",
- resize=True,
+ resize="fixed", # "fixed" or "random" or "rand_fixed"
+ imgmode="rgb",
+ num_bands=3,
+ class_mapping_file=None,
):
"""Set up config object # noqa: D417.
@@ -538,7 +814,23 @@ def setup_cfg(
num_classes: number of classes
eval_period: number of iterations between evaluations
out_dir: directory to save outputs
+ resize: resize strategy for images
+ imgmode: image mode (rgb or multispectral)
+ num_bands: number of bands in the image
+ class_mapping_file: path to class mapping file
"""
+
+ # Load the class mapping if provided
+ if class_mapping_file:
+ class_mapping = load_class_mapping(class_mapping_file)
+ num_classes = len(class_mapping) # Set the number of classes based on the mapping
+ else:
+ num_classes = 1 # Default to 1 class if no mapping is provided
+
+ # Validate the resize parameter
+ if resize not in {"fixed", "random", "rand_fixed"}:
+ raise ValueError(f"Invalid resize option '{resize}'. Must be 'fixed', 'random', or 'rand_fixed'.")
+
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(base_model))
cfg.DATASETS.TRAIN = trains
@@ -566,6 +858,20 @@ def setup_cfg(
cfg.TEST.EVAL_PERIOD = eval_period
cfg.RESIZE = resize
cfg.INPUT.MIN_SIZE_TRAIN = 1000
+ cfg.IMGMODE = imgmode # "rgb" or "ms" (multispectral)
+ if num_bands > 3:
+ # Adjust PIXEL_MEAN and PIXEL_STD for the number of bands
+ default_pixel_mean = cfg.MODEL.PIXEL_MEAN
+ default_pixel_std = cfg.MODEL.PIXEL_STD
+ # Extend or truncate the PIXEL_MEAN and PIXEL_STD based on num_bands
+ cfg.MODEL.PIXEL_MEAN = (
+ default_pixel_mean * (num_bands // len(default_pixel_mean))
+ + default_pixel_mean[:num_bands % len(default_pixel_mean)]
+ )
+ cfg.MODEL.PIXEL_STD = (
+ default_pixel_std * (num_bands // len(default_pixel_std))
+ + default_pixel_std[:num_bands % len(default_pixel_std)]
+ )
return cfg
@@ -576,7 +882,17 @@ def predictions_on_data(directory=None,
scale=1,
geos_exist=True,
num_predictions=0):
- """Prediction produced from a test folder and outputted to predictions folder."""
+ """Prediction produced from a test folder and outputted to predictions folder.
+
+ Args:
+ directory: directory containing test data
+ predictor: predictor object
+ trees_metadata: metadata for trees
+ save: boolean to save predictions
+ scale: scale of image
+ geos_exist: boolean to determine if geojson files exist
+ num_predictions: number of predictions to make
+ """
test_location = directory + "/test"
pred_dir = test_location + "/predictions"
@@ -623,46 +939,94 @@ def predictions_on_data(directory=None,
json.dump(evaluations, dest)
+def modify_conv1_weights(model, num_input_channels):
+ """
+ Modify the weights of the first convolutional layer (conv1) to accommodate a different number of input channels.
+
+ This function adjusts the weights of the `conv1` layer in the model's backbone to support a custom number
+ of input channels. It creates a new weight tensor with the desired number of input channels,
+ and initializes it by repeating the weights of the original channels.
+
+ Args:
+ model (torch.nn.Module): The model containing the convolutional layer to modify.
+ num_input_channels (int): The number of input channels for the new conv1 layer.
+
+ """
+ with torch.no_grad():
+ # Retrieve the original weights of the conv1 layer
+ old_weights = model.backbone.bottom_up.stem.conv1.weight
+
+ # Create a new weight tensor with the desired number of input channels
+ # The shape is (out_channels, in_channels, height, width)
+ new_weights = torch.zeros((old_weights.size(0), num_input_channels, *old_weights.shape[2:]))
+
+ # Initialize the new weights by repeating the original weights across the new channels
+ # This example repeats the first 3 channels if num_input_channels > 3
+ for i in range(num_input_channels):
+ new_weights[:, i, :, :] = old_weights[:, i % 3, :, :]
+
+ # Create a new conv1 layer with the updated number of input channels
+ model.backbone.bottom_up.stem.conv1 = nn.Conv2d(
+ num_input_channels, old_weights.size(0), kernel_size=7, stride=2, padding=3, bias=False
+ )
+
+ # Copy the modified weights into the new conv1 layer
+ model.backbone.bottom_up.stem.conv1.weight.copy_(new_weights)
+
+
+def get_latest_model_path(output_dir: str) -> str:
+ """
+ Find the model file with the highest index in the specified output directory.
+
+ Args:
+ output_dir (str): The directory where the model files are stored.
+
+ Returns:
+ str: The path to the model file with the highest index.
+ """
+ # Regular expression to match model files with the pattern "model_X.pth"
+ model_pattern = re.compile(r"model_(\d+)\.pth")
+
+ # List all files in the output directory
+ files = os.listdir(output_dir)
+
+ # Find all files that match the pattern and extract their indices
+ model_files = []
+ for f in files:
+ match = model_pattern.search(f)
+ if match:
+ model_files.append((f, int(match.group(1))))
+
+ if not model_files:
+ raise FileNotFoundError(f"No model files found in the directory {output_dir}")
+
+ # Sort the files by index in descending order and select the highest one
+ latest_model_file = max(model_files, key=lambda x: x[1])[0]
+
+ # Return the full path to the latest model file
+ return os.path.join(output_dir, latest_model_file)
+
+
if __name__ == "__main__":
- train_location = "/content/drive/Shareddrives/detectree2/data/Paracou/tiles/train/"
- register_train_data(train_location, "Paracou", 1) # folder, name, validation fold
-
- name = "Paracou2019"
- train_location = "/content/drive/Shareddrives/detectree2/data/Paracou/tiles2019/train/"
- dataset_dicts = combine_dicts(train_location, 1)
- trees_metadata = MetadataCatalog.get(name + "_train")
- # dataset_dicts = get_tree_dicts("./")
- for d in dataset_dicts:
- img = cv2.imread(d["file_name"])
- visualizer = Visualizer(img[:, :, ::-1], metadata=trees_metadata, scale=0.5)
- out = visualizer.draw_dataset_dict(d)
- image = cv2.cvtColor(out.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)
- # display(Image.fromarray(image))
- # Set the base (pre-trained) model from the detectron2 model_zoo
- model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
- # Set the names of the registered train and test sets
- # pretrained model?
- # trained_model = "/content/drive/Shareddrives/detectree2/models/220629_ParacouSepilokDanum_JB.pth"
- trains = (
- "Paracou_train",
- "Paracou2019_train",
- "ParacouUAV_train",
- "Danum_train",
- "SepilokEast_train",
- "SepilokWest_train",
- )
- tests = (
- "Paracou_val",
- "Paracou2019_val",
- "ParacouUAV_val",
- "Danum_val",
- "SepilokEast_val",
- "SepilokWest_val",
+ # Define paths to training data and optional class mapping file
+ train_location = "/path/to/your/train/location"
+ class_mapping_file = "/path/to/your/class_to_idx.json" # Optional, can be None
+
+ # Register the training and validation datasets using the class mapping
+ # If class_mapping_file is not provided, defaults to "tree"
+ register_train_data(train_location, "MyDataset", val_fold=1, class_mapping_file=class_mapping_file)
+
+ # Set up model configuration, using the class mapping to determine the number of classes
+ cfg = setup_cfg(
+ base_model="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
+ trains=("MyDataset_train", ),
+ tests=("MyDataset_val", ),
+ max_iter=3000,
+ out_dir="/path/to/output",
+ class_mapping_file=class_mapping_file # Optional
)
- out_dir = "/content/drive/Shareddrives/detectree2/220703_train_outputs"
- # update_model arg can be used to load in trained model
- cfg = setup_cfg(model, trains, tests, eval_period=100, max_iter=3000, out_dir=out_dir)
+ # Train the model
trainer = MyTrainer(cfg, patience=4)
trainer.resume_or_load(resume=False)
trainer.train()
diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py
index 2ceeb8c4..2a9c1d7d 100644
--- a/detectree2/preprocessing/tiling.py
+++ b/detectree2/preprocessing/tiling.py
@@ -4,11 +4,14 @@
of models and making landscape predictions.
"""
+import concurrent.futures
import json
+import logging
import os
+import pickle
import random
import shutil
-import warnings
+import warnings # noqa: F401
from math import ceil
from pathlib import Path
@@ -17,11 +20,17 @@
import numpy as np
import rasterio
from fiona.crs import from_epsg # noqa: F401
-from rasterio.crs import CRS
-from rasterio.io import DatasetReader
+# from rasterio.crs import CRS
+from rasterio.errors import RasterioIOError
+# from rasterio.io import DatasetReader
from rasterio.mask import mask
+# from rasterio.windows import from_bounds
from shapely.geometry import box
+# Configure logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
# class img_data(DatasetReader):
# """
# Class for image data to be processed for tiling
@@ -47,87 +56,93 @@ def get_features(gdf: gpd.GeoDataFrame):
return [json.loads(gdf.to_json())["features"][0]["geometry"]]
-def tile_data(
- data: DatasetReader,
- out_dir: str,
- buffer: int = 30,
- tile_width: int = 200,
- tile_height: int = 200,
- dtype_bool: bool = False,
-) -> None:
- """Tiles up orthomosaic for making predictions on.
+def load_class_mapping(file_path: str):
+ """Function to load class-to-index mapping from a file.
- Tiles up full othomosaic into managable chunks to make predictions on. Use tile_data_train to generate tiled
- training data. A bug exists on some input raster types whereby outputed tiles are completely black - the dtype_bool
- argument should be switched if this is the case.
+ Args:
+ file_path: Path to the file (json or pickle)
+
+ Returns:
+ class_to_idx: Loaded class-to-index mapping
+ """
+ file_ext = Path(file_path).suffix
+
+ if file_ext == '.json':
+ with open(file_path, 'r') as f:
+ class_to_idx = json.load(f)
+ elif file_ext == '.pkl':
+ with open(file_path, 'rb') as f:
+ class_to_idx = pickle.load(f)
+ else:
+ raise ValueError("Unsupported file format. Use '.json' or '.pkl'.")
+
+ return class_to_idx
+
+
+def process_tile(
+ img_path: str,
+ out_dir: str,
+ buffer: int,
+ tile_width: int,
+ tile_height: int,
+ dtype_bool: bool,
+ minx,
+ miny,
+ crs,
+ tilename,
+ crowns: gpd.GeoDataFrame = None,
+ threshold: float = 0,
+ nan_threshold: float = 0,
+):
+ """Process a single tile for making predictions.
Args:
- data: Orthomosaic as a rasterio object in a UTM type projection
+ img_path: Path to the orthomosaic
+ out_dir: Output directory
buffer: Overlapping buffer of tiles in meters (UTM)
tile_width: Tile width in meters
tile_height: Tile height in meters
dtype_bool: Flag to edit dtype to prevent black tiles
+ minx: Minimum x coordinate of tile
+ miny: Minimum y coordinate of tile
+ crs: Coordinate reference system
+ tilename: Name of the tile
Returns:
None
"""
- out_path = Path(out_dir)
- os.makedirs(out_path, exist_ok=True)
- crs = CRS.from_string(data.crs.wkt)
- crs = crs.to_epsg()
- tilename = Path(data.name).stem
-
- total_tiles = int(
- ((data.bounds[2] - data.bounds[0]) / tile_width) * ((data.bounds[3] - data.bounds[1]) / tile_height)
- )
-
- tile_count = 0
- print(f"Tiling to {total_tiles} total tiles")
-
- for minx in np.arange(data.bounds[0], data.bounds[2] - tile_width,
- tile_width, int):
- for miny in np.arange(data.bounds[1], data.bounds[3] - tile_height,
- tile_height, int):
-
- tile_count += 1
- # Naming conventions
+ try:
+ with rasterio.open(img_path) as data:
+ out_path = Path(out_dir)
out_path_root = out_path / f"{tilename}_{minx}_{miny}_{tile_width}_{buffer}_{crs}"
- # new tiling bbox including the buffer
- bbox = box(
- minx - buffer,
- miny - buffer,
- minx + tile_width + buffer,
- miny + tile_height + buffer,
- )
- # define the bounding box of the tile, excluding the buffer
- # (hence selecting just the central part of the tile)
- # bbox_central = box(minx, miny, minx + tile_width, miny + tile_height)
- # turn the bounding boxes into geopandas DataFrames
- geo = gpd.GeoDataFrame({"geometry": bbox}, index=[0], crs=data.crs)
- # geo_central = gpd.GeoDataFrame(
- # {"geometry": bbox_central}, index=[0], crs=from_epsg(4326)
- # ) # 3182
- # overlapping_crowns = sjoin(crowns, geo_central, how="inner")
+ minx_buffered = minx - buffer
+ miny_buffered = miny - buffer
+ maxx_buffered = minx + tile_width + buffer
+ maxy_buffered = miny + tile_height + buffer
- # here we are cropping the tiff to the bounding box of the tile we want
+ bbox = box(minx_buffered, miny_buffered, maxx_buffered, maxy_buffered)
+ geo = gpd.GeoDataFrame({"geometry": bbox}, index=[0], crs=data.crs)
coords = get_features(geo)
- # print("Coords:", coords)
- # define the tile as a mask of the whole tiff with just the bounding box
+ overlapping_crowns = None
+ if crowns is not None:
+ overlapping_crowns = gpd.clip(crowns, geo)
+ if overlapping_crowns.empty or (overlapping_crowns.dissolve().area[0] / geo.area[0]) < threshold:
+ return None
+
out_img, out_transform = mask(data, shapes=coords, crop=True)
- # Discard scenes with many out-of-range pixels
- out_sumbands = np.sum(out_img, 0)
+ out_sumbands = np.sum(out_img, axis=0)
zero_mask = np.where(out_sumbands == 0, 1, 0)
nan_mask = np.where(out_sumbands == 765, 1, 0)
sumzero = zero_mask.sum()
sumnan = nan_mask.sum()
totalpix = out_img.shape[1] * out_img.shape[2]
- if sumzero > 0.25 * totalpix:
- continue
- elif sumnan > 0.25 * totalpix:
- continue
+
+ # If the tile is mostly empty or mostly nan, don't save it
+ if sumzero > nan_threshold * totalpix or sumnan > nan_threshold * totalpix:
+ return None
out_meta = data.meta.copy()
out_meta.update({
@@ -137,145 +152,103 @@ def tile_data(
"transform": out_transform,
"nodata": None,
})
- # dtype needs to be unchanged for some data and set to uint8 for others
if dtype_bool:
out_meta.update({"dtype": "uint8"})
- # print("Out Meta:",out_meta)
- # Saving the tile as a new tiff, named by the origin of the tile.
- # If tile appears blank in folder can show the image here and may
- # need to fix RGB data or the dtype
- # show(out_img)
- out_tif = out_path_root.with_suffix(out_path_root.suffix + ".tif")
+ out_tif = out_path_root.with_suffix(".tif")
with rasterio.open(out_tif, "w", **out_meta) as dest:
dest.write(out_img)
- # read in the tile we have just saved
- clipped = rasterio.open(out_tif)
- # read it as an array
- # show(clipped)
- arr = clipped.read()
+ with rasterio.open(out_tif) as clipped:
+ arr = clipped.read()
+ r, g, b = arr[0], arr[1], arr[2]
+ rgb = np.dstack((b, g, r)) # Reorder for cv2 (BGRA)
- # each band of the tiled tiff is a colour!
- r = arr[0]
- g = arr[1]
- b = arr[2]
+ # Rescale to 0-255 if necessary
+ if np.max(g) > 255:
+ rgb_rescaled = 255 * rgb / 65535
+ else:
+ rgb_rescaled = rgb
- # stack up the bands in an order appropriate for saving with cv2,
- # then rescale to the correct 0-255 range for cv2
+ cv2.imwrite(str(out_path_root.with_suffix(".png").resolve()), rgb_rescaled)
- rgb = np.dstack((b, g, r)) # BGR for cv2
+ if overlapping_crowns is not None:
+ return data, out_path_root, overlapping_crowns, minx, miny, buffer
- if np.max(g) > 255:
- rgb_rescaled = 255 * rgb / 65535
- else:
- rgb_rescaled = rgb # scale to image
- # print("rgb rescaled", rgb_rescaled)
-
- # save this as jpg or png...we are going for png...again, named with the origin of the specific tile
- # here as a naughty method
- cv2.imwrite(
- str(out_path_root.with_suffix(out_path_root.suffix + ".png").resolve()),
- rgb_rescaled,
- )
- if tile_count % 50 == 0:
- print(f"Processed {tile_count} tiles of {total_tiles} tiles")
+ return data, out_path_root, None, minx, miny, buffer
- print("Tiling complete")
+ except RasterioIOError as e:
+ logger.error(f"RasterioIOError while applying mask {coords}: {e}")
+ return None
+ except Exception as e:
+ logger.error(f"Error processing tile {tilename} at ({minx}, {miny}): {e}")
+ return None
-def tile_data_train( # noqa: C901
- data: DatasetReader,
+def process_tile_ms(
+ img_path: str,
out_dir: str,
- buffer: int = 30,
- tile_width: int = 200,
- tile_height: int = 200,
+ buffer: int,
+ tile_width: int,
+ tile_height: int,
+ dtype_bool: bool,
+ minx,
+ miny,
+ crs,
+ tilename,
crowns: gpd.GeoDataFrame = None,
threshold: float = 0,
- nan_threshold: float = 0.1,
- dtype_bool: bool = False,
-) -> None:
- """Tiles up orthomosaic and corresponding crowns into training tiles.
-
- A threshold can be used to ensure a good coverage of crowns across a tile. Tiles that do not have sufficient
- coverage are rejected.
+ nan_threshold: float = 0,
+):
+ """Process a single tile for making predictions.
Args:
- data: Orthomosaic as a rasterio object in a UTM type projection
+ img_path: Path to the orthomosaic
+ out_dir: Output directory
buffer: Overlapping buffer of tiles in meters (UTM)
tile_width: Tile width in meters
tile_height: Tile height in meters
- crowns: Crown polygons as a geopandas dataframe
- threshold: Min proportion of the tile covered by crowns to be accepted {0,1}
- nan_theshold: Max proportion of tile covered by nans
dtype_bool: Flag to edit dtype to prevent black tiles
+ minx: Minimum x coordinate of tile
+ miny: Minimum y coordinate of tile
+ crs: Coordinate reference system
+ tilename: Name of the tile
Returns:
None
-
"""
-
- # TODO: Clip data to crowns straight away to speed things up
- # TODO: Tighten up epsg handling
- out_path = Path(out_dir)
- os.makedirs(out_path, exist_ok=True)
- tilename = Path(data.name).stem
- crs = CRS.from_string(data.crs.wkt)
- crs = crs.to_epsg()
- # out_img, out_transform = mask(data, shapes=crowns.buffer(buffer), crop=True)
- # Should start from data.bounds[0] + buffer, data.bounds[1] + buffer to avoid later complications
- for minx in np.arange(ceil(data.bounds[0]) + buffer, data.bounds[2] - tile_width - buffer, tile_width, int):
- for miny in np.arange(ceil(data.bounds[1]) + buffer, data.bounds[3] - tile_height - buffer, tile_height, int):
-
+ try:
+ with rasterio.open(img_path) as data:
+ out_path = Path(out_dir)
out_path_root = out_path / f"{tilename}_{minx}_{miny}_{tile_width}_{buffer}_{crs}"
- # Calculate the buffered tile dimensions
- # tile_width_buffered = tile_width + 2 * buffer
- # tile_height_buffered = tile_height + 2 * buffer
-
- # Calculate the bounding box coordinates with buffer
minx_buffered = minx - buffer
miny_buffered = miny - buffer
maxx_buffered = minx + tile_width + buffer
maxy_buffered = miny + tile_height + buffer
- # Create the affine transformation matrix for the tile
- # transform = from_bounds(minx_buffered, miny_buffered, maxx_buffered,
- # maxy_buffered, tile_width_buffered, tile_height_buffered)
-
bbox = box(minx_buffered, miny_buffered, maxx_buffered, maxy_buffered)
- geo = gpd.GeoDataFrame({"geometry": bbox}, index=[0], crs=data.crs)
- coords = get_features(geo)
+ geo = gpd.GeoDataFrame({"geometry": [bbox]}, index=[0], crs=data.crs)
+ coords = [geo.geometry[0].__geo_interface__]
- # Skip if insufficient coverage of crowns - good to have early on to save on unnecessary processing
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- # Warning:
- # _crs_mismatch_warn
+ overlapping_crowns = None
+ if crowns is not None:
overlapping_crowns = gpd.clip(crowns, geo)
+ if overlapping_crowns.empty or (overlapping_crowns.dissolve().area[0] / geo.area[0]) < threshold:
+ return None
- # Ignore tiles with no crowns
- if overlapping_crowns.empty:
- continue
-
- # Discard tiles that do not have a sufficient coverage of training crowns
- if (overlapping_crowns.dissolve().area[0] / geo.area[0]) < threshold:
- continue
-
- # define the tile as a mask of the whole tiff with just the bounding box
out_img, out_transform = mask(data, shapes=coords, crop=True)
- # Discard scenes with many out-of-range pixels
- out_sumbands = np.sum(out_img, 0)
+ out_sumbands = np.sum(out_img, axis=0)
zero_mask = np.where(out_sumbands == 0, 1, 0)
- nan_mask = np.where(out_sumbands == 765, 1, 0)
+ nan_mask = np.isnan(out_sumbands)
sumzero = zero_mask.sum()
sumnan = nan_mask.sum()
totalpix = out_img.shape[1] * out_img.shape[2]
- if sumzero > nan_threshold * totalpix: # reject tiles with many 0 cells
- continue
- elif sumnan > nan_threshold * totalpix: # reject tiles with many NaN cells
- continue
+
+ # If the tile is mostly empty or mostly nan, don't save it
+ if sumzero > nan_threshold * totalpix or sumnan > nan_threshold * totalpix:
+ return None
out_meta = data.meta.copy()
out_meta.update({
@@ -285,104 +258,171 @@ def tile_data_train( # noqa: C901
"transform": out_transform,
"nodata": None,
})
- # dtype needs to be unchanged for some data and set to uint8 for others to deal with black tiles
if dtype_bool:
out_meta.update({"dtype": "uint8"})
- # Saving the tile as a new tiff, named by the origin of the tile. If tile appears blank in folder can show
- # the image here and may need to fix RGB data or the dtype
- out_tif = out_path_root.with_suffix(out_path_root.suffix + ".tif")
+ out_tif = out_path_root.with_suffix(".tif")
with rasterio.open(out_tif, "w", **out_meta) as dest:
dest.write(out_img)
- # read in the tile we have just saved
- clipped = rasterio.open(out_tif)
+ if overlapping_crowns is not None:
+ return data, out_path_root, overlapping_crowns, minx, miny, buffer
- # read it as an array
- arr = clipped.read()
+ return data, out_path_root, None, minx, miny, buffer
- # each band of the tiled tiff is a colour!
- r = arr[0]
- g = arr[1]
- b = arr[2]
+ except RasterioIOError as e:
+ logger.error(f"RasterioIOError while applying mask {coords}: {e}")
+ return None
+ except Exception as e:
+ logger.error(f"Error processing tile {tilename} at ({minx}, {miny}): {e}")
+ return None
- # stack up the bands in an order appropriate for saving with cv2, then rescale to the correct 0-255 range
- # for cv2. BGR ordering is correct for cv2 (and detectron2)
- rgb = np.dstack((b, g, r))
- # Some rasters need to have values rescaled to 0-255
- # TODO: more robust check
- if np.max(g) > 255:
- rgb_rescaled = 255 * rgb / 65535
- else:
- # scale to image
- rgb_rescaled = rgb
-
- # save this as png, named with the origin of the specific tile
- # potentially bad practice
- cv2.imwrite(
- str(out_path_root.with_suffix(out_path_root.suffix + ".png").resolve()),
- rgb_rescaled,
- )
+def process_tile_train(
+ img_path: str,
+ out_dir: str,
+ buffer: int,
+ tile_width: int,
+ tile_height: int,
+ dtype_bool: bool,
+ minx,
+ miny,
+ crs,
+ tilename,
+ crowns: gpd.GeoDataFrame,
+ threshold,
+ nan_threshold,
+ mode: str = "rgb",
+ class_column: str = None, # Allow user to specify class column
+) -> None:
+ """Process a single tile for training data.
+
+ Args:
+ img_path: Path to the orthomosaic
+ out_dir: Output directory
+ buffer: Overlapping buffer of tiles in meters (UTM)
+ tile_width: Tile width in meters
+ tile_height: Tile height in meters
+ dtype_bool: Flag to edit dtype to prevent black tiles
+ minx: Minimum x coordinate of tile
+ miny: Minimum y coordinate of tile
+ crs: Coordinate reference system
+ tilename: Name of the tile
+ crowns: Crown polygons as a geopandas dataframe
+ threshold: Min proportion of the tile covered by crowns to be accepted {0,1}
+ nan_theshold: Max proportion of tile covered by nans
+
+ Returns:
+ None
+ """
+ if mode == "rgb":
+ result = process_tile(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs,
+ tilename, crowns, threshold, nan_threshold)
+ elif mode == "ms":
+ result = process_tile_ms(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs,
+ tilename, crowns, threshold, nan_threshold)
+
+ if result is None:
+ # logger.warning(f"Skipping tile at ({minx}, {miny}) due to insufficient data.")
+ return
+
+ data, out_path_root, overlapping_crowns, minx, miny, buffer = result
+
+ overlapping_crowns = overlapping_crowns.explode(index_parts=True)
+ moved = overlapping_crowns.translate(-minx + buffer, -miny + buffer)
+ scalingx = 1 / (data.transform[0])
+ scalingy = -1 / (data.transform[4])
+ moved_scaled = moved.scale(scalingx, scalingy, origin=(0, 0))
+
+ if mode == "rgb":
+ impath = {"imagePath": out_path_root.with_suffix(".png").as_posix()}
+ elif mode == "ms":
+ impath = {"imagePath": out_path_root.with_suffix(".tif").as_posix()}
+
+ try:
+ filename = out_path_root.with_suffix(".geojson")
+ moved_scaled = overlapping_crowns.set_geometry(moved_scaled)
+
+ if class_column is not None:
+ # Ensure we map the selected column to the 'status' field
+ moved_scaled['status'] = moved_scaled[class_column]
+ # Keep only 'status' and geometry
+ moved_scaled = moved_scaled[['geometry', 'status']]
+ else:
+ # Keep only geometry to reduce file size
+ moved_scaled = moved_scaled[['geometry']]
+
+ # Save the result as GeoJSON
+ moved_scaled.to_file(driver="GeoJSON", filename=filename)
+
+ # Add image path info to the GeoJSON file
+ with open(filename, "r") as f:
+ shp = json.load(f)
+ shp.update(impath)
+ with open(filename, "w") as f:
+ json.dump(shp, f)
+ except ValueError:
+ logger.warning("Cannot write empty DataFrame to file.")
+ return
+
+
+# Define a top-level helper function
+def process_tile_train_helper(args):
+ return process_tile_train(*args)
+
+
+def tile_data(
+ img_path: str,
+ out_dir: str,
+ buffer: int = 30,
+ tile_width: int = 200,
+ tile_height: int = 200,
+ crowns: gpd.GeoDataFrame = None,
+ threshold: float = 0,
+ nan_threshold: float = 0.1,
+ dtype_bool: bool = False,
+ mode: str = "rgb",
+ class_column: str = None, # Allow class column to be passed here
+) -> None:
+ """Tiles up orthomosaic and corresponding crowns (if supplied) into training/prediction tiles.
- # select the crowns that intersect the non-buffered central
- # section of the tile using the inner join
- # TODO: A better solution would be to clip crowns to tile extent
- # overlapping_crowns = sjoin(crowns, geo_central, how="inner")
- # Maybe left join to keep information of crowns?
-
- overlapping_crowns = overlapping_crowns.explode(index_parts=True)
-
- # Translate to 0,0 to overlay on png
- moved = overlapping_crowns.translate(-minx + buffer, -miny + buffer)
-
- # scale to deal with the resolution
- scalingx = 1 / (data.transform[0])
- scalingy = -1 / (data.transform[4])
- moved_scaled = moved.scale(scalingx, scalingy, origin=(0, 0))
-
- impath = {"imagePath": out_path_root.with_suffix(out_path_root.suffix + ".png").as_posix()}
-
- # Save as a geojson, a format compatible with detectron2, again named by the origin of the tile.
- # If the box selected from the image is outside of the mapped region due to the image being on a slant
- # then the shp file will have no info on the crowns and hence will create an empty gpd Dataframe.
- # this causes an error so skip creating geojson. The training code will also ignore png so no problem.
- try:
- filename = out_path_root.with_suffix(out_path_root.suffix + ".geojson")
- moved_scaled = overlapping_crowns.set_geometry(moved_scaled)
- moved_scaled.to_file(
- driver="GeoJSON",
- filename=filename,
- )
- with open(filename, "r") as f:
- shp = json.load(f)
- shp.update(impath)
- with open(filename, "w") as f:
- json.dump(shp, f)
- except ValueError:
- print("Cannot write empty DataFrame to file.")
- continue
- # Repeat and want to save crowns before being moved as overlap with lidar data to get the heights
- # can try clean up the code here as lots of reprojecting and resaving but just going to get to
- # work for now
- out_geo_file = out_path_root.parts[-1] + "_geo"
- out_path_geo = out_path / Path(out_geo_file)
- try:
- filename_unmoved = out_path_geo.with_suffix(out_path_geo.suffix + ".geojson")
- overlapping_crowns.to_file(
- driver="GeoJSON",
- filename=filename_unmoved,
- )
- with open(filename_unmoved, "r") as f:
- shp = json.load(f)
- shp.update(impath)
- with open(filename_unmoved, "w") as f:
- json.dump(shp, f)
- except ValueError:
- print("Cannot write empty DataFrame to file.")
- continue
-
- print("Tiling complete")
+ Tiles up large rasters into managable tiles for training and prediction. If crowns are not supplied the function
+ will tile up the entire landscape for prediction. If crowns are supplied the function will tile these with the image
+ and skip tiles without a minimum coverage of crowns. The 'threshold' can be varied to ensure a good coverage of
+ crowns across a traing tile. Tiles that do not have sufficient coverage are skipped.
+
+ Args:
+ img_path: Path to the orthomosaic
+ out_dir: Output directory
+ buffer: Overlapping buffer of tiles in meters (UTM)
+ tile_width: Tile width in meters
+ tile_height: Tile height in meters
+ crowns: Crown polygons as a geopandas dataframe
+ threshold: Min proportion of the tile covered by crowns to be accepted {0,1}
+ nan_theshold: Max proportion of tile covered by nans
+ dtype_bool: Flag to edit dtype to prevent black tiles
+
+ Returns:
+ None
+ """
+ out_path = Path(out_dir)
+ os.makedirs(out_path, exist_ok=True)
+ tilename = Path(img_path).stem
+ with rasterio.open(img_path) as data:
+ crs = data.crs.to_epsg() # Update CRS handling to avoid deprecated syntax
+
+ tile_args = [
+ (img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns,
+ threshold, nan_threshold, mode, class_column)
+ for minx in np.arange(ceil(data.bounds[0]) + buffer, data.bounds[2] - tile_width - buffer, tile_width, int)
+ for miny in np.arange(ceil(data.bounds[1]) + buffer, data.bounds[3] - tile_height - buffer, tile_height,
+ int)
+ ]
+
+ with concurrent.futures.ProcessPoolExecutor() as executor: # Use ProcessPoolExecutor here
+ list(executor.map(process_tile_train_helper, tile_args))
+
+ logger.info("Tiling complete")
def image_details(fileroot):
@@ -429,32 +469,41 @@ def is_overlapping_box(test_boxes_array, train_box):
return False
-def record_data(crowns,
- out_dir,
- column='status'):
- """Function that will record a list of classes into a file that can be readed during training.
+def record_classes(crowns: gpd.GeoDataFrame, out_dir: str, column: str = 'status', save_format: str = 'json'):
+ """Function that records a list of classes into a file that can be read during training.
Args:
crowns: gpd dataframe with the crowns
out_dir: directory to save the file
column: column name to get the classes from
+ save_format: format to save the file ('json' or 'pickle')
Returns:
None
"""
-
+ # Extract unique class names from the specified column
list_of_classes = crowns[column].unique().tolist()
- print("**The list of classes are:**")
- print(list_of_classes)
- print("**The list has been saved to the out_dir**")
+ # Sort the list of classes in alphabetical order
+ list_of_classes.sort()
- # Write it into file "classes.txt"
- out_tif = out_dir + 'classes.txt'
- f = open(out_tif, "w")
- for i in list_of_classes:
- f.write("%s\n" % i)
- f.close()
+ # Create a dictionary for class-to-index mapping
+ class_to_idx = {class_name: idx for idx, class_name in enumerate(list_of_classes)}
+
+ # Save the class-to-index mapping to disk
+ out_path = Path(out_dir)
+ os.makedirs(out_path, exist_ok=True)
+
+ if save_format == 'json':
+ with open(out_path / 'class_to_idx.json', 'w') as f:
+ json.dump(class_to_idx, f)
+ elif save_format == 'pickle':
+ with open(out_path / 'class_to_idx.pkl', 'wb') as f:
+ pickle.dump(class_to_idx, f)
+ else:
+ raise ValueError("Unsupported save format. Use 'json' or 'pickle'.")
+
+ print(f"Classes saved as {save_format} file: {class_to_idx}")
def to_traintest_folders( # noqa: C901
@@ -491,7 +540,8 @@ def to_traintest_folders( # noqa: C901
Path(out_dir / "train").mkdir(parents=True, exist_ok=True)
Path(out_dir / "test").mkdir(parents=True, exist_ok=True)
- file_names = tiles_dir.glob("*.png")
+ # file_names = tiles_dir.glob("*.png")
+ file_names = tiles_dir.glob("*.geojson")
file_roots = [item.stem for item in file_names]
num = list(range(0, len(file_roots)))
@@ -540,29 +590,56 @@ def to_traintest_folders( # noqa: C901
if __name__ == "__main__":
- # Right let"s test this first with Sepilok 10cm resolution, then I need to try it with 50cm resolution.
- img_path = "/content/drive/Shareddrives/detectreeRGB/benchmark/Ortho2015_benchmark/P4_Ortho_2015.tif"
- crown_path = "gdrive/MyDrive/JamesHirst/NY/Buffalo/Buffalo_raw_data/all_crowns.shp"
- out_dir = "./"
- # Read in the tiff file
- # data = img_data.open(img_path)
- # Read in crowns
- data = rasterio.open(img_path)
+ # Define paths to the input data
+ img_path = "/path/to/your/orthomosaic.tif" # Path to your input orthomosaic file
+ crown_path = "/path/to/your/crown_shapefile.shp" # Path to the shapefile containing crowns
+ out_dir = "/path/to/output/directory" # Directory where you want to save the tiled output
+
+ # Optional parameters for tiling and processing
+ buffer = 30 # Overlap between tiles (in meters)
+ tile_width = 200 # Tile width (in meters)
+ tile_height = 200 # Tile height (in meters)
+ nan_threshold = 0.1 # Max proportion of tile that can be NaN before it's discarded
+ threshold = 0.5 # Minimum crown coverage per tile for it to be kept (0-1)
+ dtype_bool = False # Change dtype to uint8 to avoid black tiles
+ mode = "rgb" # Use 'rgb' for regular 3-channel imagery, 'ms' for multispectral
+ class_column = "species" # Column in the crowns file to use as the class label
+
+ # Read in the crowns
crowns = gpd.read_file(crown_path)
- print(
- "shape =",
- data.shape,
- ",",
- data.bounds,
- "and number of bands =",
- data.count,
- ", crs =",
- data.crs,
+
+ # Record the classes and save the class mapping
+ record_classes(
+ crowns=crowns, # Geopandas dataframe with crowns
+ out_dir=out_dir, # Output directory to save class mapping
+ column=class_column, # Column used for classes
+ save_format='json' # Choose between 'json' or 'pickle'
)
- buffer = 20
- tile_width = 200
- tile_height = 200
+ # Perform the tiling, ensuring the selected class column is used
+ tile_data(
+ img_path=img_path,
+ out_dir=out_dir,
+ buffer=buffer,
+ tile_width=tile_width,
+ tile_height=tile_height,
+ crowns=crowns,
+ threshold=threshold,
+ nan_threshold=nan_threshold,
+ dtype_bool=dtype_bool,
+ mode=mode,
+ class_column=class_column # Use the selected class column (e.g., 'species', 'status')
+ )
+
+ # Split the data into training and validation sets (optional)
+ # This can be used for train/test folder creation based on the generated tiles
+ to_traintest_folders(
+ tiles_folder=out_dir, # Directory where tiles are saved
+ out_folder="/path/to/final/data/output", # Final directory for train/test data
+ test_frac=0.2, # Fraction of data to be used for testing
+ folds=5, # Number of folds (optional, can be set to 1 for no fold splitting)
+ strict=True, # Ensure no overlap between train/test tiles
+ seed=42 # Set seed for reproducibility
+ )
- tile_data_train(data, out_dir, buffer, tile_width, tile_height, crowns)
- to_traintest_folders(folds=5)
+ logger.info("Tiling process completed successfully!")
diff --git a/detectree2/tests/test_preprocessing.py b/detectree2/tests/test_preprocessing.py
index 6f2eb9cf..49a7d383 100644
--- a/detectree2/tests/test_preprocessing.py
+++ b/detectree2/tests/test_preprocessing.py
@@ -62,11 +62,11 @@ def test_tiling(self):
tile_height = 40
threshold = 0.2
- from detectree2.preprocessing.tiling import tile_data_train
+ from detectree2.preprocessing.tiling import tile_data
out_dir = os.path.join(out_dir, "tiles")
- tile_data_train(data, out_dir, buffer, tile_width, tile_height, crowns, threshold)
+ tile_data(img_path, out_dir, buffer, tile_width, tile_height, crowns, threshold)
# TODO: install pytest-depends to automatically order
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 255eefc3..545790ab 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -42,6 +42,7 @@ Accurate delineation of individual tree crowns in tropical forests from aerial R
installation
tutorial
+ tutorial_multi
contributing
using-git
.. _notebooks/contributing_guide
diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst
index 1af9c89a..369d71ec 100644
--- a/docs/source/tutorial.rst
+++ b/docs/source/tutorial.rst
@@ -1,10 +1,12 @@
Tutorial
========
+
This tutorial goes through the steps of single class (tree) detection and
-delineation. A guide to multiclass prediction (e.g. species mapping,
-disease mapping) is coming soon. Example data that can be used in
-this tutorial is available `here `_.
+delineation from RGB and multispectral data. A guide to multiclass prediction
+(e.g. species mapping, disease mapping) is coming soon. Example data that can
+be used in this tutorial is available
+`here `_.
The key steps are:
@@ -40,9 +42,15 @@ If you would just like to make predictions on an orthomosaic with a pre-trained
model from the ``model_garden``, skip to part 4 (Generating landscape
predictions).
+The data preparation and training process for both RGB and multispectral data
+is presented here. The process is similar for both data types but there are
+some key differences that are highlighted. Training a single model on both RGB
+and multispectral data at the same time is not currently supported. Stick to
+one data type per model (or stack the RGB bands with the multispectral bands
+and treat as in the case of multispectral data).
-Preparing data
---------------
+Preparing data (RGB and multispectral)
+--------------------------------------
An example of the recommended file structure when training a new model is as follows:
@@ -58,6 +66,8 @@ An example of the recommended file structure when training a new model is as fol
├── rgb
│ ├── Paracou_RGB_2016_10cm.tif (RGB orthomosaic in local UTM CRS)
│ └── Paracou_RGB_2019.tif (RGB orthomosaic in local UTM CRS)
+ ├── ms
+ │ └── Paracou_MS_2016.tif (Multispectral orthomosaic in local UTM CRS)
└── crowns
└── UpdatedCrowns8.gpkg (Crown polygons readable by geopandas e.g. Geopackage, shapefile)
@@ -65,11 +75,16 @@ Here we have two sites available to train on (Danum and Paracou). Several site d
included in the training and testing phase (but only a single site directory is required).
If available, several RGB orthomosaics can be included in a single site directory (see e.g ``Paracou -> RGB``).
+For Paracou, we also have a multispectral scan available (5-bands). For this data, the ``mode`` parameter in the
+``tile_data`` function should be set to ``"ms"``. This calls a different routine for tiling the data that retains the
+``.tif`` format instead of converting to ``.png`` as in the case of ``rgb``. This comes at a slight expense of speed
+later on but is necessary to retain all the multispectral information.
+
We call functions to from ``detectree2``'s tiling and training modules.
.. code-block:: python
- from detectree2.preprocessing.tiling import tile_data_train, to_traintest_folders
+ from detectree2.preprocessing.tiling import tile_data, to_traintest_folders
from detectree2.models.train import register_train_data, MyTrainer, setup_cfg
import rasterio
import geopandas as gpd
@@ -83,9 +98,6 @@ Set up the paths to the orthomosaic and corresponding manual crown data.
img_path = site_path + "/rgb/2016/Paracou_RGB_2016_10cm.tif"
crown_path = site_path + "/crowns/220619_AllSpLabelled.gpkg"
- # Read in the tiff file
- data = rasterio.open(img_path)
-
# Read in crowns (then filter by an attribute if required)
crowns = gpd.read_file(crown_path)
crowns = crowns.to_crs(data.crs.data) # making sure CRS match
@@ -98,6 +110,8 @@ The tile size will depend on:
* Available computational resources.
* The detail required on the crown outline.
* If using a pre-trained model, the tile size used in training should roughly match the tile size of predictions.
+* The ``mode`` depends on whether you are tiling 3-band RGB (``mode="rgb"``) data of multispectral data of 4 or more
+bands (``mode="ms"``).
.. code-block:: python
@@ -112,24 +126,28 @@ The tile size will depend on:
The total tile size here is 100 m x 100 m (a 40 m x 40 m core area with a surrounding 30 m buffer that overlaps with
surrounding tiles). Including a buffer is recommended as it allows for tiles that include more training crowns.
-Next we tile the data. The ``tile_data_train`` function will only retain tiles that contain more than the given
-``threshold`` coverage of training data (here 60%). This helps to reduce the chance that the network is trained with
-tiles that contain a large number of unlabelled crowns (which would reduce its sensitivity).
+Next we tile the data. The ``tile_data`` function, when ``crowns`` is supplied, will only retain tiles that contain more
+than the given ``threshold`` coverage of training data (here 60%). This helps to reduce the chance that the network is
+trained with tiles that contain a large number of unlabelled crowns (which would reduce its sensitivity). This value
+should be adjusted depending on the density of crowns in the landscape (e.g. 10% may be more appropriate for savannah
+type systems or urban environments).
.. code-block:: python
- tile_data_train(data, out_dir, buffer, tile_width, tile_height, crowns, threshold)
+ tile_data(img_path, out_dir, buffer, tile_width, tile_height, crowns, threshold, mode="rgb")
.. warning::
- If tiles are outputing as blank images set ``dtype_bool = True`` in the ``tile_data_train`` function. This is a bug
- and we are working on fixing it.
+ If tiles are outputing as blank images set ``dtype_bool = True`` in the ``tile_data`` function. This is a bug
+ and we are working on fixing it. Supplying crown polygons will cause the function to tile for
+ training (as opposed to landscape prediction which is described below).
.. note::
- You will want to relax the ``threshold`` value if your trees are sparsely distributed across your landscape.
- Remember, ``detectree2`` was initially designed for dense, closed canopy forests so some of the default assumptions
- will reflect that.
+ You will want to relax the ``threshold`` value if your trees are sparsely distributed across your landscape or if you
+ want to include non-forest areas (e.g. river, roads). Remember, ``detectree2`` was initially designed for dense,
+ closed canopy forests so some of the default assumptions will reflect that and parameters will need to be adjusted
+ for different systems.
-Send geojsons to train folder (with sub-folders for k-fold cross validation) and test folder.
+Send geojsons to train folder (with sub-folders for k-fold cross validation) and a test folder.
.. code-block:: python
@@ -141,7 +159,7 @@ Send geojsons to train folder (with sub-folders for k-fold cross validation) and
that have any overlap with test tiles (including the buffers), ensuring strict spatial separation of the test data.
However, this can remove a significant proportion of the data available to train on so if validation accuracy is a
sufficient test of model performance ``test_frac`` can be set to ``0`` or set ``strict=False`` (which allows for
- some overlap in the buffers between test and train/val tiles).
+ overlap in the buffers between test and train/val tiles).
The data has now been tiled and partitioned for model training, tuning and evaluation.
@@ -161,8 +179,9 @@ The data has now been tiled and partitioned for model training, tuning and evalu
└── test (test data folder)
-It is advisable to do a visual inspection on the tiles to ensure that the tiling has worked as expected and that crowns
-and images align. This can be done quickly with the inbuilt ``detectron2`` visualisation tools.
+It is recommended to visually inspect the tiles before training to ensure that the tiling has worked as expected and
+that crowns and images align. This can be done with the inbuilt ``detectron2`` visualisation tools. For RGB tiles
+(``.png``), the following code can be used to visualise the training data.
.. code-block:: python
@@ -199,8 +218,61 @@ and images align. This can be done quickly with the inbuilt ``detectron2`` visua
|
-Training a model
-----------------
+Alternatively, with some adaptation the ``detectron2`` visualisation tools can also be used to visualise the
+multispectral (``.tif``) tiles.
+
+.. code-block:: python
+
+ import rasterio
+ from detectron2.utils.visualizer import Visualizer
+ from detectree2.models.train import combine_dicts
+ from detectron2.data import DatasetCatalog, MetadataCatalog
+ from PIL import Image
+ import numpy as np
+ import cv2
+ import matplotlib.pyplot as plt
+ from IPython.display import display
+
+ val_fold = 1
+ name = "Paracou"
+ tiles = "/tilesMS_" + appends + "/train"
+ train_location = "/content/drive/MyDrive/WORK/detectree2/data/" + name + tiles
+ dataset_dicts = combine_dicts(train_location, val_fold)
+ trees_metadata = MetadataCatalog.get(name + "_train")
+
+ # Function to normalize and convert multi-band image to RGB if needed
+ def prepare_image_for_visualization(image):
+ if image.shape[2] == 3:
+ # If the image has 3 bands, assume it's RGB
+ image = np.stack([
+ cv2.normalize(image[:, :, i], None, 0, 255, cv2.NORM_MINMAX)
+ for i in range(3)
+ ], axis=-1).astype(np.uint8)
+ else:
+ # If the image has more than 3 bands, choose the first 3 for visualization
+ image = image[:, :, :3] # Or select specific bands
+ image = np.stack([
+ cv2.normalize(image[:, :, i], None, 0, 255, cv2.NORM_MINMAX)
+ for i in range(3)
+ ], axis=-1).astype(np.uint8)
+
+ return image
+
+ # Visualize each image in the dataset
+ for d in dataset_dicts:
+ with rasterio.open(d["file_name"]) as src:
+ img = src.read() # Read all bands
+ img = np.transpose(img, (1, 2, 0)) # Convert to HWC format
+ img = prepare_image_for_visualization(img) # Normalize and prepare for visualization
+
+ visualizer = Visualizer(img[:, :, ::-1]*10, metadata=trees_metadata, scale=0.5)
+ out = visualizer.draw_dataset_dict(d)
+ image = out.get_image()[:, :, ::-1]
+ display(Image.fromarray(image))
+
+
+Training a model (RGB)
+----------------------
Before training can commence, it is necessary to register the training data. It is possible to set a validation fold for
model evaluation (which can be helpful for tuning models). The validation fold can be changed over different training
@@ -231,7 +303,7 @@ datasets should be tuples containing strings. If just a single site is being use
trains = ("Paracou_train", "Danum_train", "SepilokEast_train", "SepilokWest_train") # Registered train data
tests = ("Paracou_val", "Danum_val", "SepilokEast_val", "SepilokWest_val") # Registered validation data
- out_dir = "/content/drive/Shareddrives/detectree2/220809_train_outputs"
+ out_dir = "/content/drive/Shareddrives/detectree2/240809_train_outputs"
cfg = setup_cfg(base_model, trains, tests, workers = 4, eval_period=100, max_iter=3000, out_dir=out_dir) # update_model arg can be used to load in trained model
@@ -257,7 +329,7 @@ Then set up the configurations as before but with the trained model also supplie
trains = ("Paracou_train", "Danum_train", "SepilokEast_train", "SepilokWest_train") # Registered train data
tests = ("Paracou_val", "Danum_val", "SepilokEast_val", "SepilokWest_val") # Registered validation data
- out_dir = "/content/drive/Shareddrives/detectree2/220809_train_outputs"
+ out_dir = "/content/drive/Shareddrives/detectree2/240809_train_outputs"
cfg = setup_cfg(base_model, trains, tests, trained_model, workers = 4, eval_period=100, max_iter=3000, out_dir=out_dir) # update_model arg used to load in trained model
@@ -267,7 +339,11 @@ Then set up the configurations as before but with the trained model also supplie
model training will converge given the particularities of the data supplied and computational resources available.
Once we are all set up, we can get commence model training. Training will continue until a specified number of
-iterations (``max_iter``) or until model performance is no longer improving ("early stopping" via ``patience``).
+iterations (``max_iter``) or until model performance is no longer improving ("early stopping" via ``patience``). The
+``patience`` parameter sets the number of training epochs to wait for an improvement in validation accuracy before
+stopping training. This is useful for preventing overfitting and saving time. Each time an improved model is found it is
+saved to the output directory.
+
Training outputs, including model weights and training metrics, will be stored in ``out_dir``.
.. code-block::
@@ -281,12 +357,306 @@ Training outputs, including model weights and training metrics, will be stored i
Early stopping is implemented and will be triggered by a sustained failure to improve on the performance of
predictions on the validation fold. This is measured as the AP50 score of the validation predictions.
+Training a model (multispectral)
+--------------------------------
+
+The process for training a multispectral model is similar to that for RGB data but there are some key steps that are
+different. Data will be read from ``.tif`` files of 4 or more bands instead of the 3-band ``.png`` files.
+
+Data should be registered as before:
+
+.. code-block:: python
+
+ from detectree2.models.train import register_train_data, remove_registered_data
+ val_fold = 5
+ appends = "40_30_0.6"
+ site_path = "/content/drive/SharedDrive/detectree2/data/Paracou"
+ train_location = site_path + "/tilesMS_" + appends + "/train/"
+ register_train_data(train_location, "ParacouMS", val_fold)
+
+The number of bands can be checked with rasterio:
+
+.. code-block:: python
+
+ import rasterio
+ import os
+ import glob
+
+ # Read in geotif and assess mean and sd for each band
+ #site_path = "/content/drive/MyDrive/WORK/detectree2/data/Paracou"
+ folder_path = site_path + "/tilesMS_" + appends + "/"
+
+ # Select path of first .tif file
+ img_paths = glob.glob(folder_path + "*.tif")
+ img_path = img_paths[0]
+
+ # Open the raster file
+ with rasterio.open(img_path) as dataset:
+ # Get the number of bands
+ num_bands = dataset.count
+
+ # Print the number of bands
+ print(f'The raster has {num_bands} bands.')
+
+
+Due to the additional bands, we must modify the weights of the first convolutional layer (conv1) to accommodate a
+different number of input channels. This is done with the ``modify_conv1_weights`` function. The extension of the
+``cfg.MODEL.PIXEL_MEAN`` and ``cfg.MODEL.PIXEL_STD`` lists to include the additional bands happens within the
+``setup_cfg`` function when ``num_bands`` is set to a value greater than 3. ``imgmode`` should be set to ``"ms"`` to
+ensure the correct training routines are called.
+
+.. code-block:: python
+
+ from datetime import date
+ from detectron2.modeling import build_model
+ import torch.nn as nn
+ import torch.nn.init as init
+ from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
+ import numpy as np
+ from detectree2.models.train import modify_conv1_weights, MyTrainer, setup_cfg
+
+ # Good idea to keep track of the date if producing multiple models
+ today = date.today()
+ today = today.strftime("%y%m%d")
+
+ names = ["ParacouMS",]
+
+ trains = (names[0] + "_train",)
+ tests = (names[0] + "_val",)
+ out_dir = "/content/drive/SharedDrive/detectree2/models/" + today + "_ParacouMS"
+
+ base_model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml" # Path to the model config
+
+ # Set up the configuration
+ cfg = setup_cfg(base_model, trains, tests, workers = 2, eval_period=50,
+ base_lr = 0.0003, backbone_freeze=0, gamma = 0.9,
+ max_iter=500000, out_dir=out_dir, resize = "rand_fixed", imgmode="ms",
+ num_bands= num_bands) # update_model arg can be used to load in trained model
+
+ # Build the model
+ model = build_model(cfg)
+
+ # Adjust input layer to accept correct number of channels
+ modify_conv1_weights(model, num_input_channels=num_bands)
+
+
+With additional bands, more data is being passed through the network per image so it may be neessary to reduce the
+number of images per batch. Only do this is you a getting warnings/errors about memory usage (e.g.
+``CUDA out of memory``) as it will slow down training.
+
+.. code-block:: python
+
+ cfg.SOLVER.IMS_PER_BATCH = 1
+
+
+Training can now commence as before:
+
+.. code-block::
+
+ trainer = MyTrainer(cfg, patience = 5)
+ trainer.resume_or_load(resume=False)
+ trainer.train()
+
+
+Data augmentation
+-----------------
+
+Data augmentation is a technique used to artificially increase the size of the training dataset by applying random
+transformations to the input data. This can help improve the generalization of the model and reduce overfitting. The
+``detectron2`` library provides a range of data augmentation options that can be used during training. These include
+random flipping, scaling, rotation, and color jittering.
+
+Additionally, resizing of the input data can be applied as an augmentation technique. This can be useful when training
+a model that should be flexible with respect to tile size and resolution.
+
+By default, random rotations and flips will be performed on input images.
+
+.. code-block:: python
+
+ augmentations = [
+ T.RandomRotation(angle=[90, 90], expand=False),
+ T.RandomFlip(prob=0.4, horizontal=True, vertical=False),
+ T.RandomFlip(prob=0.4, horizontal=False, vertical=True),
+ ]
+
+If the input data is RGB, additional augmentations will be applied to adjust the brightness, contrast, saturation, and
+lighting of the images. These augmentations are only available for RGB images and will not be applied to multispectral.
+
+.. code-block:: python
+ # Additional augmentations for RGB images
+ if cfg.IMGMODE == "rgb":
+ augmentations.extend([
+ T.RandomBrightness(0.7, 1.5),
+ T.RandomLighting(0.7),
+ T.RandomContrast(0.6, 1.3),
+ T.RandomSaturation(0.8, 1.4)
+ ])
+
+There are three resizing modes for the input data (1) ``fixed``, (2) ``random``, and (3) ``rand_fixed``. This are set
+in the configuration file (``cfg``) with the `setup_cfg` function.
+
+The ``fixed`` mode will resize the input data to a images width/height of 1000 pixels. This is efficient but may not
+lead to models that transfer well across scales (e.g. if the model is to be used on a range of different resolutions).
+
+.. code-block:: python
+
+ if cfg.RESIZE == "fixed":
+ augmentations.append(T.ResizeShortestEdge([1000, 1000], 1333))
+
+The ``random`` mode will randomly resize (and resample to change the resolutions) the input data to between 0.6 and 1.4
+times the original height/width. This can help the model learn to detect objects at different scales and from images of
+different resolutions (and sensors).
+
+.. code-block:: python
+
+ elif cfg.RESIZE == "random":
+ size = None
+ for i, datas in enumerate(DatasetCatalog.get(cfg.DATASETS.TRAIN[0])):
+ location = datas['file_name']
+ try:
+ # Try to read with cv2 (for RGB images)
+ img = cv2.imread(location)
+ if img is not None:
+ size = img.shape[0]
+ else:
+ # Fall back to rasterio for multi-band images
+ with rasterio.open(location) as src:
+ size = src.height # Assuming square images
+ except Exception as e:
+ # Handle any errors that occur during loading
+ print(f"Error loading image {location}: {e}")
+ continue
+ break
+
+ if size:
+ print("ADD RANDOM RESIZE WITH SIZE = ", size)
+ augmentations.append(T.ResizeScale(0.6, 1.4, size, size))
+
+The ``rand_fixed`` mode constrains the random resizing to a fixed pixel width/height range (regardless of the resolution
+of the input data). This can help to speed up training if the input tiles are high resolution and pushing up against
+available memory limits. It retains the benefits of random resizing but constrains the range of possible sizes.
+
+.. code-block:: python
+
+ elif cfg.RESIZE == "rand_fixed":
+ augmentations.append(T.ResizeScale(0.6, 1.4, 1000, 1000))
+
+Which resizing option is selected depends on the problem at hand. A more precise delineation can be generated if high
+resolution images are retained but this comes at the cost of increased memory usage and slower training times. If the
+model is to be used on a range of different resolutions, random resizing can help the model learn to detect objects at
+different scales.
+
+
+Post-training (check training convergence)
+------------------------------------------
+
+It is important to check that the model has converged and is not overfitting. This can be done by plotting the training
+and validation loss over time. The ``detectron2`` training routine will output a ``metrics.json`` file that can be used
+to plot the training and validation loss. The following code can be used to plot the loss:
+
+.. code-block:: python
+
+ import json
+ import matplotlib.pyplot as plt
+ from detectree2.models.train import load_json_arr
+
+ #out_dir = "/content/drive/Shareddrives/detectree2/models/230103_resize_full"
+ experiment_folder = out_dir
+
+ experiment_metrics = load_json_arr(experiment_folder + '/metrics.json')
+
+ plt.plot(
+ [x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
+ [x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x], label='Total Validation Loss', color='red')
+ plt.plot(
+ [x['iteration'] for x in experiment_metrics if 'total_loss' in x],
+ [x['total_loss'] for x in experiment_metrics if 'total_loss' in x], label='Total Training Loss')
+
+ plt.legend(loc='upper right')
+ plt.title('Comparison of the training and validation loss of detectree2')
+ plt.ylabel('Total Loss')
+ plt.xlabel('Number of Iterations')
+ plt.show()
+
+.. image:: ../../report/figures/train_val_loss.png
+ :width: 400
+ :alt: Train and validation loss
+ :align: center
+
+|
+Training loss and validation loss decreased over time. As training continued, the validation loss flattened whereas the
+training loss continued to decrease. The ``patience`` mechanism prevented training from continuing after 3000 iterations
+preventing overfitting. If validation loss is substantially higher than training loss, the model may be overfitted.
+
+To understand how the segmentation performance improves through training, it is also possible to plot the AP50 score
+(see below for definition) over the iterations. This can be done with the following code:
+
+
+.. code-block:: python
+
+ plt.plot(
+ [x['iteration'] for x in experiment_metrics if 'validation_loss' in x],
+ [x['validation_loss'] for x in experiment_metrics if 'validation_loss' in x], label='Total Validation Loss', color='red')
+ plt.plot(
+ [x['iteration'] for x in experiment_metrics if 'total_loss' in x],
+ [x['total_loss'] for x in experiment_metrics if 'total_loss' in x], label='Total Training Loss')
+
+ plt.legend(loc='upper right')
+ plt.title('Comparison of the training and validation loss of detectree2')
+ plt.ylabel('Total Loss')
+ plt.xlabel('Number of Iterations')
+ plt.show()
+
+.. image:: ../../report/figures/val_AP50.png
+ :width: 400
+ :alt: AP50 score
+ :align: center
+|
+
+Performance metrics
+-------------------
+
+In instance segmentation, **AP50** refers to the **Average Precision** at an Intersection over Union (IoU) threshold of
+**50%**.
+
+- **Precision**: Precision is the ratio of correctly predicted positive objects (true positives) to all predicted
+ bjects (both true positives and false positives).
+
+ - Formula: :math:`\text{Precision} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Positives}}`
+
+- **Recall**: Recall is the ratio of correctly predicted positive objects (true positives) to all actual positive
+objects in the ground truth (true positives and false negatives).
+
+ - Formula: :math:`\text{Recall} = \frac{\text{True Positives}}{\text{True Positives} + \text{False Negatives}}`
+
+- **Average Precision (AP)**: AP is a common metric used to evaluate the performance of object detection and instance
+segmentation models. It represents the precision of the model across various recall levels. In simpler terms, it is a
+combination of the model's ability to correctly detect objects and how complete those detections are.
+
+- **IoU (Intersection over Union)**: IoU measures the overlap between the predicted segmentation mask (or bounding box
+in object detection) and the ground truth mask. It is calculated as the area of overlap divided by the area of union
+between the predicted and true masks.
+
+- **AP50**: Specifically, **AP50** computes the average precision for all object classes at a threshold of **50% IoU**.
+This means that a predicted object is considered correct (a true positive) if the IoU between the predicted and ground
+truth masks is greater than or equal to 0.5 (50%). It is a relatively lenient threshold, focusing on whether the
+detected objects overlap reasonably with the ground truth, even if the boundaries aren't perfectly aligned.
+
+In summary, AP50 evaluates how well a model detects objects with a 50% overlap between the predicted and ground truth
+masks in instance segmentation tasks.
+
+.. image:: ../../report/figures/IoU_AP.png
+ :width: 400
+ :alt: IoU and AP illustration
+ :align: center
+
Evaluating model performance
----------------------------
Coming soon! See Colab notebook for example routine (``detectree2/notebooks/colab/evaluationJB.ipynb``).
+
Generating landscape predictions
--------------------------------
@@ -312,8 +682,7 @@ can discard partial the crowns predicted at the edge of tiles.
site_path = "/content/drive/Shareddrives/detectree2/data/BCI_50ha"
img_path = site_path + "/rgb/2015.06.10_07cm_ORTHO.tif"
tiles_path = site_path + "/tilespred/"
- # Read in the geotiff
- data = rasterio.open(img_path)
+
# Location of trained model
model_path = "/content/drive/Shareddrives/detectree2/models/220629_ParacouSepilokDanum_JB.pth"
@@ -321,11 +690,12 @@ can discard partial the crowns predicted at the edge of tiles.
buffer = 30
tile_width = 40
tile_height = 40
- tile_data(data, tiles_path, buffer, tile_width, tile_height, dtype_bool = True)
+ tile_data(img_path, tiles_path, buffer, tile_width, tile_height, dtype_bool = True)
.. warning::
If tiles are outputing as blank images set ``dtype_bool = True`` in the ``tile_data`` function. This is a bug
- and we are working on fixing it.
+ and we are working on fixing it. Avoid supplying crown polygons otherwise the function will run as if it is tiling
+ for training.
To download a pre-trained model from the ``model_garden`` you can run ``wget`` on the package repo
diff --git a/docs/source/tutorial_multi.rst b/docs/source/tutorial_multi.rst
index 21b1fc72..125e06ef 100644
--- a/docs/source/tutorial_multi.rst
+++ b/docs/source/tutorial_multi.rst
@@ -4,9 +4,10 @@ Tutorial (multiclass)
This tutorial goes through the steps of multiclass detection and
delineation (e.g. species mapping, disease mapping). A guide to single
class prediction is available
-`here `_. The multiclass
-process is more complicated than single class prediction as the classes need to
-be correctly encoded in the data.
+`here `_ - this covers
+more detail on the fundamentals of training and should be reviewed before this
+tutorial. The multiclassprocess is slightly more intricate than single class
+prediction as the classes need to be correctly encoded and caried throughout the pipeline.
The key steps are:
@@ -15,4 +16,161 @@ The key steps are:
3. Evaluating model performance
4. Making landscape level predictions
-THE REST OF THIS TUTORIAL IS UNDER CONSTRUCTION
+
+Preparing data (RGB and multispectral)
+--------------------------------------
+
+Data can be prepared in a similar way to the single class case but the classes
+and their order (mapping) need to be saved so that they can be accessed
+consistently across training and prediction. The classes are saved in a json
+file with the class names and their indices. The indices are used to encode
+the classes in the training.
+
+.. code-block:: python
+
+ import rasterio
+ import geopandas as gpd
+
+ # Load the data
+ base_dir = "/content/drive/MyDrive/SHARED/detectree2"
+
+ site_path = base_dir + "/data/Danum_lianas"
+
+ # Set the path to the orthomosaic and the crown shapefile
+ img_path = site_path + "/rgb/2017_50ha_Ortho_reproject.tif"
+ crown_path = site_path + "/crowns/Danum_lianas_full2017.gpkg"
+
+ # Here, we set the name of the output folder.
+ # Set tiling parameters
+ buffer = 30
+ tile_width = 40
+ tile_height = 40
+ threshold = 0.6
+ appends = str(tile_width) + "_" + str(buffer) + "_" + str(threshold)
+
+ out_dir = site_path + "/tilesClass_" + appends + "/"
+
+ # Read in the tiff file
+ data = rasterio.open(img_path)
+
+ # Read in crowns (then filter by an attribute?)
+ crowns = gpd.read_file(crown_path)
+ crowns = crowns.to_crs(data.crs.data)
+ print(crowns.head())
+
+ class_column = 'status'
+
+ # Record the classes and save the class mapping
+ record_classes(
+ crowns=crowns, # Geopandas dataframe with crowns
+ out_dir=out_dir, # Output directory to save class mapping
+ column=class_column, # Column to be used for classes
+ save_format='json' # Choose between 'json' or 'pickle'
+ )
+
+
+The class mapping has been saved in the output directory as a json file called
+``class_to_idx.json``. This file can now be accessed to encode the classes in
+training and prediction steps.
+
+To tile the data, we call the ``tile_data`` function as we did in the single
+class case except now we point to the column name of the classes.
+
+.. code-block:: python
+
+ # Tile the data
+ tile_data(
+ img_path=img_path, # Path to the orthomosaic
+ out_dir=out_dir, # Output directory to save tiles
+ buffer=buffer, # Buffer around the crowns
+ tile_width=tile_width, # Width of the tiles
+ tile_height=tile_height, # Height of the tiles
+ crowns=crowns, # Geopandas dataframe with crowns
+ threshold=threshold, # Threshold for the buffer
+ class_column=class_column, # Column to be used for classes
+ )
+
+ # Split the data into training and validation sets
+ to_traintest_folders(
+ tiles_folder=out_dir, # Directory where tiles are saved
+ out_folder=out_dir, # Final directory for train/test data
+ test_frac=0, # Fraction of data to be used for testing
+ folds=5, # Number of folds (optional, can be set to 1 for no fold splitting)
+ strict=False, # Ensure no overlap between train/test tiles
+ seed=42 # Set seed for reproducibility
+ )
+
+
+Training models
+---------------
+
+To train with multiple classes, we need to ensure that the classes are
+registered correctly in the dataset catalogue. This can be done with the class
+mapping file that was saved in the previous step. The class mapping file will
+set the classes and their indices.
+
+.. code-block:: python
+
+ from detectree2.models.train import register_train_data, remove_registered_data, setup_cfg, MyTrainer
+ from detectree2.preprocessing.tiling import load_class_mapping
+
+ # Set validation fold
+ val_fold = 5
+
+ site_path = base_dir + "/data/Danum_lianas"
+ train_dir = site_path + "/tilesClass_40_30_0.6/train"
+ class_mapping_file = site_path + "/tilesClass_40_30_0.6/" + "/class_to_idx.json"
+ data_name = "DanumLiana"
+
+ register_train_data(train_dir, data_name, val_fold=val_fold, class_mapping_file=class_mapping_file)
+
+
+Now the data is registered, should generate the configuration (`cfg`) and train
+the model. By passing the class mapping file to the configuration set up, the
+`cfg` will be register the number of classes.
+
+.. code-block:: python
+
+ from detectron2.modeling import build_model
+ from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
+ import numpy as np
+ from datetime import date
+
+
+ today = date.today()
+ today = today.strftime("%y%m%d")
+
+ names = [data_name,]
+
+ trains = (names[0] + "_train",)
+ tests = (names[0] + "_val",)
+ out_dir = "/content/drive/MyDrive/WORK/detectree2/models/" + today + "_Danum_lianas"
+
+ base_model = "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml" # Path to the model config
+
+ # When you increase the number of channels (i.e., the number of filters) in a Convolutional Neural Network (CNN), the general recommendation is to decrease the learning rate
+ lrs = [0.03, 0.003, 0.0003, 0.00003]
+
+ # Set up model configuration, using the class mapping to determine the number of classes
+ cfg = setup_cfg(
+ base_model="COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml",
+ trains=trains,
+ tests=tests,
+ max_iter=500000,
+ eval_period=50,
+ base_lr=lrs[0],
+ out_dir=out_dir,
+ resize="rand_fixed",
+ class_mapping_file=class_mapping_file # Optional
+ )
+
+ # Train the model
+ trainer = MyTrainer(cfg, patience=5)
+ trainer.resume_or_load(resume=False)
+ trainer.train()
+
+
+Landscape predictions
+---------------------
+
+COMING SOON
\ No newline at end of file
diff --git a/report/figures/IoU_AP.png b/report/figures/IoU_AP.png
new file mode 100644
index 00000000..603778b1
Binary files /dev/null and b/report/figures/IoU_AP.png differ
diff --git a/report/figures/train_val_loss.png b/report/figures/train_val_loss.png
new file mode 100644
index 00000000..2f9d0d18
Binary files /dev/null and b/report/figures/train_val_loss.png differ
diff --git a/report/figures/val_AP50.png b/report/figures/val_AP50.png
new file mode 100644
index 00000000..ec5d8ec9
Binary files /dev/null and b/report/figures/val_AP50.png differ
diff --git a/setup.py b/setup.py
index 176f2b86..bd80e31d 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,7 @@
setup(
name="detectree2",
- version="1.0.8",
+ version="1.1.0",
author="James G. C. Ball",
author_email="ball.jgc@gmail.com",
description="Detectree packaging",
@@ -22,7 +22,7 @@
"shapely",
"geopandas",
"rasterio==1.3a3",
- "fiona",
+ "fiona==1.9.6",
"pycrs",
"descartes",
"detectron2@git+https://github.com/facebookresearch/detectron2.git",