From 3e48302a5de6d99eebdcb6c05f668a0dff904165 Mon Sep 17 00:00:00 2001 From: Hamza Farhan Date: Tue, 8 Oct 2019 14:52:12 +0500 Subject: [PATCH] Update model.py An easy way to train the models on non RGB images with channels != 3 --- efficientnet_pytorch/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index 3a42217..9820379 100644 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -191,9 +191,13 @@ def from_name(cls, model_name, override_params=None): return cls(blocks_args, global_params) @classmethod - def from_pretrained(cls, model_name, num_classes=1000): + def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3): model = cls.from_name(model_name, override_params={'num_classes': num_classes}) load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) + out_channels = round_filters(32, model._global_params) + model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) return model @classmethod