From 2e7fabd6e8f6aa6610efddef06b4f2038b74548e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 1 Mar 2023 09:55:41 +0100 Subject: [PATCH 01/14] update code --- requirements/image.txt | 1 - src/torchmetrics/functional/image/lpips.py | 412 ++++++++++++++++++ .../functional/image/lpips_models/alex.pth | Bin 0 -> 6009 bytes .../functional/image/lpips_models/squeeze.pth | Bin 0 -> 10811 bytes .../functional/image/lpips_models/vgg.pth | Bin 0 -> 7289 bytes src/torchmetrics/image/lpip.py | 49 +-- src/torchmetrics/utilities/imports.py | 1 + 7 files changed, 422 insertions(+), 41 deletions(-) create mode 100644 src/torchmetrics/functional/image/lpips.py create mode 100644 src/torchmetrics/functional/image/lpips_models/alex.pth create mode 100644 src/torchmetrics/functional/image/lpips_models/squeeze.pth create mode 100644 src/torchmetrics/functional/image/lpips_models/vgg.pth diff --git a/requirements/image.txt b/requirements/image.txt index 13275bd1dcf..286b1f64c6b 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -4,4 +4,3 @@ scipy >1.0.0, <1.11.0 torchvision >=0.8, <=0.14.1 torch-fidelity <=0.3.0 -lpips <=0.1.4 diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py new file mode 100644 index 00000000000..940d944642d --- /dev/null +++ b/src/torchmetrics/functional/image/lpips.py @@ -0,0 +1,412 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Content copied from +# https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py +# and +# https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py +# and with adjustments from +# https://github.com/richzhang/PerceptualSimilarity/pull/114/files +from collections import namedtuple + +import torch +import torchvision +from packaging import version +from torch import Tensor, nn +from torchvision import models as tv +from typing_extensions import Literal + +from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_13 + +_weight_map = { + "squeezenet1_1": "SqueezeNet1_1_Weights", + "alexnet": "AlexNet_Weights", + "vgg16": "VGG16_Weights", +} + + +def _get_net(net: str, pretrained: bool): + if _TORCHVISION_GREATER_EQUAL_0_13: + if pretrained: + pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features + else: + pretrained_features = getattr(tv, net)(weights=None).features + else: + pretrained_features = getattr(tv, net)(pretrained=pretrained).features + return pretrained_features + + +class Squeezenet(torch.nn.Module): + """Squeezenet implementation.""" + + def __init__(self, requires_grad: int = False, pretrained: int = True): + super().__init__() + pretrained_features = _get_net("squeezenet1_1", pretrained) + + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + """Process input.""" + h = self.slice1(x) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) + out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) + + return out + + +class Alexnet(torch.nn.Module): + """Alexnet implementation.""" + + def __init__(self, requires_grad=False, pretrained=True): + super().__init__() + alexnet_pretrained_features = _get_net("alexnet", pretrained) + + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + """Process input.""" + h = self.slice1(x) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple("AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class Vgg16(torch.nn.Module): + """Vgg16 implementation.""" + + def __init__(self, requires_grad=False, pretrained=True): + super().__init__() + vgg_pretrained_features = _get_net("vgg16", pretrained) + + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + """Process input.""" + h = self.slice1(x) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +def spatial_average(in_tens, keepdim=True): + """Spatial averaging over heigh and width of images.""" + return in_tens.mean([2, 3], keepdim=keepdim) + + +def upsample(in_tens, out_hw=(64, 64)): # assumes scale factor is same for H and W + """Upsample input with bilinear interpolation.""" + return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) + + +def normalize_tensor(in_feat, eps=1e-10): + """Normalize tensors.""" + norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) + return in_feat / (norm_factor + eps) + + +class ScalingLayer(nn.Module): + """Scaling layer.""" + + def __init__(self): + super().__init__() + self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) + self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) + + def forward(self, inp): + """Process input.""" + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv.""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super().__init__() + + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + def forward(self, x): + """Process input.""" + return self.model(x) + + +class _LPIPS(nn.Module): + def __init__( + self, + pretrained=True, + net="alex", + spatial=False, + pnet_rand=False, + pnet_tune=False, + use_dropout=True, + model_path=None, + eval_mode=True, + ): + """Initializes a perceptual loss torch.nn.Module. + + Args: + pretrained: This flag controls the linear layers should be pretrained version or random + net: Indicate backbone to use, choose between ['alex','vgg','squeeze'] + spatial: If input should be spatial averaged + pnet_rand: If backbone should be random or use imagenet pre-trained weights + pnet_tune: If backprop should be enabled + use_dropout: If dropout layers should be added + model_path: Model path to load pretained models from + eval_mode: If network should be in evaluation mode + """ + super().__init__() + + self.pnet_type = net + self.pnet_tune = pnet_tune + self.pnet_rand = pnet_rand + self.spatial = spatial + self.version = version + self.scaling_layer = ScalingLayer() + + if self.pnet_type in ["vgg", "vgg16"]: + net_type = Vgg16 + self.chns = [64, 128, 256, 512, 512] + elif self.pnet_type == "alex": + net_type = Alexnet + self.chns = [64, 192, 384, 256, 256] + elif self.pnet_type == "squeeze": + net_type = Squeezenet + self.chns = [64, 128, 256, 384, 384, 512, 512] + self.L = len(self.chns) + + self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) + + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + if self.pnet_type == "squeeze": # 7 layers for squeezenet + self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) + self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) + self.lins += [self.lin5, self.lin6] + self.lins = nn.ModuleList(self.lins) + + if pretrained: + if model_path is None: + import inspect + import os + + model_path = os.path.abspath( + os.path.join(inspect.getfile(self.__init__), "..", f"lpips_models/{net}.pth") + ) + + self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) + + if eval_mode: + self.eval() + + def forward(self, in0, in1, retperlayer=False, normalize=False): + if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] + in0 = 2 * in0 - 1 + in1 = 2 * in1 - 1 + + # v0.0 - original release had a bug, where input was not scaled + in0_input, in1_input = ( + (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) + ) + outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) + feats0, feats1, diffs = {}, {}, {} + + for kk in range(self.L): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + if self.lpips: + if self.spatial: + res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] + else: + res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + else: + if self.spatial: + res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] + else: + res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] + + val = 0 + for layer in range(self.L): + val += res[layer] + + if retperlayer: + return (val, res) + return val + + +class _NoTrainLpips(_LPIPS): + """Wrapper to make sure LPIPS never leaves evaluation mode.""" + + def train(self, mode: bool) -> "_NoTrainLpips": + """Force network to always be in evaluation mode.""" + return super().train(False) + + +def _valid_img(img: Tensor, normalize: bool) -> bool: + """Check that input is a valid image to the network.""" + value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1 + return img.ndim == 4 and img.shape[1] == 3 and value_check + + +def _lpips_update(img1, img2, net, normalize): + if not (_valid_img(img1, normalize) and _valid_img(img2, normalize)): + raise ValueError( + "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]." + f" Got input with shape {img1.shape} and {img2.shape} and values in range" + f" {[img1.min(), img1.max()]} and {[img2.min(), img2.max()]} when all values are" + f" expected to be in the {[0,1] if normalize else [-1,1]} range." + ) + loss = net(img1, img2, normalize=normalize).squeeze() + return loss, img1.shape[0] + + +def _lpips_compute(sum_scores, total, reduction): + return sum_scores / total if reduction == "mean" else sum_scores + + +def learned_perceptual_image_patch_similarity( + img1: Tensor, + img2: Tensor, + net_type: str = "alex", + reduction: Literal["sum", "mean"] = "mean", + normalize: bool = False, +) -> Tensor: + """The Learned Perceptual Image Patch Similarity (`LPIPS_`) calculates the perceptual similarity between two images. + + LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. + This measure has been shown to match human perception well. A low LPIPS score means that image patches are + perceptual similar. + + Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the + chosen backbone (see `net_type` arg). + + Args: + img1: first set of images + img2: second set of images + net_type: str indicating backbone network type to use. Choose between `'alex'`, `'vgg'` or `'squeeze'` + reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`. + normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set + to ``True`` will instead expect input to be in the ``[0,1]`` range. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + """ + net = _NoTrainLpips(net_type) + loss, total = _lpips_update(img1, img2, net, normalize) + return _lpips_compute(loss, total, reduction) diff --git a/src/torchmetrics/functional/image/lpips_models/alex.pth b/src/torchmetrics/functional/image/lpips_models/alex.pth new file mode 100644 index 0000000000000000000000000000000000000000..1df9dfe62abb1fc89cc7f82b4e5fe886c979708e GIT binary patch literal 6009 zcma)9d0dTK_is`(5R&K?Dk(JUJbRy>wT_`8QE4#Un`TNWPlMz-I=H1mC?!Q`Fx?wb zhIBO;Zc)0(bY)1ADRYJlc~4J{_jNzNKYq{WQ+uEFUF-W@d+)W^bB;tIMK$d6HpkB4 z3-`JWF_$Pzf2=6|FXrS;yfha_Mnp#hM1=V)35tvgi3sPqQ7_f#xWz#}Q6bBMqBun{ zr)1)#7!n@M^>SA7>J=3n5gE-XJ1cl8g++uf;8dKIV!SlMLZYL?f_#F)14I16x!y@J zQiT91Z&*-3w3IeV)ip9OC^9I}J|rMom6P}86imE1MH8=qVIkp`=8GZ%gTl z76eCgN&(RkkpaQxJ~7cDVNt3+kwN}3Az^_&(Lv!+5s^MitW~*QCMI4gQBpU*1wov0 zkH4y8ScG4+M~_jJQ}I%@qBb^+^dBSOdV47c#02^YExA4^^5b2iXE@85ILkYam;T8| zoAiv*S+YyCJEt1t&8hX|X7%sf)Ptqc!OForyA@+OjK zM&D0Luh~Oy{crmI|4VOU#aJ=cmX_Rr@ALyx6eRL8`sosB^4;lwo5%#TqymF_1Z@5- zp#6h@jg=LVu7Ml;T|g&ALGHf-ax#IQiA-RKRA6Y2z_@=4=>8x;1k|>BgMRuhp!fd~ z&<_eqaOZ}Jx#3=Yd=>>o`vv+%`*9W-)e>vh&e_EksdNxIazX4%n4miesm?bwu%(=@zj~_uR&Qr{J$-tRCO5?a$Vs5s~IHyPHTS0Mi#oW&_=$9TcS*=-e z-ePW^4D|UCwBmfloL>*vT}1v-{?cCL-QDe^{ownJRp7gW?1D*KUU#qfb{mv_=p4QH zPqIQ*4l=Kj`#+v}Nn~L{g6Quj5hP9P0kmeVLBW_ojO*N3GOsR|e4aiG+PkLPihLn5t&VvoS@UQeiIvmB-!xX>*IF0Bg{Zxx8*g44hL_b+7=6oOFv&*@ z-L{UcTbZr`5yr;k^F9{$wCGPAJD9TNdx>^$EnD_eFxA@_23Ctq1$*3)kLivxJSQaK z(?$^68~%c7`*`fObwBxJ2jl0oOVB<4bdllnQf!@wL!(M~pBeoJk;S`fd7kq1pYU?V zCG@AOY0736<-4?aJG+yCiL2V#8=O)!P-2 zzKm-A4Z++Gy+tbCCm~mT8Qs#EAlU7q0?ANAyW{j}F)Z8bgC(w4_}KU~29K6b!B>Sz z-}iQdTDq;_zGFB;X!cF`8zf5H3!(0h1AuyW4Zo?3{A61Bfh#^zsg1hz8{`7az?jQD0dryR*@`5u-+USiAu>u(1Jc;*>iC+#*?w&%kK8fGg6&`v-&<3ug zdVC=cFApS+%nf{bL>2SmGQsz#4MgQ`AzA19(?IJ+;v`?)N4xkjWA!*eQ@kEgx@RDg$oK4?kJo)tcO(VfEh^x~1XiT;td5+wR1ywMn39KefP)H;b%9XgVQ)3_cH1YH!UbGPOGHHHVeqR zF=_bhbv#9TbzE`$39O7v!~37>m~*vJFwXTpUpN1ZG{#@7HiLOy7Lnvn#0?uy(Ds&5 zLXXmeI9K}uRW>UUBsmxhTTMG*?YdR)QE?Y}^Qwjt?i>@EoWVZ+sz;VoZz5mDY@#7u z=_F&9uV{#a9-mMDGpiYk=Gjy&HxwNYt>e$h|3VbMpH9&aaPIMYtiRD$dcePgEqwAj zlaPCr+WtBh4cPQSDd{eyuHwRx)2)%;EMqv1paOH<4^9 zp^aPa-~-K3Xq}LP+l@|;*`G6Q(|odd{qA#b3p>VqASLv4%{;zN6c0w?=}U1U)mlcR zdp(U@*4vDOrvOb?p3X}8WxSsZ8PEwJ0A^irY1_1WGdknF)I(G+3*I&%HaB?45b&yZ093 zF*P6S&BDm!@*UK=Bo%M>&cmmDits{01GE0>JRCi^73WQ~Mt=iE;&xm|n6u{`v6-$+ z-wgbl4D-vTyGKsu*E00dYo0sRI1ZCj8c5gP400wX1CwS=5Ut9b&HG$2Od+OUltr~k zed%U@4=OP{j0Fq#iIN>xFv>IkA#%yP@OW?rEl`{QH?oVF{hbr|y=$GajT~_@q(kxz ziOtp=oZ+7Pl)E?xvDt}vpIi%!&2X$eHQH9%h3ex~S?7d?DW2}7+P6ScdC zN%~}EYE-(9(YzE(&|(kzZs{-ddZB{MsR4Laa|quHgSSP{@_94Zu8wp%MSe1EnsSpg zCa=N^l4nfss3ZJ6!OEWZ%ElF)jJMt`FgRTZ`5sH~gI9~)&G(OqYv>JJ7AZ%UDO^W` z`8_=)l8GQt`V`$v??Fn#UpQ{z5F+UqU>B@sBy5d+O&*1mV?DDJBbJBZozQX|xzvmf zt9)*&HAo$ePwXILQ#|lpyA8tgcSIrcK9lA+5+9w&64ld>B&>Q4)vi+%#Yz-uhu>)o zKBY>|q$a@77ayT|i@RVdbCCEQ3=>(48Cb7+2>R{NBl+gpfFo>SVSWZo*Bwawn<5Dg z{X!$INskbG`93hsO@jtoWD?&`C+x0|*aULl-UFd31F-FlE+r~M@QY+SJF#jW(Tg=k zNq-BRE857PT}9ojzyqWXUtZ{Xeg;gNI1S=uINhtm4S?`X)W*>(gn~Whm+G&u&XSisy8k zaC(+ISZZA+n>U2Bbu$-$$-rCivv3PG#w^9QihN9&qCr2tRe-oh+rc|agKAygjD3Bm z@NU&*Hh#Dz@G_0sp!8LCZdE1*#Yn#rjLJrg(|XVMkM!SBRA{d)DlrHHhLpj~V-n7fsYXM%(ON2mcWl@+4|=*_JaftSH+o2h5}FMOtp z)G$%f_WfjQizZeo{Z0y(WMbZ;!)WmaSeHaA)E=D#oM$8mf+ z>|J+19D6O1Jgym!S1+d1q#_qIpTrOc-DynXkwjek+uu-nwgmGARKWgU6ls02EuA!a zB@B7g#kAfzifLniV}Fm`fX4bA%qD7z4nk8_cm8^UnlJgix2Q3ib%(HXXmFM za+CGL8M_@YT$~Ky@ke2TN(dNYDEs)Qky!OnmE63%240NTqhzZqRCaiRp>jX`<-%Cu zsxKw*xn74p_-PbW&mD+r&7SDHaU3=u(m}ZR7u@%oiu<0Fp~9YPBd=q9HE;b>zOERMPd(WVs` zFP3275(YA^>SD96Dh_B<0R2i2s6Lwmmjc>h;p8efu>3yJin)L;!5wU9TO#YL&;eK9 z9fLLds=@v1RWcF|f`hUydwG@ynN=AtY!Wx3rTtLIeRfR9O%6lbt@^_D*5hp3k`6Yb zupDxxJ%oW{L)gdm&2VJmB;1tv6to-7VQ#nr+Ncl0h5ECh@j|EFbgGC+0ikx!ofXkD zX(~89y+sXve#km(Dusr4O+2Ng1Z9W+vdhlW!8=U}u>8CVuAKJ(G@`3u_ncNxGt`1T z0a?&5LJR7>SL6N|O)#)42V;p6+Rf^OZEu~02EmABMcMGPtq~nnwHR;5DGC$*RKulK zYDglah@1Tz>tnp*k#e7V;Obo_H*d>#uGaUyIjPh}6(nL}ZQ8Kw?-TThvbinyK7@3y?_APDFUmjj` zepWK>F*^;lCrbd`lW0+ej&P3PA>E;U2@SF|X~XKx(3sr)F5PE<5IY9yz)+ z5GKxA47tPlVO_c|2E4Tu2Kfl-h{QH_Pse=nhqQ+s&~1WOp4Wt%Epypkro(YT4Fy+G z3MiI;Wv?I7rtjCkrYFyt(!x(4;7ivr@W{(!i%y?opEjh7)3w|Wm)|xDZOv2d0?IRq zMY@`B9~Hya*+uZ{vvg_~a}-0D-XJZfGvRLjRGewlpDvg^2!Gdd$0dWWLi!XF$n{)m z=X2q2yW`@Iz|`+#V`{ZwmX74ldqT#0R^3Vh4f6}Z;LKa17XOTm#Xdr_{b{IXvIjgH82o%=2b02_CArMG*L7lBvVKiMRZG{QY2(Zg+?@JRB6ya zn)C=wN=oyHW)1Swl|J9+`RjLHuj}r!-}|oj+H0@(+GiheQfn1Ler|F2ra!-Es-}6I z)TCF^U4KPFNfTGqh2de5VZLD@UW@(0BZ9&f2!Erl3Xw||`guhJ#rQ=CrA5MCCa%&! z3nGOwPQ6@ZA_Bw0BZaa~Qm(y1!WIMw<(zsga#acmii`~L^YU9TFUV(sP=3v#E(YH& zyCHtQkzHmZ6sCmF^9%QzXCLGnsUVc}6iS(R3Z+e4)k1<6SeS=~&GQQ}w=|FP3knE~ z6!!9s3=8)SH1}E*859zs;1%vScTrHtJg-Q<1rcH4UW-R52xUx6T;(FVhWP~e31z#- zD>#IN`9x0Z?o<%UxhhzqwT&eY!G#`Xo9Qc<&qs#yQ5ybTnbT@hkG`k42x(TfROEBpF5?FvGV`;_kgxY@z zbpC$`bh`-jx(RIlOQ8RM32ZDF3mfEt(BLnDVONNx{ud&to)C!}Ju?{GMKGkBfcIYl zqyG?CGrvPL^e=(&{||wQpI@SjaF|GF>e|~Y)GyL!o=>EY&@8Yk^@n$*{s5^5KB!reC;i$kromM8U3jfoI&{`z4>4A7Au6_OonOg`Mkq~-7 z{5NPR3+dcf)3 zOjbfykR#3g?T2!98F|H6H)ttbED}cbfYIGdJmC_NFs7$7wwvjXNQ6s8 zLQxO2tlLqL{#Mnl zji~FcrwI1?Q_s6Amdu~K?&>c4`|fWo5ZC6eyKdpi`>e#4QJ!2%=Z)D5rVrl=N-Bcj$60$!@|EQ> zVioAisuZ%UCW7|Uq4Y@UD#(u0L#zJDSo_ilH+}4dkHe?4(L?Vs;E;;bKAO<|-?M1@ z=Q<)&yq+%KV}XI=f6B=ADaJyq)e(}*HoO4G&1!hKgNvK3$Zn+4B;SO}zo71#mbqm~>_!?@`rVf9Mxk({csuCmbi8o2luo7Jd=56f?oU)~ASHD@%J+kO&11$+a}cw{ASeMm6ftYt@1 z9>l^aqcbFO`DRS{s6=1DGP-(<8vVV7qIogMysf2evp%sZsUN6j$p*T?a0VXUF3&x> zJBRAP)A$cwyWK*5^ZdSeOG?l^6S%%Imh_1AKDz9NATuaU{< z+%yUIs*WH{7g`zBCLd54k%O)$W1u}r2J5~I0A1<%jAv^)GuwYF&wsHYe)}m0S~974 zxUWA1ZgnRKgEygf9Aqb2spVyjl7Ud5a;|hNoBh1f)?|8KC=(L zrxTN^Pm#sU5` z<2CqYj~4dt9DvD{o9IuSNnn4xH*;r>DU$UkX>{{$Y$f*uZEcI-wr>*gD*g(=7hZxu zJ%jztoiybY6oDN2d7JZ{#uYJcy<(VUb^D0dJ1KN(x8r0&%^+f=3k{Mq1Ibt?PO4!B z%E&y#BZ)`I&&amFwaMCj3v4{J4&FE`<5mSJy!%;?zU~)~&j#i2ZY^JicSlQ-C7%6h zq*8w>vt|W6W{!*F;3X?V=(;noZ{A+KSazB|y?hNyUe|*@-2l=hH<*|wX*gKTpSiYN z3%-}s!16*H{{Fj9U`Ch?$x+k6vo#9hya<_H#T-5S(MJ3AH0s(*l6zp}ON6OjV*dst z=uyF4PewcV6|UbL47*M3(LQP>qjm8rLvB5wO`)49@mWk2$`)X=j65w&vjjRto2HN} zOycD^tV;W47-F+S(EH36_RyG9OtX?9qnU0`r&92DM41BGih@p`lE|;;FnN89gpvM@v>$r~`iSDbUbtSw=eD1qKxo(iv3ck}$s?3o*oKtduMPHV{Lep(bR*e*VP91==M&DvQ9~nb$ zi!bc8hL=3QcfO$U{3BC2{VeZt`w;4#ArI<=@L#NTh9L(N#cOr&<*uM97eXvzu#H?Da zA;wcp{o(5@YcckF?53+mkYlixNhn^)aO2i8#n2n*ug8D!XC;SheveD`m4l=NK0L~G zhWL+bZ(Eo{R0-%9eof#|h%L&a)7Z`13sAH})E7 zeSLuhYYgTdo>av%uR{fM*4&37r5nhp2Yt|OZ9VvT4#j#|d2wE=FfWEX&~ue z#0K=PWV1%taSL9JAdEtPI%I7bY(_(LGwiX~6dtmf0;*MZOu+Pw^wEnYaQuJ_o-Hf@bCDq}>AjknFlHMx?;3#qGyj1Td6|Oi z#cSE)nsZ^iVG_Bpw-#J(+y|)?U6?f9ghVHm!`;}sNJ_x9FIfa&d|x`-x7=~Kn9t#a8CN#=AgX_K#IWs2xCiz_ z&N>Aw_^H5r(X@qnP%!ti3)Mat3-cbjeX@l#XSBJ`4qvKMb!Xd?Jy$lK4VPlSJLg zW7cUtgY=>4G*a{c{%P69?J#o#jcqJi#N1(&4jg9&q|S$q1)qR#{h60|=@xwTOv60o zbmo`Sb|!rLZlJSzLBGu##q;^PYdz0wbu7qV$wSo#>f*lmk9p$wZwx+5Jw+d3NYh2$ zoHsHs^5%7>Uc>|4jZW0CULDI7uYqlhF<9Qo2Dh|zpza-w1-q(wE?0IjaAW{~hN&N; zXCB8?6vVN@`|WTe%8|)i&J){yxg#2K3ogUN5CNol7oomFJeb8Aim@oOW6X!!C&le~ z52VHGkn%_!YE=qpbM_X9@>$0l_OO%5ty;lUt&ZmWOs9bH#i1}&Hy<{teT4%uy|^z= zEy09;OswIJ)dk@DLy?(uX$A21u#D*>J?3ffad>;-I+VV)N2@nMpmHn@E|mE~xKt*w zFV2H=6Vt(N;R@hqZDw{#mBO^)SxnElKrOkB?p{?uGo+J=#jG0ia?>K_QOmf2KFYX$ zmK>)na^>1T_65oL&KP^MQDCuT4fP0_#eLJhK^!Xk;-#BaG(Eh4<^&$&E%Mgl%BL!F z15+6qGsDRVRq{loRxOVHxCmRMHKSBk^hReBv2?4gxP-CzBRW zr-%Foa^v#G;q8VP^0o3c@&B5RPvqP|*j?5L=gjNOP42mtmhyR|ad+ z(OGtMtY;~FX=$S~4lTp4OL}3fg*kVxqKfuaJq9~9^tn#E2HMo|mSXQWRIb(lZ_gja zv1-dHgvnBu>jkv)y9&*JYe_xFCvwjD9jv->e>(Dk0d23iPjo&^0Q*q~NcboX*u372 zdd)&%~Y<)?HSnB|16uM=r%6)ZGQr77AhC%Gz~gwI?}2AGY3 zT^UJie(NN5jF%ANLw3_j|E-LfRTjv}>tZkL!UbIKj04q1H zA+ME9K9I!e$F89=H>+Ted=<48=F!)FAsm#caJ%cBxYCdqERfSBZ>^2ECeO8`bLwHd z?6RCw?EMxCcPindHLI|0;Ai^ms1Y}C^dcI5cM(;2;lPdBeuZ77F_9YQYti5r$z9J& z1M$643ke!~k{Xm+<5=4$aIfiy2Bz-dU$~k5dh|GHsvSvnvnzRFxytx*&O9o}D}fQo zyUEEgzEHa>Sa*E=_j@cc%Q=RdUZG#1lP&54}g!D*Vg=rb)P`^e_Y zbGVYbUgEWNU$_^CxD8=4+y$tjT|faXxIU+9=;xioI9mlpm?&$@jUKa&O2~ziTD=P5 zE_g(Cs!rgNZ{6i7y7i~Q=Vht)LN)4B{EYWwYi}H$V8;b+yutRLxt`f3cbRJ54Z>mP zPEqJzM*QS5ao=TGu6DB|&MlGVCgteT2MXo1%5yE9JvyA0hs=SFkVLp)pNRQZ8`1gg zQ6^wWGcD;`O7|b*(T8CjbWi6=;Kl5QnXx~Z(AGilc+C^>Ja)~LCb1$)G@DZm|Vgo5i)ThJSYE9P@_^?`ZZ|>oN>;94E$_rW8}nCMVEI z;A6a>3hX&~229_bgW&0U_^MAUO<2+fH6gdC{;>+`VN)X5^`@E@u0Bf^4f6n%%X!>S z8($cYv$*@;zJc>&OR^)_o<6-U&4o|U<|gqp!N@fgmYknSrHTyk$CRz?7zG>Lq#8-x zP3|z`m#rq>eXLObfHdl>oyE$t7a&Dq1g*|Kje|XhW5i$;`XI~~{UpW!wr|7RDTT0c zo+3_F-cD=MyKAe8qc7F})=UkqWMN1|I|genB_?-~Xq=zQjCM6a=l#=Q|Je1^^YIxF z_$?v)2J%Zxd{CXSn4Siwy2t)xOxS6!CTK#_k8KFynhUT5~=_>1z8SrA8CRyOiR zwhZ9(Wv{Xp4L3=&i627^kFr-LB*Ep6@d zTeQ=}G5%~CPj1M&;GL9e`p8`yYz=t0_L7Sj+nBwAYjLck4q!S;M!s&C4bU;%aeLckiNB?>xlEV-cO#JGbW9Bl@=x5gPOeQG4TY7=i;AhmpqwB0K z^_xvXqU=N#e08;^vDUqrUe5mfHFFn`%voNQlX}9+YptR$16`?Rd?~JNcBG;` zF4R4LA6_YO5OjX;!>%k^L#!U{C*i9Iy}3J{wyCF#fI6s`8cwS_FO#Q9&v?VX+ll=MJ8jHw$u5Lhvs-a;*O@Im)ETeLno8?B z(#f5aHBjo_hT~EOzH8< zj0_(f;vE1>7uoV-+k=T)tcyZaTuoVSxrJeRFKIdrqK;4Ya!Pvk&t0}?w>AFq z7!3~xpTXhA(zvU>jOeIqh<~%At(>ORs*tS@kC1fFr-C`&LaG-q0JAII$efR-@XWbF z`Z{MNd@$}LyYCe+k59Sbvj}B6Tv3L9_&yI``1TUpx%p-j`Si;QT>bZ;X!Q>G6DQ+F1^IdcEOS={5G7auPJ&o=)Q{^dP5gGb1SOi#ZW3 zaQD?Qn%8JRwTAUVw}2h=V}`y!c^@aP1KTc~z~gDnq@ysC^ex(r<$k+~3V#Ku)s{w& zl&jECI|RRJZh(J&$k6E9x;&|+O6&>cYC+$qefUAJk)|B9;u~LEfzS7+vNEoRX`O~C zj^e+8!M(C@;eB;hU{b_zQugSdHj|4XdusW5svvp4=AkvJ{Xl!DP#a;th8 zIV0bPT~(q@wG~F;qK6;QjZtI=?~p_0a18lSQvk{*wu9t@*{CCX5|x({OuXOyJ|wq! zE#x=d1C4+)aQ^jKwsHSj=1P?*Fojw)-Ljat{;Z?r4TJb@&faA4uhm4k{sNKwM~b|X z)aTzQa^srx2H??mt$1^qzF7C;<_?2hMtbn7!ijnC%MtcE^un0DpG2`}7IUZmD%l+{ zQ=m583tc`7`FZXXowvy0kni7^Xj5AplURU?cOGNGX?wokkX5*Be-eIPJOO8QK42W9 z?&14^ub_ap7#fc*BvCO(@#>THB(*{xzIp2KOZ~H%Nmn<3^J*J9(a8Xe6UKsfnF&p- zXki00t00j)!Tn(_@Gk8LteUWdzT2WfNIj(@MS`*A3>~*+9i0>|&kS*tfptlbUC1VfH#PjiIaaD*SvEO=(xJPNg<-G&&*;hN&9Zl0zEq37@6hAdGPe`I@&+tJ(%;gZ1Qh!BIz??*>=fuP!^L7eeT&{^<`5!(>D~} z-1mml7k}c>hz+ceIS5DI+hRwkG=H-GQMlulOuo}CT{&mN@0WgyoYzt2N9>a3YA24w zw-L#p<)luR4cvzLXI9bvFXGTkZUqM43WY)U$6|upb9T_APN=sPQMuWP_*|Z_DJmNo zrEfRkt(+h9St1RY$~xru*+%rKm7~ThQt?qv6yTV9tiw)exEeYbN1bgXoe2%RpQh(& z!+sHpzLYWg)k$oML5s~w*v16pTnCraD%h`iq(H+X3zL`M#rIA{&^Y-Fdezkvhxm1H z&tf#Q(Rd|8Gg9bF+9|FH-`XfWAMD7llD`ET&PXr?bF%UM$U-v4e;2zT@3c7XLth!A zxi^J*F=KFE6v3-gQfQ$@Do#=_VfpoWyyT~yD8<}?`9+kq9+FEe-xC+HD}g|>7|t}DyMmSf@PTJ(_7qxo8Up`~1j zbnG#qbZ91yInfCRxNNk1JDZemW|_Oa_Mk_uA*w%gBWtRWCcaxk()UTy(BydRTVKf* zt$hnlb@xfaBPU#MeTvPyP%OATKak9n+>Zw1SMy$tO~Drh`^ex=>iDednI?3m(WvPq z!%=PfDY4#XUpp=EH4ep~#uacg+lBvHEtTEa+d$ANb&@>qbB;_Nxsk+LzeDbL33>jx z14cZ`A#&bwY~#C8WOJqibN1F@HuXdi?sM^?qU#p~E3enE?)`R>cU?((b^0&lp7$fZ z)g$Ts24@l+l)83Pz4Wd}~7d;TEfS}()(D)EJbq5GMT-v`wbW8-y|>A93UNLiuCg0pD?;~gk>I^`2w=aTaSe z*5Il3^>FcnAr3NK$CuEWM5G=zVcO&!HhWYuNYyb5IC^y%md!RL^F4CO^fE=UrqWUm zkk1}^WLHT5XgewJ^K;ju+K&ymc*0)nDl#P`NR za;LbQJ)V}pKfZ4u=}`O(KaTGfpL_IImyvIQTiO0zG31F-6bhImQmwjD@TA(@#(B|2 za{iek-LU>MuHUhh{gfq#S0@jozivq3&nFBnG=CvD^&^h>jBa9g)hiO4R8!iV=8qq; zG6WJ!(%G9M4)812Td+ShP7ovie4=kzh%TznFkATuj%rLJW^>K3X-O-qw&pM#oF>x3aY2xg%MyxtBQp0~_?%UwzYw%a+Yu@99(UBm0Du7>**EIu^uBc^|CP zD2D@b8Mrx#N9|6T@VA^8NY$ct3ieJI$()|g@|WE=C5o#i<7DMeWR~|ZLE1fIax+$s zEFFK8fcai_)+bH2VYLX0?S7FBn(Fk*4k=ncLkf+ab+UVgnUgnza`Dbn4V>4hMa=Z$ zKw1AI|H-Y1Ha?GE5+%b#Xi15NhW5T>cv3F(U9ZpYQ>ITkeD|}mPOp?{-cB0qBjYcI6;=Q4bLQRhnSH| z8}e~qJ3+}M{fJDtA3EkcQiu3_)Z3IudQBKa&Tq3tNws+xA2mw6&z=6~E?cYc2?saC zgWymTJLCNX`mlN?8VpwIdShnAcl}z)Cfn4r)eVaD3dUe+@yZ7tSh~OecN& zYLMn{PsBBMqV#LA&H0xk+0Mnq5cfmP}frzrbX_#e%OxF9>ejkHkLiYUG}YA{Ktw z&F}coi@$T$WKgpC&R*?QAsy8h$U}=ycqmhaXgR+nW)GJ_IW~f>xf?Ux=m{wZUBwRH z;YMOFq>}L!3Vi1k(d1D4DqQ?yD6xBdlFYy7j7RuDYO4iwi2q uz{Dvu2^{7NtxXuIYE~P51k~-@kt6dG^_RuXjK1yPmz?^?ue~JdpyKh_CB^`K%VUZQ-O@ zJdp)2L42lg3vn-v)y7!)x}f7Z&N&=nyZMl666 z9UBlb%Qv1A8WAVs8yn;w9~u$p%L$5#i;ne8Fpy!ySuA(SIKg9nD}oq_q330OiHP>& zxD7SRFp};v`er7krluyQ#s-X(yF@^Ipr3^vBV8bD>BMny5Mnt9J6H-1VGe8PEV#;d z;`iGwPlm74XF#Tu& z{g@$oga4(M`)_&^Gb1B&BU1xL{vZ0W1tLPi!}JS;1fBn@`#&#+1;z;k#t#V?{;z<- ze+w8J8=DyDndvbT{t=ic5D@u~fXJ}G-xtFIlLP{aLjp$sE1>jm0X-8lQ$rI*`5%GF z{~rOBpr9OAMwQE`xl8*-266lX{WyNiln{Y&)CI*m?ARH&?({z@= zlmbIBnp}nz^0QHkjkkU`jFeY5abQm%lA{(}o9%IgBEQZ0^{|5CLHkZLLNJB~n%p5LbIoxP9 zr1aB1%v>%rZy2&3A{%y=9%I90ei;UB{|)Lh^SO-OFlaxdWXL#h8OPzq1w%?deZ)9% z8RubW;Skxd#q<~#F5@~3y8Ro}XBKf8_hHauNXdxtv4-hZW?APtd4`#dl^ap}&ojSryZy@1pF^Xip0gf-ms;Z&jn>rRady;747)pK6gk#5Mp>CH9 zrR;S?_4QR~>hi;nMi;zje}r(!MRZH(0km@YFm&09@y=V}XnzpNZ|5T6YXbrUrsCQ7 zQq;6>Maqmegr};~HyuA@9ZZM+3I)2c{WH#=ybPUvx^Osk0f8wg*wSCiz7*y}DuESv zyt@+~Y6>*_p*cmn*Fj%s0y#|l|TmJ@0P6sxfeuI~n{=m&gme|x=fTt2ZSkkl)DKhfpYrGMsU+}@i zzGuJ9zm6LRMCfVgH7t|Tz=X(=FdnSKyge=8y$Pp$gZ0?lF`N3S7>dDaSl5(-`RDv- zaOy))*h3hXsbRb5M`Tw|L+qdGkn{1OPfxF4(;WtD>wmyBT>^z^M%Z&P3y-%2pniWT zx<%LHrSmIzck!`QXBjSdZpS006Th^+Bq!My>=u_^d~i@C3mtWe?A3(pp#+G%6(gD{ zM)TwDKymABq;<22vsRsYoMs|v#(WsNykYm*9zx!dwfOW%n(AdXLt%>~=@`Z#UGRSH zjjO<`#)-7;%VYNb?>QLN$3srJ1cKS;uv*Ln3KuS7=R8U3{QM3zUE8tG-~xgts-thW zCH8e`Q&9IU$c#M$>)gwbQ6GTzx@L@@?SzZsi=be76%#fsrs+TCl5|r6)}}wf6NfFZ z_gIS^2QGqF+lbQ>_>jM=NOylPftJQYG+a%`w;*x4v09nVg@JBPuER`?EPQnM4d*B9 zV;_w^gpnF4=-6@;;(tV8+T~S9zC4D~a;i|Mx&bY%a`bSKKPFhp)5+L!#An}Uzt5;c zOA8@$#9?@fSW-r|JC)wj#k1Ag=>8stGc9E(eKnm*BDgs9WH0zy&mr|)8WInxpe?lq z1EG1)eCLC64&&fgIEu_-+i>E*To{AS+cJjw>#jrf{b~5$jX|l}D_nY<2-VyXR3`qN?Ofr8 z*vY3b@lg?sq!y5zUKtu?jp$Qc4YvGZKu;5->9M~iJ?RNXg{2hjd^4WP8Xh6cKLGi& z4zZOJ8t~O>3lS}`AmpE&PggaYp(CzB4Z_j*mS919n|C6+ z=Lu}@j~amOfl9bHHx@4kb@DJ{B~5QxcF+R%7J9pOLjvf00}(5S?sM;F@g zy@HFf4I(7f^@Oebax~7jN}+VcS@bmYLAC7yP83SP!B7wD-ygxJd^L!V&BWVN=5U(c z1SgX#Y(w+4IJAp!L|KK#rrN_`%3~<~`U4XxpW*6QZ)}p(qi#7R`tsD8BDGs!mYjey zmkvU>N1i_2*2R`9F4Uc7kDHbX5a}+$jcF4g{D%VCPiWFWmoRlaZNcWRilk{2iZvQh z5bYR4>%{m-DXzxsjth7>*Bi4$eud4BDKHGbhK+Nj5F8pqsGdsVkABBGMuK>z8hG6z zNzeHsspM`0csir$L*ZfkxbY4;_pNa)=Lgu#e%$HNLD5nr($1cR6(hbu(*7A%&9lHH zX<0HFH=RxwmO_#W;Zj+P_Q#X)h5Zc2W5!drkS(1*IS8lOY^0|4v&|H^$Uk)!_2=&( zSNs(W=ZVvAZ@N(zsYwA-wD4{go36Qg;q!G}yb6AeD1&9VY^@886Y8+*??>;Kn<((S zf{Q*UV5J)ex0rk+9W#M!lN!cydALhn*jp|K&E@eh=$wl^)^#}SFAdHiY4Yzpi|4yC z@Xk>Q7CK(YTn~sA?uN&QQ|R=rfTfH+l50M|Z`1?aLOOD`Mc~_*lc@W14|-Ltv5MCV z5s|m})^AFr2fC9wnkwH*(elohxaJjv?*lwUJlCbI*OjQTQwYB`_OKn^7o)Wq(3H4< z+Go?~jeiXK|FB2gjH{T-Hz#??5%j>>4u^j)!yH9Nhy>3>d6yocH57B}d!e*pFFLd| zpwQNXUYn`t)Dxo*PHJeX_o6^Q20u1A!yxA>LcU&rj(ai&?{37x*M`*AJ{|X>K0~F_ z1W#^$VP8Jzgwger=}eb83EwcIwIW>*ZaIWUtZ!^{J8#5=e8z@>r|cCI`*7;!6ZU$Q ztr%AE{Q5AWB5euxWrI1Ulf%W@rl+N5lzNdxKak=6$$w4z#uv~4etCDv}SuCX5MMayjVozrzGRN?Jqd4V}rrw z_xMvj4h<)S=+8ZcNORSqoxUZ=vyMl{;!x^|D1@PZ95fQfl0lvd{dw>m2Fz*^%c{cu z>r3!$PdWUL1yk5h(I63v?bl#O+dQh67z4Kp3FuFDLacfn ziK;~7_qBK7c5E`$%6p^pSOMY{IT+tE7q?rSY24BXI6R(A@A?!d`f)a67hULR2m;HHi|Y3gI6PnvHjkl0Dm4&) z7mYuUP}j$B}qZDMHO(@-b>qn0n6G!*Z7inUAQ$XGFqw zc?`B6sGykfk8$*SI(V;sKtpUZENvo5t}PDg12K5*AA%m?4E)*R4fi*KvAtJ-Nxo^= z9>YL(<0-7km`d*FRmk2d2J_Mf(LDYPb%&os!>tT>jJSmFktvvBUkZ(sJm?Ge;6gwN za^ll*-ysO~$7j%-RzdzhO9yvTCFa?$f~RvBB1Q{Q-$psQBbAJi15q#!%tZNv3<%%J z$KIGGeEYH#LOoG%p1U3+q)ITcJqp)dlHe|$33G8il#<#Jo05&)1=g@r<%54a4g0J- zaBf-&s?PF&kqpdgwBfbffE39V??2DPJSkNg8#|E{losGs$ui7-I0p--W#O~x3G8#{ zL(=jbp3mQf(N}`8`DhCc@ojLhb1WV1VjwD>1Fe-jjGM8FI!rfWlIH@Pij70*h#(xC z9e}SIGPL(!xXz>e2UG4*kXk_GJ~~b z``nCbeTtzk$PeX#G`KmK(Dv;?NMku5SNJKWEzP1H!!+>6)?&0Y7groZG40@c*gBNs zO#CHi9$taqZN1o+{Tgbwl2OxpjU;RGG5T~Ws(zhIx`KBw@k=NsmAm7ahhTkVV!A#3D|M$7xEm-!&BWmLa{&9KeS!Ip8Yh2+pn1D?> z;V3(hhz`>X%ubV~CuuM7VQ(QUKJ)Nk+5?mvEWneQk`&l+9Y#75)To|^ytr(^e4mG+ zw@Gkz(xm`{E4bupgqq3{te5qHLQyfg3_J0))r$;PAHh1-Hu8KE4gLWhuD;|W-A9wo zX(!^-f{~rhI2XU>@NrwN1WVYX zNa=V6#(gozy}lH3xL1QHL zuGkz9_@J;D#bIu!dYKIQ!b5O57=|_v!Rew6|^wNolTCL zQn9;Hn%vsk>FwGoOp2^S*bW~OaZbS?FAZ~*vM@5G9hWaA<5p%h9o-y(K9&$!ww6Ko zo*+jf0&(Pd6Nzm&fh?N|G;vE992YUj%jO}rI2PeOQlz?X4U%UR()!#|w0@1mVp>4= z;?yt9hE|_BgBg5R^cTlL zCbIzf*Ba5X{~o^AmmpxLE7i6M>U7#Ps`D2m9qwqFm5@)a`bJdBFpyri9g7+YsCN|) zum329ommlF!ffecy$->s5}gmSV4YTh;*H~I*7;CuJC%y*nk#YEEf1}(+1NO?0LwRf zAhtdl*B=QHt6z|deOwqN^N|*)NUxUmfbUie*-&Ge9r zXjT%MBlTe`K20qKD=8gw8;a2v8US6M3F#(QP`9B!c3I58NBCZrf)lukfD#x)^ zX{dkQg88Ro$@)Y!!m{h&B3Fo#m)c}waSmL7eDfbn5f)p|JqI9V0L>c_TI*(Zi&+}$?T??8up5WDq zqQ)_U9M8?@7Z~8K}gNF)MpH#+Vy+o`` z&V^}WG;|*DDN?wRGDAjS*?VnpLv!GF#tf_S>Y(s2w7i*w`MWBREf_byHLIa=PZJXPMVQzZ zg@|S?_)m35cC$Hr+l#TUwG_`DC&SKJ6&*LNFfyVD?w4e7rON@$C+s1q&PLr=A56Xa zoqajl4G{)w5Vk)TJvaF{<&+PNViP<|4u{th!Fu_p4<1=9z@YdP>>5Q_$|=UJ7y8(> z+yG&#ig8BZgW5YP=zU>KE&i_foFWIc%gRW3F9T;r7Yo<7la$E|dUyav#G8pvTfG?qPFrC@Nj$+9| zNO=~YFA=KQ4&_RXG3XT z7TN`w_VFtZ#Z5_gUdxAmN-i!WC?a897C!jKLVDXuv_39G@UwVWZRcZ}HXkMR-new2 z9Bm`xfK!TEu+`YxR|1}EI(D~b;Zfj9oNUcS(bilXD9eX(M=5rm%!gW79wPK| zvBWzYJ+JHt_Tu^gi==E2l;1xAF0U~Xvz z0)qKqf2zhJwK%+C@^N5mH99U9;`SeF&=in|2iKm{#>EBL^)d)@zvLlRDGfzu7K0@o zD%c|_hSVz%mENt|g=O6%SIHRXCLBi`)EM^fgzbHL6PRxl@jQ z!{w0PpA2JjE<^&?plkUCK^7)rurM0j2nI8v6R=U>)3#An;9o7myO|t>TrNT9m2CV# z0mk&_A^vJPx(?=pua^bgEFMIoim+@=1vU!yKN5Fl;m;kp@K20DiC}#d+!T)c9>Fkw zU<21N(U{K8N2zu^+DwYEM!pEPp(S8?+;FrYA7vIrc)f2ub+Do^uEZZ!t(B+~U5yKR zJY;;2MCF$p%$och;ty+~C0m2FWfHWw=^b8vEcHkuzKK;EJP za-UbgtttrLxWO=azY49pbMQ=B;NiUyX#BVa><@9c=+DPO`BF69%7m)O3Y5QKu$j+; zvwtB1Ywd9BbvR@OGH_hO9nJS+A<<`m%lAv5VVnbrqk=J+l!ddRm9UBR!l-o|s9A8( PyEGODb&D{1dnW!5&PYuU literal 0 HcmV?d00001 diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index ae5216a1d0a..c632b6ccc78 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -19,40 +19,20 @@ from torch.nn import Module from typing_extensions import Literal +from torchmetrics.functional.image.lpips import _LPIPS, _lpips_compute, _lpips_update, _NoTrainLpips from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout from torchmetrics.utilities.imports import _LPIPS_AVAILABLE -if _LPIPS_AVAILABLE: - from lpips import LPIPS as _LPIPS - def _download_lpips() -> None: - _LPIPS(pretrained=True, net="vgg") +def _download_lpips() -> None: + _LPIPS(pretrained=True, net="vgg") - if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips): - __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] -else: - - class _LPIPS(Module): - pass +if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips): __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] -class NoTrainLpips(_LPIPS): - """Wrapper to make sure LPIPS never leaves evaluation mode.""" - - def train(self, mode: bool) -> "NoTrainLpips": - """Force network to always be in evaluation mode.""" - return super().train(False) - - -def _valid_img(img: Tensor, normalize: bool) -> bool: - """Check that input is a valid image to the network.""" - value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1 - return img.ndim == 4 and img.shape[1] == 3 and value_check - - class LearnedPerceptualImagePatchSimilarity(Metric): """The Learned Perceptual Image Patch Similarity (`LPIPS_`) calculates the perceptual similarity between two images. @@ -117,7 +97,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): def __init__( self, - net_type: str = "alex", + net_type: Literal["alex", "alex", "squeeze"] = "alex", reduction: Literal["sum", "mean"] = "mean", normalize: bool = False, **kwargs: Any, @@ -133,7 +113,7 @@ def __init__( valid_net_type = ("vgg", "alex", "squeeze") if net_type not in valid_net_type: raise ValueError(f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}.") - self.net = NoTrainLpips(net=net_type, verbose=False) + self.net = _NoTrainLpips(net=net_type) valid_reduction = ("mean", "sum") if reduction not in valid_reduction: @@ -149,21 +129,10 @@ def __init__( def update(self, img1: Tensor, img2: Tensor) -> None: """Update internal states with lpips score.""" - if not (_valid_img(img1, self.normalize) and _valid_img(img2, self.normalize)): - raise ValueError( - "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]." - f" Got input with shape {img1.shape} and {img2.shape} and values in range" - f" {[img1.min(), img1.max()]} and {[img2.min(), img2.max()]} when all values are" - f" expected to be in the {[0,1] if self.normalize else [-1,1]} range." - ) - loss = self.net(img1, img2, normalize=self.normalize).squeeze() + loss, total = _lpips_update(img1, img2, net=self.net, normalize=self.normalize) self.sum_scores += loss.sum() - self.total += img1.shape[0] + self.total += total def compute(self) -> Tensor: """Compute final perceptual similarity metric.""" - if self.reduction == "mean": - return self.sum_scores / self.total - if self.reduction == "sum": - return self.sum_scores - return None + return _lpips_compute(self.sum_scores, self.total, self.reduction) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 7a56f7db610..b5cdbab9165 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -34,6 +34,7 @@ _PYCOCOTOOLS_AVAILABLE: bool = package_available("pycocotools") _TORCHVISION_AVAILABLE: bool = package_available("torchvision") _TORCHVISION_GREATER_EQUAL_0_8: Optional[bool] = compare_version("torchvision", operator.ge, "0.8.0") +_TORCHVISION_GREATER_EQUAL_0_13: Optional[bool] = compare_version("torchvision", operator.ge, "0.13.0") _TQDM_AVAILABLE: bool = package_available("tqdm") _TRANSFORMERS_AVAILABLE: bool = package_available("transformers") _PESQ_AVAILABLE: bool = package_available("pesq") From fe8dbd34b180545517d1604d117a8ed5c7fd5fe2 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 1 Mar 2023 10:03:16 +0100 Subject: [PATCH 02/14] changelog --- CHANGELOG.md | 4 ++++ src/torchmetrics/functional/image/lpips.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d072b143188..3c003b16156 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485)) + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) @@ -74,6 +75,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed `__iter__` method from raising `NotImplementedError` to `TypeError` by setting to `None` ([#1538](https://github.com/Lightning-AI/metrics/pull/1538)) +- Changed `LPIPS` implementation to no more rely on third-party package ([#1575](https://github.com/Lightning-AI/metrics/pull/1575)) + + ### Deprecated - diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 940d944642d..10a7f26d36e 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -18,6 +18,9 @@ # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py # and with adjustments from # https://github.com/richzhang/PerceptualSimilarity/pull/114/files +# Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang +# All rights reserved. + from collections import namedtuple import torch From d2eac43fbe7b6d7b0c5c59080846d142adc300f2 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 1 Mar 2023 10:50:43 +0100 Subject: [PATCH 03/14] try fix --- src/torchmetrics/functional/image/lpips.py | 30 +++++++++++++--------- src/torchmetrics/image/lpip.py | 12 +++++---- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 10a7f26d36e..84e9ffc3b72 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -30,7 +30,7 @@ from torchvision import models as tv from typing_extensions import Literal -from torchmetrics.utilities.imports import _TORCHVISION_GREATER_EQUAL_0_13 +from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 _weight_map = { "squeezenet1_1": "SqueezeNet1_1_Weights", @@ -38,6 +38,9 @@ "vgg16": "VGG16_Weights", } +if not _TORCHVISION_AVAILABLE: + __doctest_skip__ = ["learned_perceptual_image_patch_similarity"] + def _get_net(net: str, pretrained: bool): if _TORCHVISION_GREATER_EQUAL_0_13: @@ -335,16 +338,10 @@ def forward(self, in0, in1, retperlayer=False, normalize=False): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - if self.lpips: - if self.spatial: - res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] - else: - res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] + if self.spatial: + res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: - if self.spatial: - res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] - else: - res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] + res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] val = 0 for layer in range(self.L): @@ -408,8 +405,17 @@ def learned_perceptual_image_patch_similarity( reduction: str indicating how to reduce over the batch dimension. Choose between `'sum'` or `'mean'`. normalize: by default this is ``False`` meaning that the input is expected to be in the [-1,1] range. If set to ``True`` will instead expect input to be in the ``[0,1]`` range. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> _ = torch.manual_seed(123) + >>> from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity + >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 + >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 + >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='vgg') + tensor(0.3493, grad_fn=) + """ net = _NoTrainLpips(net_type) loss, total = _lpips_update(img1, img2, net, normalize) - return _lpips_compute(loss, total, reduction) + return _lpips_compute(loss.sum(), total, reduction) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index c632b6ccc78..44a686be680 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -22,14 +22,16 @@ from torchmetrics.functional.image.lpips import _LPIPS, _lpips_compute, _lpips_update, _NoTrainLpips from torchmetrics.metric import Metric from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout -from torchmetrics.utilities.imports import _LPIPS_AVAILABLE +from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCHVISION_AVAILABLE +if _TORCHVISION_AVAILABLE: -def _download_lpips() -> None: - _LPIPS(pretrained=True, net="vgg") + def _download_lpips() -> None: + _LPIPS(pretrained=True, net="vgg") - -if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips): + if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_lpips): + __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] +else: __doctest_skip__ = ["LearnedPerceptualImagePatchSimilarity", "LPIPS"] From e007f4e4367da2c30f76fd20ca3af29348d15dd2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Sat, 11 Mar 2023 15:09:25 +0100 Subject: [PATCH 04/14] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- src/torchmetrics/functional/image/lpips.py | 58 +++++++--------------- 1 file changed, 17 insertions(+), 41 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 84e9ffc3b72..7cb5045e4a0 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -53,59 +53,35 @@ def _get_net(net: str, pretrained: bool): return pretrained_features -class Squeezenet(torch.nn.Module): - """Squeezenet implementation.""" +class SqueezeNet(torch.nn.Module): + """SqueezeNet implementation.""" - def __init__(self, requires_grad: int = False, pretrained: int = True): + def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super().__init__() pretrained_features = _get_net("squeezenet1_1", pretrained) - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - self.slice6 = torch.nn.Sequential() - self.slice7 = torch.nn.Sequential() self.N_slices = 7 - for x in range(2): - self.slice1.add_module(str(x), pretrained_features[x]) - for x in range(2, 5): - self.slice2.add_module(str(x), pretrained_features[x]) - for x in range(5, 8): - self.slice3.add_module(str(x), pretrained_features[x]) - for x in range(8, 10): - self.slice4.add_module(str(x), pretrained_features[x]) - for x in range(10, 11): - self.slice5.add_module(str(x), pretrained_features[x]) - for x in range(11, 12): - self.slice6.add_module(str(x), pretrained_features[x]) - for x in range(12, 13): - self.slice7.add_module(str(x), pretrained_features[x]) + feature_ranges = [range(2), range(2, 5), range(5, 8), range(8, 10), range(10, 11), range(11, 12), range(12, 13)] + for feature_range in feature_ranges: + slice = torch.nn.Sequential() + for i in feature_range: + slice.add_module(str(i), pretrained_features[i]) + slices.append(slice) + + self.slices = nn.ModuleList(slices) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, x): """Process input.""" - h = self.slice1(x) - h_relu1 = h - h = self.slice2(h) - h_relu2 = h - h = self.slice3(h) - h_relu3 = h - h = self.slice4(h) - h_relu4 = h - h = self.slice5(h) - h_relu5 = h - h = self.slice6(h) - h_relu6 = h - h = self.slice7(h) - h_relu7 = h vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) - out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) - - return out + + relus = [] + for slice in self.slices: + x = slice(x) + relus.append(x) + return vgg_outputs(*relus) class Alexnet(torch.nn.Module): From d746c7e0d73926b36f645dcbfabc44f6c9640572 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 11 Mar 2023 15:16:27 +0100 Subject: [PATCH 05/14] adjust copyright --- src/torchmetrics/functional/image/lpips.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 7cb5045e4a0..7a5804097e1 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -18,8 +18,10 @@ # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py # and with adjustments from # https://github.com/richzhang/PerceptualSimilarity/pull/114/files +# due to package no longer being maintained # Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang # All rights reserved. +# License under BSD 2-clause from collections import namedtuple @@ -61,6 +63,7 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None pretrained_features = _get_net("squeezenet1_1", pretrained) self.N_slices = 7 + slices = [] feature_ranges = [range(2), range(2, 5), range(5, 8), range(8, 10), range(10, 11), range(11, 12), range(12, 13)] for feature_range in feature_ranges: slice = torch.nn.Sequential() @@ -76,7 +79,7 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None def forward(self, x): """Process input.""" vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) - + relus = [] for slice in self.slices: x = slice(x) @@ -266,7 +269,7 @@ def __init__( net_type = Alexnet self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == "squeeze": - net_type = Squeezenet + net_type = SqueezeNet self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) From 54bf257294a7b575232a17412c6c7508edb9dceb Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 11 Mar 2023 15:22:34 +0100 Subject: [PATCH 06/14] fixes --- src/torchmetrics/functional/image/lpips.py | 4 ++-- src/torchmetrics/image/lpip.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 7a5804097e1..37321b993d7 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -44,7 +44,7 @@ __doctest_skip__ = ["learned_perceptual_image_patch_similarity"] -def _get_net(net: str, pretrained: bool): +def _get_net(net: str, pretrained: bool) -> nn.Module: if _TORCHVISION_GREATER_EQUAL_0_13: if pretrained: pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features @@ -392,7 +392,7 @@ def learned_perceptual_image_patch_similarity( >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='vgg') - tensor(0.3493, grad_fn=) + tensor(0.1441, grad_fn=) """ net = _NoTrainLpips(net_type) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index ea535290f4a..ab4ad73887a 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -88,7 +88,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> lpips(img1, img2) - tensor(0.3493, grad_fn=) + tensor(0.3747, grad_fn=) """ is_differentiable: bool = True From 65b9c06a9651f5bd84ed787d0ddf3de6ca671df7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Mar 2023 10:20:56 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/lpips.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 37321b993d7..4643c8dfc43 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -90,7 +90,7 @@ def forward(self, x): class Alexnet(torch.nn.Module): """Alexnet implementation.""" - def __init__(self, requires_grad=False, pretrained=True): + def __init__(self, requires_grad=False, pretrained=True) -> None: super().__init__() alexnet_pretrained_features = _get_net("alexnet", pretrained) @@ -135,7 +135,7 @@ def forward(self, x): class Vgg16(torch.nn.Module): """Vgg16 implementation.""" - def __init__(self, requires_grad=False, pretrained=True): + def __init__(self, requires_grad=False, pretrained=True) -> None: super().__init__() vgg_pretrained_features = _get_net("vgg16", pretrained) @@ -196,7 +196,7 @@ def normalize_tensor(in_feat, eps=1e-10): class ScalingLayer(nn.Module): """Scaling layer.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) @@ -209,7 +209,7 @@ def forward(self, inp): class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv.""" - def __init__(self, chn_in, chn_out=1, use_dropout=False): + def __init__(self, chn_in, chn_out=1, use_dropout=False) -> None: super().__init__() layers = ( @@ -240,7 +240,7 @@ def __init__( use_dropout=True, model_path=None, eval_mode=True, - ): + ) -> None: """Initializes a perceptual loss torch.nn.Module. Args: From 36e141e0d363354264cced789f6d395aa0d6ce04 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 13 Apr 2023 14:43:31 +0200 Subject: [PATCH 08/14] fix tests --- src/torchmetrics/functional/image/lpips.py | 74 +++++++++------------- 1 file changed, 29 insertions(+), 45 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 4643c8dfc43..025c93dad73 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -22,12 +22,12 @@ # Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang # All rights reserved. # License under BSD 2-clause - +import inspect +import os from collections import namedtuple +from typing import Optional, Tuple, Union import torch -import torchvision -from packaging import version from torch import Tensor, nn from torchvision import models as tv from typing_extensions import Literal @@ -76,7 +76,7 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None for param in self.parameters(): param.requires_grad = False - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Process input.""" vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) @@ -90,7 +90,7 @@ def forward(self, x): class Alexnet(torch.nn.Module): """Alexnet implementation.""" - def __init__(self, requires_grad=False, pretrained=True) -> None: + def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super().__init__() alexnet_pretrained_features = _get_net("alexnet", pretrained) @@ -114,7 +114,7 @@ def __init__(self, requires_grad=False, pretrained=True) -> None: for param in self.parameters(): param.requires_grad = False - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Process input.""" h = self.slice1(x) h_relu1 = h @@ -127,15 +127,13 @@ def forward(self, x): h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) - out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) - - return out + return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) class Vgg16(torch.nn.Module): """Vgg16 implementation.""" - def __init__(self, requires_grad=False, pretrained=True) -> None: + def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None: super().__init__() vgg_pretrained_features = _get_net("vgg16", pretrained) @@ -159,7 +157,7 @@ def __init__(self, requires_grad=False, pretrained=True) -> None: for param in self.parameters(): param.requires_grad = False - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Process input.""" h = self.slice1(x) h_relu1_2 = h @@ -172,22 +170,20 @@ def forward(self, x): h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) - out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) - return out - -def spatial_average(in_tens, keepdim=True): +def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor: """Spatial averaging over heigh and width of images.""" return in_tens.mean([2, 3], keepdim=keepdim) -def upsample(in_tens, out_hw=(64, 64)): # assumes scale factor is same for H and W +def upsample(in_tens: Tensor, out_hw: Tuple[int] = (64, 64)) -> Tensor: # assumes scale factor is same for H and W """Upsample input with bilinear interpolation.""" return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) -def normalize_tensor(in_feat, eps=1e-10): +def normalize_tensor(in_feat: Tensor, eps: float = 1e-10) -> Tensor: """Normalize tensors.""" norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) return in_feat / (norm_factor + eps) @@ -201,7 +197,7 @@ def __init__(self) -> None: self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) - def forward(self, inp): + def forward(self, inp: Tensor) -> Tensor: """Process input.""" return (inp - self.shift) / self.scale @@ -209,22 +205,16 @@ def forward(self, inp): class NetLinLayer(nn.Module): """A single linear layer which does a 1x1 conv.""" - def __init__(self, chn_in, chn_out=1, use_dropout=False) -> None: + def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> None: super().__init__() - layers = ( - [ - nn.Dropout(), - ] - if (use_dropout) - else [] - ) + layers = [nn.Dropout()] if use_dropout else [] layers += [ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Process input.""" return self.model(x) @@ -232,14 +222,14 @@ def forward(self, x): class _LPIPS(nn.Module): def __init__( self, - pretrained=True, - net="alex", - spatial=False, - pnet_rand=False, - pnet_tune=False, - use_dropout=True, - model_path=None, - eval_mode=True, + pretrained: bool = True, + net: Literal["alex", "vgg", "squeeze"] = "alex", + spatial: bool = False, + pnet_rand: bool = False, + pnet_tune: bool = False, + use_dropout: bool = True, + model_path: Optional[str] = None, + eval_mode: bool = True, ) -> None: """Initializes a perceptual loss torch.nn.Module. @@ -259,7 +249,6 @@ def __init__( self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial - self.version = version self.scaling_layer = ScalingLayer() if self.pnet_type in ["vgg", "vgg16"]: @@ -289,9 +278,6 @@ def __init__( if pretrained: if model_path is None: - import inspect - import os - model_path = os.path.abspath( os.path.join(inspect.getfile(self.__init__), "..", f"lpips_models/{net}.pth") ) @@ -301,15 +287,13 @@ def __init__( if eval_mode: self.eval() - def forward(self, in0, in1, retperlayer=False, normalize=False): + def forward(self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False) -> Tensor: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled - in0_input, in1_input = ( - (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) - ) + in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} @@ -345,7 +329,7 @@ def _valid_img(img: Tensor, normalize: bool) -> bool: return img.ndim == 4 and img.shape[1] == 3 and value_check -def _lpips_update(img1, img2, net, normalize): +def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> Tuple[Tensor, Union[int, Tensor]]: if not (_valid_img(img1, normalize) and _valid_img(img2, normalize)): raise ValueError( "Expected both input arguments to be normalized tensors with shape [N, 3, H, W]." @@ -357,7 +341,7 @@ def _lpips_update(img1, img2, net, normalize): return loss, img1.shape[0] -def _lpips_compute(sum_scores, total, reduction): +def _lpips_compute(sum_scores: Tensor, total: Union[Tensor, int], reduction: Literal["sum", "mean"] = "mean") -> Tensor: return sum_scores / total if reduction == "mean" else sum_scores From 54bf2d4df7fb10e17101c65f0cf24e546d6ef039 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 14 Apr 2023 09:37:24 +0200 Subject: [PATCH 09/14] conditional check tv --- src/torchmetrics/functional/image/lpips.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 025c93dad73..cb05ef1c6fb 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -29,7 +29,6 @@ import torch from torch import Tensor, nn -from torchvision import models as tv from typing_extensions import Literal from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13 @@ -42,6 +41,9 @@ if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["learned_perceptual_image_patch_similarity"] + tv = None +else: + from torchvision import models as tv def _get_net(net: str, pretrained: bool) -> nn.Module: From 8d70c71d20f4da7f2a1be941fe1de964ab0b3c79 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 14 Apr 2023 10:45:25 +0200 Subject: [PATCH 10/14] readd req to tests --- requirements/image_test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/image_test.txt b/requirements/image_test.txt index a189111d9dd..144150f587b 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -5,3 +5,4 @@ scikit-image >=0.19.0, <=0.20.0 kornia >=0.6.7, <0.6.12 pytorch-msssim ==0.2.1 sewar >=0.4.4, <=0.4.5 +lpips <=0.1.4 From 35c3dce2ae18b8f5fc8b1fc8527c8fe633ee508f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 14 Apr 2023 14:47:28 +0200 Subject: [PATCH 11/14] fix mypy --- src/torchmetrics/functional/image/lpips.py | 53 ++++++++++++---------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index cb05ef1c6fb..021c8d87b11 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -25,7 +25,7 @@ import inspect import os from collections import namedtuple -from typing import Optional, Tuple, Union +from typing import List, NamedTuple, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -41,12 +41,11 @@ if not _TORCHVISION_AVAILABLE: __doctest_skip__ = ["learned_perceptual_image_patch_similarity"] - tv = None else: from torchvision import models as tv -def _get_net(net: str, pretrained: bool) -> nn.Module: +def _get_net(net: str, pretrained: bool) -> nn.modules.container.Sequential: if _TORCHVISION_GREATER_EQUAL_0_13: if pretrained: pretrained_features = getattr(tv, net)(weights=getattr(tv, _weight_map[net]).IMAGENET1K_V1).features @@ -78,15 +77,15 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None for param in self.parameters(): param.requires_grad = False - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> NamedTuple: """Process input.""" - vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) + squeeze_output = namedtuple("squeeze_output", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) relus = [] for slice in self.slices: x = slice(x) relus.append(x) - return vgg_outputs(*relus) + return squeeze_output(*relus) # type: ignore[return-value] class Alexnet(torch.nn.Module): @@ -116,7 +115,7 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None for param in self.parameters(): param.requires_grad = False - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> NamedTuple: """Process input.""" h = self.slice1(x) h_relu1 = h @@ -128,8 +127,8 @@ def forward(self, x: Tensor) -> Tensor: h_relu4 = h h = self.slice5(h) h_relu5 = h - alexnet_outputs = namedtuple("AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) - return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + alexnet_outputs = namedtuple("alexnet_outputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) + return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) # type: ignore[return-value] class Vgg16(torch.nn.Module): @@ -159,7 +158,7 @@ def __init__(self, requires_grad: bool = False, pretrained: bool = True) -> None for param in self.parameters(): param.requires_grad = False - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor) -> NamedTuple: """Process input.""" h = self.slice1(x) h_relu1_2 = h @@ -171,8 +170,8 @@ def forward(self, x: Tensor) -> Tensor: h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h - vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) - return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + vgg_outputs = namedtuple("vgg_outputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) + return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) # type: ignore[return-value] def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor: @@ -180,7 +179,7 @@ def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor: return in_tens.mean([2, 3], keepdim=keepdim) -def upsample(in_tens: Tensor, out_hw: Tuple[int] = (64, 64)) -> Tensor: # assumes scale factor is same for H and W +def upsam(in_tens: Tensor, out_hw: Tuple[int, int] = (64, 64)) -> Tensor: """Upsample input with bilinear interpolation.""" return nn.Upsample(size=out_hw, mode="bilinear", align_corners=False)(in_tens) @@ -212,7 +211,7 @@ def __init__(self, chn_in: int, chn_out: int = 1, use_dropout: bool = False) -> layers = [nn.Dropout()] if use_dropout else [] layers += [ - nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), # type: ignore[list-item] ] self.model = nn.Sequential(*layers) @@ -257,10 +256,10 @@ def __init__( net_type = Vgg16 self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == "alex": - net_type = Alexnet + net_type = Alexnet # type: ignore[assignment] self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == "squeeze": - net_type = SqueezeNet + net_type = SqueezeNet # type: ignore[assignment] self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) @@ -276,12 +275,12 @@ def __init__( self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] - self.lins = nn.ModuleList(self.lins) + self.lins = nn.ModuleList(self.lins) # type: ignore[assignment] if pretrained: if model_path is None: model_path = os.path.abspath( - os.path.join(inspect.getfile(self.__init__), "..", f"lpips_models/{net}.pth") + os.path.join(inspect.getfile(self.__init__), "..", f"lpips_models/{net}.pth") # type: ignore[misc] ) self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) @@ -289,7 +288,9 @@ def __init__( if eval_mode: self.eval() - def forward(self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False) -> Tensor: + def forward( + self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize: bool = False + ) -> Union[int, Tuple[int, List[Tensor]]]: if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 @@ -304,13 +305,15 @@ def forward(self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.spatial: - res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] + res = [ + upsam(self.lins[kk](diffs[kk]), out_hw=in0.shape[2:]) for kk in range(self.L) # type: ignore[arg-type] + ] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] val = 0 for layer in range(self.L): - val += res[layer] + val += res[layer] # type: ignore[assignment] if retperlayer: return (val, res) @@ -320,7 +323,7 @@ def forward(self, in0: Tensor, in1: Tensor, retperlayer: bool = False, normalize class _NoTrainLpips(_LPIPS): """Wrapper to make sure LPIPS never leaves evaluation mode.""" - def train(self, mode: bool) -> "_NoTrainLpips": + def train(self, mode: bool) -> "_NoTrainLpips": # type: ignore[override] """Force network to always be in evaluation mode.""" return super().train(False) @@ -328,7 +331,7 @@ def train(self, mode: bool) -> "_NoTrainLpips": def _valid_img(img: Tensor, normalize: bool) -> bool: """Check that input is a valid image to the network.""" value_check = img.max() <= 1.0 and img.min() >= 0.0 if normalize else img.min() >= -1 - return img.ndim == 4 and img.shape[1] == 3 and value_check + return img.ndim == 4 and img.shape[1] == 3 and value_check # type: ignore[return-value] def _lpips_update(img1: Tensor, img2: Tensor, net: nn.Module, normalize: bool) -> Tuple[Tensor, Union[int, Tensor]]: @@ -350,7 +353,7 @@ def _lpips_compute(sum_scores: Tensor, total: Union[Tensor, int], reduction: Lit def learned_perceptual_image_patch_similarity( img1: Tensor, img2: Tensor, - net_type: str = "alex", + net_type: Literal["alex", "vgg", "squeeze"] = "alex", reduction: Literal["sum", "mean"] = "mean", normalize: bool = False, ) -> Tensor: @@ -381,6 +384,6 @@ def learned_perceptual_image_patch_similarity( tensor(0.1441, grad_fn=) """ - net = _NoTrainLpips(net_type) + net = _NoTrainLpips(net=net_type) loss, total = _lpips_update(img1, img2, net, normalize) return _lpips_compute(loss.sum(), total, reduction) From a55dd4046fab25304a4cc3c94313ad1392923a4b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 14 Apr 2023 15:06:22 +0200 Subject: [PATCH 12/14] fix --- requirements/image.txt | 1 + requirements/image_test.txt | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/image.txt b/requirements/image.txt index 0f9518aa487..5807b55ac50 100644 --- a/requirements/image.txt +++ b/requirements/image.txt @@ -4,3 +4,4 @@ scipy >1.0.0, <1.11.0 torchvision >=0.8, <=0.15.1 torch-fidelity <=0.3.0 +lpips <=0.1.4 diff --git a/requirements/image_test.txt b/requirements/image_test.txt index 144150f587b..a189111d9dd 100644 --- a/requirements/image_test.txt +++ b/requirements/image_test.txt @@ -5,4 +5,3 @@ scikit-image >=0.19.0, <=0.20.0 kornia >=0.6.7, <0.6.12 pytorch-msssim ==0.2.1 sewar >=0.4.4, <=0.4.5 -lpips <=0.1.4 From 90c08daba02c97f8ca4567f108be2e3b9526f2d3 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 14 Apr 2023 15:28:37 +0200 Subject: [PATCH 13/14] fix --- src/torchmetrics/functional/image/lpips.py | 8 ++++---- src/torchmetrics/image/lpip.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/image/lpips.py b/src/torchmetrics/functional/image/lpips.py index 021c8d87b11..0e7b6cf3fec 100644 --- a/src/torchmetrics/functional/image/lpips.py +++ b/src/torchmetrics/functional/image/lpips.py @@ -85,7 +85,7 @@ def forward(self, x: Tensor) -> NamedTuple: for slice in self.slices: x = slice(x) relus.append(x) - return squeeze_output(*relus) # type: ignore[return-value] + return squeeze_output(*relus) class Alexnet(torch.nn.Module): @@ -128,7 +128,7 @@ def forward(self, x: Tensor) -> NamedTuple: h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("alexnet_outputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) - return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) # type: ignore[return-value] + return alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) class Vgg16(torch.nn.Module): @@ -171,7 +171,7 @@ def forward(self, x: Tensor) -> NamedTuple: h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("vgg_outputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) - return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) # type: ignore[return-value] + return vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) def spatial_average(in_tens: Tensor, keepdim: bool = True) -> Tensor: @@ -381,7 +381,7 @@ def learned_perceptual_image_patch_similarity( >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> learned_perceptual_image_patch_similarity(img1, img2, net_type='vgg') - tensor(0.1441, grad_fn=) + tensor(0.3485, grad_fn=) """ net = _NoTrainLpips(net=net_type) diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index 3dec8fae7e4..d081e064e10 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -88,7 +88,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric): >>> img1 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> img2 = (torch.rand(10, 3, 100, 100) * 2) - 1 >>> lpips(img1, img2) - tensor(0.3747, grad_fn=) + tensor(0.3493, grad_fn=) """ is_differentiable: bool = True From 93a85e9643bd6eb983d9b72825a30e7ffac9134f Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 17 Apr 2023 13:59:45 +0200 Subject: [PATCH 14/14] include models --- MANIFEST.in | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 1dce3090745..81cf1a457dd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,10 +2,12 @@ graft wheelhouse recursive-exclude __pycache__ *.py[cod] *.orig +# include also models +recursive-include src *.pth # Include the README and CHANGELOG include *.md -recursive-include torchmetrics *.md +recursive-include src *.md # Include the license file include LICENSE