Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PaddlePaddle Hackathon] add DenseNet #36077

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/paddle/tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def test_resnet101(self):
def test_resnet152(self):
self.models_infer('resnet152')

def test_densenet121(self):
self.models_infer("densenet121")

def test_densenet161(self):
self.models_infer("densenet161")

def test_densenet169(self):
self.models_infer("densenet169")

def test_densenet201(self):
self.models_infer("densenet201")

def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)

Expand Down
5 changes: 5 additions & 0 deletions python/paddle/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
from .models import vgg13 # noqa: F401
from .models import vgg16 # noqa: F401
from .models import vgg19 # noqa: F401
from .models import DenseNet # noqa: F401
from .models import densenet121 # noqa: F401
from .models import densenet161 # noqa: F401
from .models import densenet169 # noqa: F401
from .models import densenet201 # noqa: F401
from .models import LeNet # noqa: F401
from .transforms import BaseTransform # noqa: F401
from .transforms import Compose # noqa: F401
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from .vgg import vgg16 # noqa: F401
from .vgg import vgg19 # noqa: F401
from .lenet import LeNet # noqa: F401
from .densenet import DenseNet # noqa: F401
from .densenet import densenet121 # noqa: F401
from .densenet import densenet161 # noqa: F401
from .densenet import densenet169 # noqa: F401
from .densenet import densenet201 # noqa: F401

__all__ = [ #noqa
'ResNet',
Expand All @@ -45,5 +50,10 @@
'mobilenet_v1',
'MobileNetV2',
'mobilenet_v2',
'LeNet'
'LeNet',
'DenseNet',
'densenet121',
'densenet161',
'densenet169',
'densenet201'
]
245 changes: 245 additions & 0 deletions python/paddle/vision/models/densenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import paddle
import paddle.fluid as fluid
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url

__all__ = []

model_urls = {
'densenet121': ('', ''),
'densenet161': ('', ''),
'densenet169': ('', ''),
'densenet201': ('', '')
}


class _DenseLayer(nn.Sequential):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
super(_DenseLayer, self).__init__()
self.add_sublayer('norm1', nn.BatchNorm2D(num_input_features)),
self.add_sublayer('relu1', nn.ReLU()),
self.add_sublayer(
'conv1',
nn.Conv2D(
num_input_features,
bn_size * growth_rate,
kernel_size=1,
stride=1)),
self.add_sublayer('norm2', nn.BatchNorm2D(bn_size * growth_rate)),
self.add_sublayer('relu2', nn.ReLU()),
self.add_sublayer(
'conv2',
nn.Conv2D(
bn_size * growth_rate,
growth_rate,
kernel_size=3,
stride=1,
padding=1)),
self.drop_rate = drop_rate

def forward(self, x):
new_features = super(_DenseLayer, self).forward(x)
if self.drop_rate > 0:
new_features = nn.Dropout(
new_features, p=self.drop_rate, training=self.training)
return fluid.layers.concat([x, new_features], axis=1)


class _DenseBlock(nn.Sequential):
def __init__(self, num_layers, num_input_features, bn_size, growth_rate,
drop_rate):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(num_input_features + i * growth_rate,
growth_rate, bn_size, drop_rate)
self.add_sublayer('denselayer%d' % (i + 1), layer)


class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features):
super(_Transition, self).__init__()
self.add_sublayer('norm', nn.BatchNorm2D(num_input_features))
self.add_sublayer('relu', nn.ReLU())
self.add_sublayer(
'conv',
nn.Conv2D(
num_input_features,
num_output_features,
kernel_size=1,
stride=1))
self.add_sublayer('pool', nn.AvgPool2D(kernel_size=2, stride=2))


class DenseNet(nn.Layer):
"""Densenet-BC model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
growth_rate (int) - how many filters to add each layer
block_config (list of 4 ints) - how many layers in each pooling block
num_init_features (int) - the number of filters to learn in the first convolution layer
bn_size (int) - multiplicative factor for number of bottle neck layers
drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes

Examples:
.. code-block:: python

from paddle.vision.models import DenseNet

config = (6,12,32,32)

densenet = DenseNet(block_config=config, num_classes=10)
"""

def __init__(self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充一个with_pool,同时num_classes保持一致的逻辑,在小于等于0时不创建最后的全连接层

growth_rate=32,
block_config=(6, 12, 24, 16),
num_init_features=64,
bn_size=4,
drop_rate=0,
num_classes=1000):

super(DenseNet, self).__init__()

self.features = nn.Sequential(
nn.Conv2D(
3, num_init_features, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2D(num_init_features),
nn.ReLU(),
nn.MaxPool2D(
kernel_size=3, stride=2, padding=1), )

num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate)
self.features.add_sublayer('denseblock%d' % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(
num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_sublayer('transition%d' % (i + 1), trans)
num_features = num_features // 2
self.features.add_sublayer('norm5', nn.BatchNorm2D(num_features))
self.classifier = nn.Linear(num_features, num_classes)

def forward(self, x):
features = self.features(x)
out = nn.ReLU(features, )
out = nn.AvgPool2D(
out, kernel_size=7, stride=1).view(features.size(0), -1)
out = self.classifier(out)
return out


def _densenet(arch, block_cfg, pretrained, **kwargs):
model = DenseNet(block_config=block_cfg, **kwargs)

if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
arch)
weight_path = get_weights_path_from_url(model_urls[arch][0],
model_urls[arch][1])

param = paddle.load(weight_path)
model.load_dict(param)
return model


def densenet121(pretrained=False, **kwargs):
"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet

Examples:
.. code-block:: python

from paddle.vision.models import densenet121

# build model
model = densenet121()
"""
model_name = 'DenseNet121'
return _densenet(model_name, (6, 12, 24, 16), pretrained, **kwargs)


def densenet161(pretrained=False, **kwargs):
"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet

Examples:
.. code-block:: python

from paddle.vision.models import densenet161

# build model
model = densenet161()
"""
model_name = 'DenseNet161'
return _densenet(model_name, (6, 12, 32, 32), pretrained, **kwargs)


def densenet169(pretrained=False, **kwargs):
"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet

Examples:
.. code-block:: python

from paddle.vision.models import densenet169

# build model
model = densenet169()
"""
model_name = 'DenseNet169'
return _densenet(model_name, (6, 12, 48, 32), pretrained, **kwargs)


def densenet201(pretrained=False, **kwargs):
"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet

Examples:
.. code-block:: python

from paddle.vision.models import densenet201

# build model
model = densenet201()
"""
model_name = 'DenseNet201'
return _densenet(model_name, (6, 12, 64, 48), pretrained, **kwargs)