This reprository demonstrates training a Mask-RCNN network to perform instance segmentation and running an inference on a few test images. The network is trained on two classes - 'Person' and 'Car' using the COCO 2014 dataset.
This repository has the following minimum requirements-
Running the network training code from this repository, needs the following -
Instance segmentation is computer vision technique which involves detection and localization of objects while simultaneously generating a segmentation map for each of the detected instances.
There are several deep learning algorithms for instance segmentation, the most popular being Mask-RCNN. The Mask-RCNN network belongs to RCNN family of networks and builds on the Faster-RCNN network to perform pixel level segmentation on the detected objects. The Mask-RCNN network uses a Faster-RCNN network with -
- A more accurate sub-pixel level ROI pooling layer - ROIAlign
- A Mask branch for pixel level object segmentation.
- Add path to the source directory -
- Download and load the pre-trained network -
datadir = tempdir;
url = '';
helper.downloadTrainedMaskRCNN(url, datadir);
pretrained = load(fullfile(datadir, 'maskrcnn_pretrained_person_car.mat'));
net =;
% Extract Mask segmentation sub-network
maskSubnet = helper.extractMaskNetwork(net);
Perform prediction using the detectMaskRCNN
img = imread('visionteam.jpg');
executionEnvironment = "gpu";
[boxes, scores, labels, masks] = detectMaskRCNN(net, maskSubnet, img, params, executionEnvironment);
% Visualize results
overlayedImage = insertObjectMask(img, masks);
figure, imshow(overlayedImage)
showShape("rectangle", gather(boxes), "Label", labels, "LineColor",'r')
To train your Mask-RCNN network, follow the steps outlined in the following examples.
- MaskRCNNTrainingExample.mlx - Example showing how to train the Mask-RCNN network.
- MaskRCNNParallelTrainingExample.mlx - Example showing how to train the Mask-RCNN network on multiple GPUs.
This repository provides the following files to help create and train Mask-RCNN networks-
- src/CreateMaskRCNNConfig - Function to create a configuration structure for Mask-RCNN.
- src/CreateMaskRCNN - Function to create a Mask-RCNN netwwork with a resnet-101 backbone.
- detectMaskRCNN - Function to perform inference on Mask-RCNN network.
K. He, G. Gkioxari, P. Dolla ́r, and R. Girshick. Mask R-CNN. In ICCV, 2017.
- "This network was trained using the COCO dataset, which was collected by the COCO Consortium ("