Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ICNet for image segmentation. #975

Merged
merged 3 commits into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions fluid/icnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。


## 代码结构
```
├── network.py # 网络结构定义脚本
├── train.py # 训练任务脚本
├── eval.py # 评估脚本
├── infer.py # 预测脚本
├── cityscape.py # 数据预处理脚本
└── utils.py # 定义通用的函数
```

## 简介

Image Cascade Network(ICNet)主要用于图像实时语义分割。相较于其它压缩计算的方法,ICNet即考虑了速度,也考虑了准确性。
ICNet的主要思想是将输入图像变换为不同的分辨率,然后用不同计算复杂度的子网络计算不同分辨率的输入,然后将结果合并。ICNet由三个子网络组成,计算复杂度高的网络处理低分辨率输入,计算复杂度低的网络处理分辨率高的网络,通过这种方式在高分辨率图像的准确性和低复杂度网络的效率之间获得平衡。

整个网络结构如下:

<p align="center">
<img src="images/icnet.png" width="620" hspace='10'/> <br/>
<strong>图 1</strong>
</p>


## 数据准备



本文采用Cityscape数据集,请前往[Cityscape官网](https://www.cityscapes-dataset.com)注册下载。下载数据之后,按照[这里](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py#L3)的说明和工具处理数据。
处理之后的数据
```
data/cityscape/
|-- gtFine
| |-- test
| |-- train
| `-- val
|-- leftImg8bit
| |-- test
| |-- train
| `-- val
|-- train.list
`-- val.list
```
其中,train.list和val.list分别是用于训练和测试的列表文件,第一列为输入图像数据,第二列为标注数据,两列用空格分开。示例如下:
```
leftImg8bit/train/stuttgart/stuttgart_000021_000019_leftImg8bit.png gtFine/train/stuttgart/stuttgart_000021_000019_gtFine_labelTrainIds.png
leftImg8bit/train/stuttgart/stuttgart_000072_000019_leftImg8bit.png gtFine/train/stuttgart/stuttgart_000072_000019_gtFine_labelTrainIds.png
```
完成数据下载和准备后,需要修改`cityscape.py`脚本中对应的数据地址。

## 模型训练与预测

### 训练
执行以下命令进行训练:
```
python train.py --batch_size=16 --use_gpu=True
```
使用以下命令获得更多使用说明:
```
python train.py --help
```
训练过程中会根据用户的设置,输出训练集上每个网络分支的`loss`, 示例如下:
```
Iter[0]; train loss: 2.338; sub4_loss: 3.367; sub24_loss: 4.120; sub124_loss: 0.151
```
### 测试
执行以下命令在`Cityscape`测试数据集上进行测试:
```
python eval.py --model_path="./model/" --use_gpu=True
```
需要通过选项`--model_path`指定模型文件。
测试脚本的输出的评估指标为[mean IoU]()。

### 预测
执行以下命令对指定的数据进行预测:
```
python infer.py \
--model_path="./model" \
--images_path="./data/cityscape/" \
--images_list="./data/cityscape/infer.list"
```
通过选项`--images_list`指定列表文件,列表文件中每一行为一个要预测的图片的路径。
预测结果默认保存到当前路径下的`output`文件夹下。

## 实验结果
图2为在`CityScape`训练集上的训练的Loss曲线:

<p align="center">
<img src="images/train_loss.png" width="620" hspace='10'/> <br/>
<strong>图 2</strong>
</p>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有mean IoU的结果吗?


在训练集上训练,在validation数据集上验证的结果为:mean_IoU=67.0%(论文67.7%)

图3是使用`infer.py`脚本预测产生的结果示例,其中,第一行为输入的原始图片,第二行为人工的标注,第三行为我们模型计算的结果。
<p align="center">
<img src="images/result.png" width="620" hspace='10'/> <br/>
<strong>图 3</strong>
</p>

## 其他信息
|数据集 | pretrained model |
|---|---|
|CityScape | [Model]()[md: ] |

## 参考

- [ICNet for Real-Time Semantic Segmentation on High-Resolution Images](https://arxiv.org/abs/1704.08545)
236 changes: 236 additions & 0 deletions fluid/icnet/cityscape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""Reader for Cityscape dataset.
"""
import os
import cv2
import numpy as np
import paddle.v2 as paddle

DATA_PATH = "./data/cityscape"
TRAIN_LIST = DATA_PATH + "/train.list"
TEST_LIST = DATA_PATH + "/val.list"
IGNORE_LABEL = 255
NUM_CLASSES = 19
TRAIN_DATA_SHAPE = (3, 720, 720)
TEST_DATA_SHAPE = (3, 1024, 2048)
IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)


def train_data_shape():
return TRAIN_DATA_SHAPE


def test_data_shape():
return TEST_DATA_SHAPE


def num_classes():
return NUM_CLASSES


class DataGenerater:
def __init__(self, data_list, mode="train", flip=True, scaling=True):
self.flip = flip
self.scaling = scaling
self.image_label = []
with open(data_list, 'r') as f:
for line in f:
image_file, label_file = line.strip().split(' ')
self.image_label.append((image_file, label_file))

def create_train_reader(self, batch_size):
"""
Create a reader for train dataset.
"""

def reader():
np.random.shuffle(self.image_label)
images = []
labels_sub1 = []
labels_sub2 = []
labels_sub4 = []
count = 0
for image, label in self.image_label:
image, label_sub1, label_sub2, label_sub4 = self.process_train_data(
image, label)
count += 1
images.append(image)
labels_sub1.append(label_sub1)
labels_sub2.append(label_sub2)
labels_sub4.append(label_sub4)
if count == batch_size:
yield self.mask(
np.array(images),
np.array(labels_sub1),
np.array(labels_sub2), np.array(labels_sub4))
images = []
labels_sub1 = []
labels_sub2 = []
labels_sub4 = []
count = 0
if images:
yield self.mask(
np.array(images),
np.array(labels_sub1),
np.array(labels_sub2), np.array(labels_sub4))

return reader

def create_test_reader(self):
"""
Create a reader for test dataset.
"""

def reader():
for image, label in self.image_label:
image, label = self.load(image, label)
image = paddle.image.to_chw(image)[np.newaxis, :]
label = label[np.newaxis, :, :, np.newaxis].astype("float32")
label_mask = np.where((label != IGNORE_LABEL).flatten())[
0].astype("int32")
yield image, label, label_mask

return reader

def process_train_data(self, image, label):
"""
Process training data.
"""
image, label = self.load(image, label)
if self.flip:
image, label = self.random_flip(image, label)
if self.scaling:
image, label = self.random_scaling(image, label)
image, label = self.resize(image, label, out_size=TRAIN_DATA_SHAPE[1:])
label = label.astype("float32")
label_sub1 = paddle.image.to_chw(self.scale_label(label, factor=4))
label_sub2 = paddle.image.to_chw(self.scale_label(label, factor=8))
label_sub4 = paddle.image.to_chw(self.scale_label(label, factor=16))
image = paddle.image.to_chw(image)
return image, label_sub1, label_sub2, label_sub4

def load(self, image, label):
"""
Load image from file.
"""
image = paddle.image.load_image(
DATA_PATH + "/" + image, is_color=True).astype("float32")
image -= IMG_MEAN
label = paddle.image.load_image(
DATA_PATH + "/" + label, is_color=False).astype("float32")
return image, label

def random_flip(self, image, label):
"""
Flip image and label randomly.
"""
r = np.random.rand(1)
if r > 0.5:
image = paddle.image.left_right_flip(image, is_color=True)
label = paddle.image.left_right_flip(label, is_color=False)
return image, label

def random_scaling(self, image, label):
"""
Scale image and label randomly.
"""
scale = np.random.uniform(0.5, 2.0, 1)[0]
h_new = int(image.shape[0] * scale)
w_new = int(image.shape[1] * scale)
image = cv2.resize(image, (w_new, h_new))
label = cv2.resize(
label, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
return image, label

def padding_as(self, image, h, w, is_color):
"""
Padding image.
"""
pad_h = max(image.shape[0], h) - image.shape[0]
pad_w = max(image.shape[1], w) - image.shape[1]
if is_color:
return np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), 'constant')
else:
return np.pad(image, ((0, pad_h), (0, pad_w)), 'constant')

def resize(self, image, label, out_size):
"""
Resize image and label by padding or cropping.
"""
ignore_label = IGNORE_LABEL
label = label - ignore_label
if len(label.shape) == 2:
label = label[:, :, np.newaxis]
combined = np.concatenate((image, label), axis=2)
combined = self.padding_as(
combined, out_size[0], out_size[1], is_color=True)
combined = paddle.image.random_crop(
combined, out_size[0], is_color=True)
image = combined[:, :, 0:3]
label = combined[:, :, 3:4] + ignore_label
return image, label

def scale_label(self, label, factor):
"""
Scale label according to factor.
"""
h = label.shape[0] / factor
w = label.shape[1] / factor
return cv2.resize(
label, (h, w), interpolation=cv2.INTER_NEAREST)[:, :, np.newaxis]

def mask(self, image, label0, label1, label2):
"""
Get mask for valid pixels.
"""
mask_sub1 = np.where(((label0 < (NUM_CLASSES + 1)) & (
label0 != IGNORE_LABEL)).flatten())[0].astype("int32")
mask_sub2 = np.where(((label1 < (NUM_CLASSES + 1)) & (
label1 != IGNORE_LABEL)).flatten())[0].astype("int32")
mask_sub4 = np.where(((label2 < (NUM_CLASSES + 1)) & (
label2 != IGNORE_LABEL)).flatten())[0].astype("int32")
return image.astype(
"float32"), label0, mask_sub1, label1, mask_sub2, label2, mask_sub4


def train(batch_size=32, flip=True, scaling=True):
"""
Cityscape training set reader.
It returns a reader, in which each result is a batch with batch_size samples.

:param batch_size: The batch size of each result return by the reader.
:type batch_size: int
:param flip: Whether flip images randomly.
:type batch_size: bool
:param scaling: Whether scale images randomly.
:type batch_size: bool
:return: Training reader.
:rtype: callable
"""
reader = DataGenerater(
TRAIN_LIST, flip=flip, scaling=scaling).create_train_reader(batch_size)
return reader


def test():
"""
Cityscape validation set reader.
It returns a reader, in which each result is a sample.

:return: Training reader.
:rtype: callable
"""
reader = DataGenerater(TEST_LIST).create_test_reader()
return reader


def infer(image_list=TEST_LIST):
"""
Infer set reader.
It returns a reader, in which each result is a sample.

:param image_list: The image list file in which each line is a path of image to be infered.
:type batch_size: str
:return: Infer reader.
:rtype: callable
"""
reader = DataGenerater(image_list).create_test_reader()
Loading