Skip to content

Commit

Permalink
Define all C++ model constructors explicit (#2944)
Browse files Browse the repository at this point in the history
* Making all model constructors explicit.

* formatting.
  • Loading branch information
datumbox authored Nov 3, 2020
1 parent f95b053 commit 6b071be
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 45 deletions.
2 changes: 1 addition & 1 deletion torchvision/csrc/models/alexnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace models {
struct VISION_API AlexNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr}, classifier{nullptr};

AlexNetImpl(int64_t num_classes = 1000);
explicit AlexNetImpl(int64_t num_classes = 1000);

torch::Tensor forward(torch::Tensor x);
};
Expand Down
10 changes: 5 additions & 5 deletions torchvision/csrc/models/densenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
torch::nn::Sequential features{nullptr};
torch::nn::Linear classifier{nullptr};

DenseNetImpl(
explicit DenseNetImpl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
Expand All @@ -35,7 +35,7 @@ struct VISION_API DenseNetImpl : torch::nn::Module {
};

struct VISION_API DenseNet121Impl : DenseNetImpl {
DenseNet121Impl(
explicit DenseNet121Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 24, 16},
Expand All @@ -45,7 +45,7 @@ struct VISION_API DenseNet121Impl : DenseNetImpl {
};

struct VISION_API DenseNet169Impl : DenseNetImpl {
DenseNet169Impl(
explicit DenseNet169Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 32, 32},
Expand All @@ -55,7 +55,7 @@ struct VISION_API DenseNet169Impl : DenseNetImpl {
};

struct VISION_API DenseNet201Impl : DenseNetImpl {
DenseNet201Impl(
explicit DenseNet201Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 32,
const std::vector<int64_t>& block_config = {6, 12, 48, 32},
Expand All @@ -65,7 +65,7 @@ struct VISION_API DenseNet201Impl : DenseNetImpl {
};

struct VISION_API DenseNet161Impl : DenseNetImpl {
DenseNet161Impl(
explicit DenseNet161Impl(
int64_t num_classes = 1000,
int64_t growth_rate = 48,
const std::vector<int64_t>& block_config = {6, 12, 36, 24},
Expand Down
4 changes: 2 additions & 2 deletions torchvision/csrc/models/googlenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};

BasicConv2dImpl(torch::nn::Conv2dOptions options);
explicit BasicConv2dImpl(torch::nn::Conv2dOptions options);

torch::Tensor forward(torch::Tensor x);
};
Expand Down Expand Up @@ -71,7 +71,7 @@ struct VISION_API GoogLeNetImpl : torch::nn::Module {
torch::nn::Dropout dropout{nullptr};
torch::nn::Linear fc{nullptr};

GoogLeNetImpl(
explicit GoogLeNetImpl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false,
Expand Down
12 changes: 7 additions & 5 deletions torchvision/csrc/models/inception.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ struct VISION_API BasicConv2dImpl : torch::nn::Module {
torch::nn::Conv2d conv{nullptr};
torch::nn::BatchNorm2d bn{nullptr};

BasicConv2dImpl(torch::nn::Conv2dOptions options, double std_dev = 0.1);
explicit BasicConv2dImpl(
torch::nn::Conv2dOptions options,
double std_dev = 0.1);

torch::Tensor forward(torch::Tensor x);
};
Expand All @@ -30,7 +32,7 @@ struct VISION_API InceptionAImpl : torch::nn::Module {
struct VISION_API InceptionBImpl : torch::nn::Module {
BasicConv2d branch3x3, branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3;

InceptionBImpl(int64_t in_channels);
explicit InceptionBImpl(int64_t in_channels);

torch::Tensor forward(const torch::Tensor& x);
};
Expand All @@ -50,7 +52,7 @@ struct VISION_API InceptionDImpl : torch::nn::Module {
BasicConv2d branch3x3_1, branch3x3_2, branch7x7x3_1, branch7x7x3_2,
branch7x7x3_3, branch7x7x3_4;

InceptionDImpl(int64_t in_channels);
explicit InceptionDImpl(int64_t in_channels);

torch::Tensor forward(const torch::Tensor& x);
};
Expand All @@ -60,7 +62,7 @@ struct VISION_API InceptionEImpl : torch::nn::Module {
branch3x3dbl_1, branch3x3dbl_2, branch3x3dbl_3a, branch3x3dbl_3b,
branch_pool;

InceptionEImpl(int64_t in_channels);
explicit InceptionEImpl(int64_t in_channels);

torch::Tensor forward(const torch::Tensor& x);
};
Expand Down Expand Up @@ -110,7 +112,7 @@ struct VISION_API InceptionV3Impl : torch::nn::Module {

_inceptionimpl::InceptionAux AuxLogits{nullptr};

InceptionV3Impl(
explicit InceptionV3Impl(
int64_t num_classes = 1000,
bool aux_logits = true,
bool transform_input = false);
Expand Down
13 changes: 8 additions & 5 deletions torchvision/csrc/models/mnasnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,28 @@ struct VISION_API MNASNetImpl : torch::nn::Module {

void _initialize_weights();

MNASNetImpl(double alpha, int64_t num_classes = 1000, double dropout = .2);
explicit MNASNetImpl(
double alpha,
int64_t num_classes = 1000,
double dropout = .2);

torch::Tensor forward(torch::Tensor x);
};

struct MNASNet0_5Impl : MNASNetImpl {
MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet0_5Impl(int64_t num_classes = 1000, double dropout = .2);
};

struct MNASNet0_75Impl : MNASNetImpl {
MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet0_75Impl(int64_t num_classes = 1000, double dropout = .2);
};

struct MNASNet1_0Impl : MNASNetImpl {
MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet1_0Impl(int64_t num_classes = 1000, double dropout = .2);
};

struct MNASNet1_3Impl : MNASNetImpl {
MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
explicit MNASNet1_3Impl(int64_t num_classes = 1000, double dropout = .2);
};

TORCH_MODULE(MNASNet);
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/models/mobilenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel;
torch::nn::Sequential features, classifier;

MobileNetV2Impl(
explicit MobileNetV2Impl(
int64_t num_classes = 1000,
double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
Expand Down
30 changes: 20 additions & 10 deletions torchvision/csrc/models/resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ struct ResNetImpl : torch::nn::Module {
int64_t blocks,
int64_t stride = 1);

ResNetImpl(
explicit ResNetImpl(
const std::vector<int>& layers,
int64_t num_classes = 1000,
bool zero_init_residual = false,
Expand Down Expand Up @@ -186,45 +186,55 @@ torch::Tensor ResNetImpl<Block>::forward(torch::Tensor x) {
}

struct VISION_API ResNet18Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet18Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet18Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet34Impl : ResNetImpl<_resnetimpl::BasicBlock> {
ResNet34Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet34Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet50Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet50Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet50Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet101Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet101Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet101Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNet152Impl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNet152Impl(int64_t num_classes = 1000, bool zero_init_residual = false);
explicit ResNet152Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNext50_32x4dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext50_32x4dImpl(
explicit ResNext50_32x4dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
ResNext101_32x8dImpl(
explicit ResNext101_32x8dImpl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet50_2Impl(
explicit WideResNet50_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};

struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet101_2Impl(
explicit WideResNet101_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
Expand Down
8 changes: 4 additions & 4 deletions torchvision/csrc/models/shufflenetv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ struct VISION_API ShuffleNetV2Impl : torch::nn::Module {
};

struct VISION_API ShuffleNetV2_x0_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x0_5Impl(int64_t num_classes = 1000);
};

struct VISION_API ShuffleNetV2_x1_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x1_0Impl(int64_t num_classes = 1000);
};

struct VISION_API ShuffleNetV2_x1_5Impl : ShuffleNetV2Impl {
ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x1_5Impl(int64_t num_classes = 1000);
};

struct VISION_API ShuffleNetV2_x2_0Impl : ShuffleNetV2Impl {
ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
explicit ShuffleNetV2_x2_0Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(ShuffleNetV2);
Expand Down
6 changes: 3 additions & 3 deletions torchvision/csrc/models/squeezenet.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
int64_t num_classes;
torch::nn::Sequential features{nullptr}, classifier{nullptr};

SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);
explicit SqueezeNetImpl(double version = 1.0, int64_t num_classes = 1000);

torch::Tensor forward(torch::Tensor x);
};
Expand All @@ -19,15 +19,15 @@ struct VISION_API SqueezeNetImpl : torch::nn::Module {
// accuracy with 50x fewer parameters and <0.5MB model size"
// <https://arxiv.org/abs/1602.07360> paper.
struct VISION_API SqueezeNet1_0Impl : SqueezeNetImpl {
SqueezeNet1_0Impl(int64_t num_classes = 1000);
explicit SqueezeNet1_0Impl(int64_t num_classes = 1000);
};

// SqueezeNet 1.1 model from the official SqueezeNet repo
// <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>.
// SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
// than SqueezeNet 1.0, without sacrificing accuracy.
struct VISION_API SqueezeNet1_1Impl : SqueezeNetImpl {
SqueezeNet1_1Impl(int64_t num_classes = 1000);
explicit SqueezeNet1_1Impl(int64_t num_classes = 1000);
};

TORCH_MODULE(SqueezeNet);
Expand Down
34 changes: 25 additions & 9 deletions torchvision/csrc/models/vgg.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct VISION_API VGGImpl : torch::nn::Module {

void _initialize_weights();

VGGImpl(
explicit VGGImpl(
const torch::nn::Sequential& features,
int64_t num_classes = 1000,
bool initialize_weights = true);
Expand All @@ -21,42 +21,58 @@ struct VISION_API VGGImpl : torch::nn::Module {

// VGG 11-layer model (configuration "A")
struct VISION_API VGG11Impl : VGGImpl {
VGG11Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG11Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 13-layer model (configuration "B")
struct VISION_API VGG13Impl : VGGImpl {
VGG13Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG13Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 16-layer model (configuration "D")
struct VISION_API VGG16Impl : VGGImpl {
VGG16Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG16Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 19-layer model (configuration "E")
struct VISION_API VGG19Impl : VGGImpl {
VGG19Impl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG19Impl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 11-layer model (configuration "A") with batch normalization
struct VISION_API VGG11BNImpl : VGGImpl {
VGG11BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG11BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 13-layer model (configuration "B") with batch normalization
struct VISION_API VGG13BNImpl : VGGImpl {
VGG13BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG13BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 16-layer model (configuration "D") with batch normalization
struct VISION_API VGG16BNImpl : VGGImpl {
VGG16BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG16BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

// VGG 19-layer model (configuration 'E') with batch normalization
struct VISION_API VGG19BNImpl : VGGImpl {
VGG19BNImpl(int64_t num_classes = 1000, bool initialize_weights = true);
explicit VGG19BNImpl(
int64_t num_classes = 1000,
bool initialize_weights = true);
};

TORCH_MODULE(VGG);
Expand Down

0 comments on commit 6b071be

Please sign in to comment.