From 20406b77f539347a3b452535ef1aaeef60525730 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Wed, 9 Mar 2022 22:44:27 +0100 Subject: [PATCH 01/11] millionaid --- docs/api/datasets.rst | 5 + tests/data/millionaid/data.py | 54 +++ tests/data/millionaid/test.zip | Bin 0 -> 3244 bytes .../grassland/meadow/P0115918.jpg | Bin 0 -> 1240 bytes .../test/water_area/beach/P0060208.jpg | Bin 0 -> 1238 bytes tests/data/millionaid/train.zip | Bin 0 -> 3265 bytes .../grassland/meadow/P0115918.jpg | Bin 0 -> 1254 bytes .../train/water_area/beach/P0060208.jpg | Bin 0 -> 1231 bytes torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/millionaid.py | 397 ++++++++++++++++++ 10 files changed, 458 insertions(+) create mode 100644 tests/data/millionaid/data.py create mode 100644 tests/data/millionaid/test.zip create mode 100644 tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg create mode 100644 tests/data/millionaid/test/water_area/beach/P0060208.jpg create mode 100644 tests/data/millionaid/train.zip create mode 100644 tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg create mode 100644 tests/data/millionaid/train/water_area/beach/P0060208.jpg create mode 100644 torchgeo/datasets/millionaid.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index b6f51307e5c..086040a936d 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -224,6 +224,11 @@ LoveDA .. autoclass:: LoveDA +Million-AID +^^^^^^^^^^^ + +.. autoclass:: MillionAID + NASA Marine Debris ^^^^^^^^^^^^^^^^^^ diff --git a/tests/data/millionaid/data.py b/tests/data/millionaid/data.py new file mode 100644 index 00000000000..a531d409224 --- /dev/null +++ b/tests/data/millionaid/data.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import random +import shutil + +import numpy as np +from PIL import Image + +SIZE = 32 + +np.random.seed(0) +random.seed(0) + +PATHS = { + "train": [ + os.path.join( + "train", "agriculture_land", "grassland", "meadow", "P0115918.jpg" + ), + os.path.join("train", "water_area", "beach", "P0060208.jpg"), + ], + "test": [ + os.path.join("test", "agriculture_land", "grassland", "meadow", "P0115918.jpg"), + os.path.join("test", "water_area", "beach", "P0060208.jpg"), + ], +} + + +def create_file(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + for split, paths in PATHS.items(): + # remove old data + if os.path.isdir(split): + shutil.rmtree(split) + for path in paths: + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path) + + # compress data + shutil.make_archive(split, "zip", ".", split) + + # Compute checksums + with open(split + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{split}: {md5}") diff --git a/tests/data/millionaid/test.zip b/tests/data/millionaid/test.zip new file mode 100644 index 0000000000000000000000000000000000000000..15a0bbd3f46d7dbabb14545ac9c186d55a68c6de GIT binary patch literal 3244 zcmeHKX;4#H8U=&^je`qYtKf(z$PUu%3Su{jfB^);B18-!tP(I1X-xy;AX*40TOcST z0tzAoFcDCeuM`8&)#N> zu(ZHNTr2W-dY627i15cy#D!38`!ANrZkE7&QT}lBxll9;;du^*^!^HF`z9s~7K%W5 z!cYj5P(h3OiU5%Vv-vSWSHncMY0D%AiaG&>bhFwmxI)Yz+ijSlU4?tUzu+V1tKaOYO79#-Gj%_w?1b0;vOyV%}m(Dk>`5 z3~vc*7hg@xj?3QBHnuJZu#g0Z07QI%2LZ|!Kp%_db

;D{OCs@D{+Kv)*1lFa&&R zj5YyK8&IpQZEWrAY^?8WZ?6|hA&Ru?m57nPQ{8Ps5*Jk_smt8>UfS~=H;A)D<yP3%@5P^~Q_O;m??^n=rkd|Ry(`_V z>EL|?T?Py3(j*446f&X7%691{K`U|TWLeEKq|m{68FiPSp(+HkjAeR}Z4$cN`W)V3 zmvIC=J&!HuFB*Mvi8-FrVuv@w@FwU5w+LpD+!fa9EkeyPo|pGopR0m}yq&@;u`$un z>Cp-_ZD1}cLC8YZ+44sN#g=1>jmKdJbAXaN@?s)R=#6o#tw zuEo}~llebJSb_rjLd;+#G%vD8#oT;q&7;`z;V0dM9?k9N2WG}|ColJSDDH{ao8sh$ zMm{UkIK679W=anZUxvJmZ>Bmwj4T*T%CBI)MRN?D`%GiL(%tHlSzJ29V%cFdC@nNp#SRU1DK?`233tzli4-` zz5rb5^e+;9b+Mn@7v@>r6&79{K#Z8lUwE;^WVl4$aPQ@~8&A6_Gc&D3I}Q{c+>Id$EJGxQ{_b=C>6s(w8oWqkdd6dVH1*5Z;xgAJSPxezTsa%^aAB3>#V&`@tW^z@ zFop>Bcra^van@se$!(ay>2?DR*DMmUrs-u$7Sy3G@Nw<5>Y&i^Ire^Rm(m)IlQdI$1k%&;L_ zW%C(xFBSCK>Y7Fp5c74^o-VK!xSu^Zlp&{k<8niy$uv7fhJZM_rm{ld@y3z;0ejO6 z9x{0KQnBfFD5{5?&h=B)>$JXIk;nknAqsWG>2*JZRN98Q#v5!vl1-1?PmXqwHhg$J z18z_E_SH&honJWWdM=H$#-mJY7n@T{%b~ci6lmoFi+)#y0p@5aLu#Jb zR~msJmcjQ^7jnzw%NhvsyOaq=cQXnqLOa}(OSv-D5K|?5=o{Mz3W@EMh0`-Pm|3BNUvMZ7SRJm9 z!(ZMcW>B-?BeoEVzP9dzFV)kFYj>4y#)b0DezlPrT0sFq`=uPxz9eiv=zr3CeEBZb*dTHRNf4!bR*nT&ro1}kjO*d&D zUl9}hcG8cG>eu$*&tCz@|0Vcz2RBciKR5a5+4Dz}onPko=p?eY5dmyq#P}CjNI<~q IyN$Qs00;qLhyVZp literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d04bf8a52319264e9dabfeec47cd963b07d02fe GIT binary patch literal 1240 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L*w~*|8kG$e%teNk;dZC-MR-3t$H~B-Q!C3 zJjW-^+rEhVe0yv6K+!O-;F8PkJ=+pOHU=HGc&x4^uj8G&>f-Ip>ASXsm3_Z`%~L2c z?2*C7Q-}PDj?}E-e8OPK#b&v5{acaM*_nqYUECSFYIW<&N#`E(*mrzkS32Uctw(Ut(=0g-s^m^?ShFO}xs9c#ZwFJyJI)!d z%;H_&3Y8XwIUToVnVA~Lpk)`}Ru{xq9>Q6a`k$e*`h1q&v@0j3TTii^x?NK!PiNiD z9G>rY*w|xJ{oc>=_P%wx)@tAETi3rD_oQ2DZu=QK`^DQ;_kHs(mS($Id8h1}rM^S= zS#<8q3o3=)+RxujI>i>k&@iWXPEmn$VPU0h^YN{Du?|LS8o#xkNIT}47RG2_vAn@E zCUw=Nm~*;$(*w6mFS@HzbRzepZ-eX&o}B@XG4Et065gqFF?~DBb1CvzZdAGl!;N=! z;pZ-tHLW@rF~RNGxx(+Wy3OC6vyQrd?LR})%WX%3p0%5sgkCOtwe70fwx8v@-T&VN E0AVo#p8x;= literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/test/water_area/beach/P0060208.jpg b/tests/data/millionaid/test/water_area/beach/P0060208.jpg new file mode 100644 index 0000000000000000000000000000000000000000..226ff70e36b3403435b268424de45b26eaafd8b6 GIT binary patch literal 1238 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L+;(ZAEvQSD|{;97%WoFqMGoR$1xh`N1 zqm=&1tuy}I)Y?&epS`Kjud8HhllKh$vxnVgcX&!C3#BM-R1Z1FpvuY4Tqu#{>wB=+ zE@k`uebZlFf1Yb?yJS-KB+cxnH{UOPyLRi_%Rx`fo8Dbas(x{yc5#lXirz{s>rHJr zY==KNGE6>kKsi@+>6W)9Yx665t9C{_>ENCG_TDiC%a9<6&s&;z+e~iM{V=Q6?v}Nu z(4*A!N!%^+duIhipHW~cV>(ssrdn(+0`czVoWx5EK@gV3_T zNe)_!Tp@w;R=&9VD|74H>ra2&chWp%crj;wY}ED3o2uV^QQMcXX#2JuTRERa$Gn}Y zd~04}P{GL)IYk1@*DVh|HuBfqGAnmu*{642>(A~zK8;!Q)<(gDCvs#`WC|N@ba9G3 zi}pTQ%9)pHzjqp=cE^DetfyEll$Ap7_Vmw|JhAL>an|%*?|}m+#L7(OZj&CTW7vn<@ECtCvI;IHQbC1g#etK?Syi;lc$y8Q8rQ}b2Us=XEYE7x`Hb#30SpD!2vzXF@q3uVMfFZ!^{+AikNM-lVk5r_TGPv^B=ko_w0 z4I;w0ToQ2fx8&b9z`~3~Kmvk+hrV00aj6F4hY0XN1tC!gn0F8)7`g&TZV40)LBbH; z5CjaeLb=*f`8PfM5RlN&Kk4?8~s+W3LpwJaW(;fKmY*fN9d>M%;)JBW3FGvTu+I?U@*6D$0XXt-o28U zX?Nu3H1BEtKrfEF5T3niY}|K(-qqXi64XIyl!8jBYinz- zGf5H$mc}OMUCWbec_R=7*vbMV0200+O@Njy$k(=6V5c;Ekqd>1uLAsZ5_-xP0`p*E;TE03fU!;K z9!;7LVIvADu5Eh7aL!f?__mgFx;cF2T6%?|-idM~by8uQJDm0lY#hh22;-U~KfUbq zAvkAQqQ)juipV7cPtoIp_X*C~))DhVBSlF$*3qA5II~GP^}FX!K~MU|iel&2i_av+ z$Hk?`Z5;-J@)0-1Y!%)1?=@Bl9elnmo`xuK58sG^JB(QWU{?O({L`^ z?k4$}wR*dcBlNQj*C4h&*x;SMk(hdt$Ia*2YGYBP!q+ACBkdm6MO`6VDxQ`e>)H4o z>D@5vtgVFJ%fNX5&M@c*8$1cccf`1NcCw0yCCYdvyM3At zxu7xQ(A6@Ro^*{BG;EMr(LdQTazP=4HH4J&PO&$O+pgn+aL-!(gO%IevSYN{XZJRl z6~Ea7jCnV?Q81rMFO7N}H5K>l;buI24(L~3r&7!Ftu_jDmYEmwe!q zcwT}4#5$<1(He~IK{OU_au;4Re?S48%{j08FgN|Y?iaRVIAsq_V|_xV;_EFr%nF1F?VXv--}@{SCvAsef~Gft;J4oIpda=cH?(7h;;89mwb-s2ji$;%R zs#KpLVC}}@500ytYCLWtftw2T+pZRTOoi-ra1!*PgG)oUT(fMDXLe$h+$cta`uxh6 z)~WE29OX@C7+M3P?H^uy$=~k?Mny!_6cXD0>EW=4sCLzH$n)CC54~31)~a4!yC6-Z zqdhuVTg(oE1lBec9%yz47iZykx{X8N0c&PLsXZZrO-E}K?Xgd-I6?DMtS3W6-&)Ub zFF`(=u&pBsukxO(lhqhT0ITzvbYqNi%@it4NuENIKki6ctL_sx%SE$`E6ip-gH>3u zfcQ=FH>eWUL>Ef4zn3LJqkrOotVva4U$~)+YFNR;Se4N+tD9w>_ns87-mgLXr0@;n z3@#ahb9OS(Ii~+Y^=Sg*a#>?eir0(pMtLs31N_7!}Um~&#CWF zrGanrIdvy$Y4ZX{k-wtD%Ov#_4C3Se2TceVVasCxM59CWCR%hlc65}$ZG zShB*QtyaxDxhxqx@(4E$OTkT~@EQJkh(kH|VM|VsFfO1B;IZ<-vV-!f!neBvgwzOVj?|e=Uga#cWlOO;Iwd8*e&-ce>fu($ge|V@UFa0 z*Sezdj0pF!e8m^T+DxMjs61_FSG0LDPnEor2n6(}+fT!!i4NA;rIkY!;8q!GT=$pA zeH}T;zH$gtwT$^NnmqhDwnuZvT?$UKIK^Uf3HxLusxMv=0UJO-O|qri1W^-R)Ug@2 zW{e}|0!i8Og#zu?mr&+b9(ZgdA{6dQop(Llc@`hW(jajQ9lW;j19`o(QB?w?Ew929 z8RwCNN_&-r9qEBDg0Nw%!gs?G{Om&M@>FWOJPP+@8?$%);}Cgz;@X@PPR&$a&<)my ze_TOg#BRhuI-ab7HxxLVqA=tK6Z@)g1GQzDy1VxH(#|J476s|0AgVd2p;c*nfk2+7y8h}5W=r6q~BS?a##BOR}%Wt#SqdF7ScbZTXLyiS^v9F z{mNT;ajodzU|r@~SK5Z(e+sO`*AB~e@$13!`(^&|==skw2Uj3jb`Tx1mk|Dwh)6A5 Nu3{o0Zfk^3zXGnKaIpXY literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg new file mode 100644 index 0000000000000000000000000000000000000000..afc980b49ddbdb1fa90b5b2944bd2c2e52cc9418 GIT binary patch literal 1254 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L(=`2Q>k5-%AK=(R;AxKW^rG!sxRW{zJ&Lm zPB8L@yjkbAWU|iw)cX>>v5P0PPAIdipUTL})w|@$cAp2YqNC4Omt8k@&DEXlyTkZM z^qHiS`!gRI&)LXhHfOozlgs%k-pC*y{Sb6UZL zUIyXiFaG^4yI=ic{pJgSOEq@8`(Ms}eLs5Jx^1_E>^Hmg7u{NN@n+nr4PEKmU9|+m z3c0;>4v2=`yC1N#ONKV|d6o*1^D7Q7b^9MT9{ zuH(q?$%Cz1ZA#v^?Tc^Sx4X4&du^1b>+TnCtGBFKTf5!&#g|W2rJHk2PkPU_yZg*N za^a3ykJ1yB);(}YVvCu(ic?{l5l`za!+@Ba{8ZCSMxWj4_V#pBkDCFrn1q%?5Szrc7{TbRd!~Ipwzg`*ikYdZ zQ#d(tl#Kp;erB`s)SumLcbe3i+mGL0boJ-5tKa@JsB|rsSaEyW+5G3vYsHt%dbRGh z{!P`7SDyWLJ>suZIEn3!yGHKKyKOo*bm!|ZF($B?C%BrJ8%D;~_TBT;TW53h^SeCG z2@Mi6CLFU1e3qxMjnTGvj!C%LNgZx(iMQ`$7(VPyO9?&D(|aIcwLoiE&*K&H#wzb+ z3S$Iix;|;l^FDD<>$Xc?3PaD+oz~Gi4luKIugi9Px$EM~+qtoKf8XA`ugl1P>74Q{ S;p<-4#!cIOy?V9%|C<14Z~>_R literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/train/water_area/beach/P0060208.jpg b/tests/data/millionaid/train/water_area/beach/P0060208.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d742f2c51b455ab2839b083325bbb563d95aee5 GIT binary patch literal 1231 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L;a)+r*p1l+7wM;@>u&+%kO#6i{3w4UXwgz z*e8L|tf`TK~TKO@Z|N>AU*@`_ dataset consists + of one million aerial images from the Googl Earth Engine that offers + either `a mult-class learning task + `_ + with 51 classess or a `mulit-label learning task + `_ + with 73 different possible labels. For more details please consult + the accompanying `paper `_. + + Dataset features: + + * RGB aerial images with varying resolutions from 0.5m to 153m per pixel + * images within classes can have different pixel dimension + + Dataset format: + + * images are three-channel jpg + + If you use this dataset in your research, please cite the following paper: + + * https://ieeexplore.ieee.org/document/9393553 + + .. versionadded:: 0.3 + """ + + multi_label_categories = [ + "agriculture_land", + "airport_area", + "apartment", + "apron", + "arable_land", + "bare_land", + "baseball_field", + "basketball_court", + "beach", + "bridge", + "cemetery", + "church", + "commercial_area", + "commercial_land", + "dam", + "desert", + "detached_house", + "dry_field", + "factory_area", + "forest", + "golf_course", + "grassland", + "greenhouse", + "ground_track_field", + "helipad", + "highway_area", + "ice_land", + "industrial_land", + "intersection", + "island", + "lake", + "leisure_land", + "meadow", + "mine", + "mining_area", + "mobile_home_park", + "oil_field", + "orchard", + "paddy_field", + "parking_lot", + "pier", + "port_area", + "power_station", + "public_service_land", + "quarry", + "railway", + "railway_area", + "religious_land", + "residential_land", + "river", + "road", + "rock_land", + "roundabout", + "runway", + "solar_power_plant", + "sparse_shrub_land", + "special_land", + "sports_land", + "stadium", + "storage_tank", + "substation", + "swimming_pool", + "tennis_court", + "terraced_field", + "train_station", + "transportation_land", + "unutilized_land", + "viaduct", + "wastewater_plant", + "water_area", + "wind_turbine", + "woodland", + "works", + ] + + multi_class_categories = [ + "apartment", + "apron", + "bare_land", + "baseball_field", + "bapsketball_court", + "beach", + "bridge", + "cemetery", + "church", + "commercial_area", + "dam", + "desert", + "detached_house", + "dry_field", + "forest", + "golf_course", + "greenhouse", + "ground_track_field", + "helipad", + "ice_land", + "intersection", + "island", + "lake", + "meadow", + "mine", + "mobile_home_park", + "oil_field", + "orchard", + "paddy_field", + "parking_lot", + "pier", + "quarry", + "railway", + "river", + "road", + "rock_land", + "roundabout", + "runway", + "solar_power_plant", + "sparse_shrub_land", + "stadium", + "storage_tank", + "substation", + "swimming_pool", + "tennis_court", + "terraced_field", + "train_station", + "viaduct", + "wastewater_plant", + "wind_turbine", + "works", + ] + + md5s = { + "train": "1b40503cafa9b0601653ca36cd788852", + "test": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", + } + url = { + "train": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", + "test": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", + } + + filenames = {"train": "train.zip", "test": "test.zip"} + + tasks = ["multi-class", "multi-label"] + splits = ["train", "test"] + + def __init__( + self, + root: str = "data", + task: str = "multi-class", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + download: bool = False, + checksum: bool = False, + ) -> None: + """Initialize a new MillionAID dataset instance. + + Args: + root: root directory where dataset can be found + task: whether to use multi-class or multi-label task + split: train or test split + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + download: if True, download dataset and store it in the root directory + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if ``download=False`` and data is not found, or checksums + don't match + """ + self.root = root + self.transforms = transforms + self.download = download + self.checksum = checksum + assert task in self.tasks + assert split in self.splits + self.task = task + self.split = split + + self._verify() + + self.files = self._load_files(self.root) + + self.classes = sorted(set(cls for f in self.files for cls in f["label"])) + self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + image = self._load_image(files["image"]) + cls_label = [self.class_to_idx[label] for label in files["label"]] + label = torch.tensor(cls_label, dtype=torch.long) # type: ignore[attr-defined] + sample = {"image": image, "label": label} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_files(self, root: str) -> List[Dict[str, Any]]: + """Return the paths of the files in the dataset. + + Args: + root: root dir of dataset + + Returns: + list of dicts containing paths for each pair of image, and list of labels + """ + imgs_no_subcat = list( + glob.glob(os.path.join(root, self.split, "**", "**", "*.jpg")) + ) + + imgs_subcat = list( + glob.glob(os.path.join(root, self.split, "**", "**", "**", "*.jpg")) + ) + + scenes = [p.split("/")[-3] for p in imgs_no_subcat] + [ + p.split("/")[-4] for p in imgs_subcat + ] + + subcategories = ["Missing" for p in imgs_no_subcat] + [ + p.split("/")[-3] for p in imgs_subcat + ] + + classes = [p.split("/")[-2] for p in imgs_no_subcat] + [ + p.split("/")[-2] for p in imgs_subcat + ] + + if self.task == "multi-label": + labels = [ + [sc, sub, c] if sub != "Missing" else [sc, c] + for sc, sub, c in zip(scenes, subcategories, classes) + ] + else: + labels = [[c] for c in classes] + + images = imgs_no_subcat + imgs_subcat + + files = [dict(image=img, label=l) for img, l in zip(images, labels)] + + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories are found, else False + """ + filepath = os.path.join(self.root, self.split) + if os.path.isdir(filepath): + return + + filepath = os.path.join(self.root, self.split + ".zip") + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5s[self.split]): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + # Check if the user requested to download the dataset + if not self.download: + raise RuntimeError( + "Dataset not found in `root` directory, either specify a different" + + " `root` directory or manually download " + + "the dataset to this directory." + ) + + # else download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset and extract it. + + Raises: + AssertionError: if the checksum does not match + """ + download_and_extract_archive( + self.url[self.split], + self.root, + filename=self.filenames[self.split], + md5=self.md5s[self.split] if self.checksum else None, + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + labels = [self.classes[cast(int, label)] for label in sample["label"]] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_labels = [ + self.classes[cast(int, label)] for label in sample["prediction"] + ] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {labels}" + if showing_predictions: + title += f"\nPrediction: {prediction_labels}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig From 1c7951ab8ceec07db0ece8db24d193f0033e73d3 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 10 Mar 2022 22:08:43 +0100 Subject: [PATCH 02/11] test --- tests/datasets/test_millionaid.py | 104 ++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 tests/datasets/test_millionaid.py diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py new file mode 100644 index 00000000000..65229e77f89 --- /dev/null +++ b/tests/datasets/test_millionaid.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +from pathlib import Path +from typing import Generator + +import matplotlib.pyplot as plt +import pytest +import torch +import torch.nn as nn +from _pytest.fixtures import SubRequest +from _pytest.monkeypatch import MonkeyPatch + +import torchgeo.datasets.utils +from torchgeo.datasets import MillionAID + + +def download_url(url: str, root: str, *args: str) -> None: + shutil.copy(url, root) + + +class TestMillionAID: + @pytest.fixture( + params=zip( + ["train", "train", "test", "test"], + ["multi-class", "multi-label", "multi-class", "multi-label"], + ) + ) + def dataset( + self, + monkeypatch: Generator[MonkeyPatch, None, None], + request: SubRequest, + tmp_path: Path, + ) -> MillionAID: + monkeypatch.setattr( # type: ignore[attr-defined] + torchgeo.datasets.utils, "download_url", download_url + ) + data_dir = os.path.join("tests", "data", "millionaid") + + urls = { + "train": os.path.join(data_dir, "train.zip"), + "test": os.path.join(data_dir, "test.zip"), + } + + md5s = { + "train": "d5b7c0e90af70b4e6746c9d3a37471b2", + "test": "7309f19eca7f010d1af9a6adb396b7f8", + } + + monkeypatch.setattr(MillionAID, "url", urls) # type: ignore[attr-defined] + monkeypatch.setattr(MillionAID, "md5s", md5s) # type: ignore[attr-defined] + root = str(tmp_path) + split, task = request.param + transforms = nn.Identity() # type: ignore[attr-defined] + return MillionAID( + root=root, + split=split, + task=task, + transforms=transforms, + download=True, + checksum=True, + ) + + def test_already_downloaded(self, dataset: MillionAID) -> None: + MillionAID(root=dataset.root, split=dataset.split, download=True) + + def test_getitem(self, dataset: MillionAID) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + + def test_len(self, dataset: MillionAID) -> None: + assert len(dataset) == 2 + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "millionaid", "train.zip") + shutil.copy(url, tmp_path) + MillionAID(root=str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "train.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + MillionAID(root=str(tmp_path), checksum=True) + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in."): + MillionAID(str(tmp_path)) + + def test_plot(self, dataset: MillionAID) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: MillionAID) -> None: + x = dataset[0].copy() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() From 8903905bb666499dc387fa4e424aedb1bf50a794 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Fri, 11 Mar 2022 17:13:21 +0100 Subject: [PATCH 03/11] separator --- torchgeo/datasets/millionaid.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 23701fbe645..7fdb3d8d3c5 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -274,16 +274,16 @@ def _load_files(self, root: str) -> List[Dict[str, Any]]: glob.glob(os.path.join(root, self.split, "**", "**", "**", "*.jpg")) ) - scenes = [p.split("/")[-3] for p in imgs_no_subcat] + [ - p.split("/")[-4] for p in imgs_subcat + scenes = [p.split(os.sep)[-3] for p in imgs_no_subcat] + [ + p.split(os.sep)[-4] for p in imgs_subcat ] subcategories = ["Missing" for p in imgs_no_subcat] + [ - p.split("/")[-3] for p in imgs_subcat + p.split(os.sep)[-3] for p in imgs_subcat ] - classes = [p.split("/")[-2] for p in imgs_no_subcat] + [ - p.split("/")[-2] for p in imgs_subcat + classes = [p.split(os.sep)[-2] for p in imgs_no_subcat] + [ + p.split(os.sep)[-2] for p in imgs_subcat ] if self.task == "multi-label": From af42f5972551addcb2b09ddb030d4b12b8f3e77e Mon Sep 17 00:00:00 2001 From: nilsleh Date: Sun, 20 Mar 2022 20:21:03 +0100 Subject: [PATCH 04/11] remove type ignore --- torchgeo/datasets/millionaid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 7fdb3d8d3c5..04408487d6e 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -249,7 +249,7 @@ def __getitem__(self, index: int) -> Dict[str, Tensor]: files = self.files[index] image = self._load_image(files["image"]) cls_label = [self.class_to_idx[label] for label in files["label"]] - label = torch.tensor(cls_label, dtype=torch.long) # type: ignore[attr-defined] + label = torch.tensor(cls_label, dtype=torch.long) sample = {"image": image, "label": label} if self.transforms is not None: @@ -311,7 +311,7 @@ def _load_image(self, path: str) -> Tensor: """ with Image.open(path) as img: array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) - tensor: Tensor = torch.from_numpy(array) # type: ignore[attr-defined] + tensor: Tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) return tensor From 28c3cc32d90eb055fc1be38ce3fd5443f33d72dd Mon Sep 17 00:00:00 2001 From: nilsleh Date: Sun, 20 Mar 2022 20:46:51 +0100 Subject: [PATCH 05/11] type in test --- tests/datasets/test_millionaid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 65229e77f89..a995347f2a4 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -53,7 +53,7 @@ def dataset( monkeypatch.setattr(MillionAID, "md5s", md5s) # type: ignore[attr-defined] root = str(tmp_path) split, task = request.param - transforms = nn.Identity() # type: ignore[attr-defined] + transforms = nn.Identity() # type: ignore[no-untyped-call] return MillionAID( root=root, split=split, From 9af29b4d4d4e964b35ee1962ae7de64903b14613 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Sun, 3 Apr 2022 21:13:14 +0200 Subject: [PATCH 06/11] requested changes --- tests/data/millionaid/data.py | 2 -- tests/datasets/test_millionaid.py | 20 ++++++-------------- torchgeo/datasets/millionaid.py | 10 +++++----- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/tests/data/millionaid/data.py b/tests/data/millionaid/data.py index a531d409224..03ea05ad5df 100644 --- a/tests/data/millionaid/data.py +++ b/tests/data/millionaid/data.py @@ -5,7 +5,6 @@ import hashlib import os -import random import shutil import numpy as np @@ -14,7 +13,6 @@ SIZE = 32 np.random.seed(0) -random.seed(0) PATHS = { "train": [ diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index a995347f2a4..9a6b9142f87 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -1,10 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import itertools import os import shutil from pathlib import Path -from typing import Generator import matplotlib.pyplot as plt import pytest @@ -23,20 +23,12 @@ def download_url(url: str, root: str, *args: str) -> None: class TestMillionAID: @pytest.fixture( - params=zip( - ["train", "train", "test", "test"], - ["multi-class", "multi-label", "multi-class", "multi-label"], - ) + params=itertools.product(["train", "test"], ["multi-class", "multi-label"]) ) def dataset( - self, - monkeypatch: Generator[MonkeyPatch, None, None], - request: SubRequest, - tmp_path: Path, + self, monkeypatch: MonkeyPatch, request: SubRequest, tmp_path: Path ) -> MillionAID: - monkeypatch.setattr( # type: ignore[attr-defined] - torchgeo.datasets.utils, "download_url", download_url - ) + monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) data_dir = os.path.join("tests", "data", "millionaid") urls = { @@ -49,8 +41,8 @@ def dataset( "test": "7309f19eca7f010d1af9a6adb396b7f8", } - monkeypatch.setattr(MillionAID, "url", urls) # type: ignore[attr-defined] - monkeypatch.setattr(MillionAID, "md5s", md5s) # type: ignore[attr-defined] + monkeypatch.setattr(MillionAID, "url", urls) + monkeypatch.setattr(MillionAID, "md5s", md5s) root = str(tmp_path) split, task = request.param transforms = nn.Identity() # type: ignore[no-untyped-call] diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 04408487d6e..ceefa206a9a 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -21,7 +21,7 @@ class MillionAID(VisionDataset): """Million-AID Dataset. The `MillionAID `_ dataset consists - of one million aerial images from the Googl Earth Engine that offers + of one million aerial images from the Google Earth Engine that offers either `a mult-class learning task `_ with 51 classess or a `mulit-label learning task @@ -226,7 +226,7 @@ def __init__( self.files = self._load_files(self.root) - self.classes = sorted(set(cls for f in self.files for cls in f["label"])) + self.classes = sorted({cls for f in self.files for cls in f["label"]}) self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} def __len__(self) -> int: @@ -336,9 +336,9 @@ def _verify(self) -> None: # Check if the user requested to download the dataset if not self.download: raise RuntimeError( - "Dataset not found in `root` directory, either specify a different" - + " `root` directory or manually download " - + "the dataset to this directory." + f"Dataset not found in `root={self.root}` directory, either " + "specify a different `root` directory or manually download " + "the dataset to this directory." ) # else download the dataset From 0993e6083b97a2ec330dbe02151a9d12026cef48 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Mon, 4 Apr 2022 10:13:54 +0200 Subject: [PATCH 07/11] typos and glob pattern --- torchgeo/datasets/millionaid.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index ceefa206a9a..043a6a7ab74 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -21,10 +21,10 @@ class MillionAID(VisionDataset): """Million-AID Dataset. The `MillionAID `_ dataset consists - of one million aerial images from the Google Earth Engine that offers - either `a mult-class learning task + of one million aerial images from Google Earth Engine that offers + either `a multi-class learning task `_ - with 51 classess or a `mulit-label learning task + with 51 classes or a `multi-label learning task `_ with 73 different possible labels. For more details please consult the accompanying `paper `_. @@ -261,17 +261,17 @@ def _load_files(self, root: str) -> List[Dict[str, Any]]: """Return the paths of the files in the dataset. Args: - root: root dir of dataset + root: root directory of dataset Returns: list of dicts containing paths for each pair of image, and list of labels """ imgs_no_subcat = list( - glob.glob(os.path.join(root, self.split, "**", "**", "*.jpg")) + glob.glob(os.path.join(root, self.split, "*", "*", "*.jpg")) ) imgs_subcat = list( - glob.glob(os.path.join(root, self.split, "**", "**", "**", "*.jpg")) + glob.glob(os.path.join(root, self.split, "*", "*", "*", "*.jpg")) ) scenes = [p.split(os.sep)[-3] for p in imgs_no_subcat] + [ From ea6ad5b6d6d89361b6c7e52ccc740546960e9ea0 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Tue, 5 Apr 2022 21:02:42 +0200 Subject: [PATCH 08/11] task argument description --- torchgeo/datasets/millionaid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 043a6a7ab74..12e36115ca9 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -202,7 +202,7 @@ def __init__( Args: root: root directory where dataset can be found - task: whether to use multi-class or multi-label task + task: type of task, either "multi-class" or "multi-label" split: train or test split transforms: a function/transform that takes input sample and its target as entry and returns a transformed version From 5715315612032f4da143134c5f549a02e3919515 Mon Sep 17 00:00:00 2001 From: nilsleh Date: Thu, 7 Apr 2022 14:49:02 +0200 Subject: [PATCH 09/11] add test md5 hash --- torchgeo/datasets/millionaid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 12e36115ca9..8df5d731768 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -177,7 +177,7 @@ class MillionAID(VisionDataset): md5s = { "train": "1b40503cafa9b0601653ca36cd788852", - "test": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", + "test": "51a63ee3eeb1351889eacff349a983d8", } url = { "train": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", From 976ad3d177061d76a00511d0e346a34434d41f0e Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 9 Jul 2022 14:48:29 -0700 Subject: [PATCH 10/11] Remove download logic --- docs/api/non_geo_datasets.csv | 1 + tests/datasets/test_millionaid.py | 52 ++++++------------------------- torchgeo/datasets/millionaid.py | 40 +++++------------------- 3 files changed, 18 insertions(+), 75 deletions(-) diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 3ebfc26668e..c4ae78f037d 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -16,6 +16,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB `LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB `LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB +`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB `NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB `OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI `PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 9a6b9142f87..437a1a93c51 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import itertools import os import shutil from pathlib import Path @@ -11,53 +10,22 @@ import torch import torch.nn as nn from _pytest.fixtures import SubRequest -from _pytest.monkeypatch import MonkeyPatch -import torchgeo.datasets.utils from torchgeo.datasets import MillionAID -def download_url(url: str, root: str, *args: str) -> None: - shutil.copy(url, root) - - class TestMillionAID: @pytest.fixture( - params=itertools.product(["train", "test"], ["multi-class", "multi-label"]) + scope="class", params=zip(["train", "test"], ["multi-class", "multi-label"]) ) - def dataset( - self, monkeypatch: MonkeyPatch, request: SubRequest, tmp_path: Path - ) -> MillionAID: - monkeypatch.setattr(torchgeo.datasets.utils, "download_url", download_url) - data_dir = os.path.join("tests", "data", "millionaid") - - urls = { - "train": os.path.join(data_dir, "train.zip"), - "test": os.path.join(data_dir, "test.zip"), - } - - md5s = { - "train": "d5b7c0e90af70b4e6746c9d3a37471b2", - "test": "7309f19eca7f010d1af9a6adb396b7f8", - } - - monkeypatch.setattr(MillionAID, "url", urls) - monkeypatch.setattr(MillionAID, "md5s", md5s) - root = str(tmp_path) + def dataset(self, request: SubRequest) -> MillionAID: + root = os.path.join("tests", "data", "millionaid") split, task = request.param transforms = nn.Identity() # type: ignore[no-untyped-call] return MillionAID( - root=root, - split=split, - task=task, - transforms=transforms, - download=True, - checksum=True, + root=root, split=split, task=task, transforms=transforms, checksum=True ) - def test_already_downloaded(self, dataset: MillionAID) -> None: - MillionAID(root=dataset.root, split=dataset.split, download=True) - def test_getitem(self, dataset: MillionAID) -> None: x = dataset[0] assert isinstance(x, dict) @@ -69,20 +37,20 @@ def test_getitem(self, dataset: MillionAID) -> None: def test_len(self, dataset: MillionAID) -> None: assert len(dataset) == 2 + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + MillionAID(str(tmp_path)) + def test_not_extracted(self, tmp_path: Path) -> None: url = os.path.join("tests", "data", "millionaid", "train.zip") shutil.copy(url, tmp_path) - MillionAID(root=str(tmp_path)) + MillionAID(str(tmp_path)) def test_corrupted(self, tmp_path: Path) -> None: with open(os.path.join(tmp_path, "train.zip"), "w") as f: f.write("bad") with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): - MillionAID(root=str(tmp_path), checksum=True) - - def test_not_found(self, tmp_path: Path) -> None: - with pytest.raises(RuntimeError, match="Dataset not found in."): - MillionAID(str(tmp_path)) + MillionAID(str(tmp_path), checksum=True) def test_plot(self, dataset: MillionAID) -> None: x = dataset[0].copy() diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py index 8df5d731768..5136907ff4d 100644 --- a/torchgeo/datasets/millionaid.py +++ b/torchgeo/datasets/millionaid.py @@ -14,7 +14,7 @@ from torchgeo.datasets import VisionDataset -from .utils import check_integrity, download_and_extract_archive, extract_archive +from .utils import check_integrity, extract_archive class MillionAID(VisionDataset): @@ -31,7 +31,7 @@ class MillionAID(VisionDataset): Dataset features: - * RGB aerial images with varying resolutions from 0.5m to 153m per pixel + * RGB aerial images with varying resolutions from 0.5 m to 153 m per pixel * images within classes can have different pixel dimension Dataset format: @@ -179,10 +179,6 @@ class MillionAID(VisionDataset): "train": "1b40503cafa9b0601653ca36cd788852", "test": "51a63ee3eeb1351889eacff349a983d8", } - url = { - "train": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", - "test": "https://eastus1-mediap.svc.ms/transform/zip?cs=fFNQTw", - } filenames = {"train": "train.zip", "test": "test.zip"} @@ -195,7 +191,6 @@ def __init__( task: str = "multi-class", split: str = "train", transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, - download: bool = False, checksum: bool = False, ) -> None: """Initialize a new MillionAID dataset instance. @@ -206,16 +201,13 @@ def __init__( split: train or test split transforms: a function/transform that takes input sample and its target as entry and returns a transformed version - download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: - RuntimeError: if ``download=False`` and data is not found, or checksums - don't match + RuntimeError: if dataset is not found """ self.root = root self.transforms = transforms - self.download = download self.checksum = checksum assert task in self.tasks assert split in self.splits @@ -333,28 +325,10 @@ def _verify(self) -> None: extract_archive(filepath) return - # Check if the user requested to download the dataset - if not self.download: - raise RuntimeError( - f"Dataset not found in `root={self.root}` directory, either " - "specify a different `root` directory or manually download " - "the dataset to this directory." - ) - - # else download the dataset - self._download() - - def _download(self) -> None: - """Download the dataset and extract it. - - Raises: - AssertionError: if the checksum does not match - """ - download_and_extract_archive( - self.url[self.split], - self.root, - filename=self.filenames[self.split], - md5=self.md5s[self.split] if self.checksum else None, + raise RuntimeError( + f"Dataset not found in `root={self.root}` directory, either " + "specify a different `root` directory or manually download " + "the dataset to this directory." ) def plot( From 1fce2bc3c54d6ef2053062cb664507c18b8d31e8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 9 Jul 2022 14:52:08 -0700 Subject: [PATCH 11/11] Type ignore no longer needed --- tests/datasets/test_millionaid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_millionaid.py b/tests/datasets/test_millionaid.py index 437a1a93c51..751567e28a8 100644 --- a/tests/datasets/test_millionaid.py +++ b/tests/datasets/test_millionaid.py @@ -21,7 +21,7 @@ class TestMillionAID: def dataset(self, request: SubRequest) -> MillionAID: root = os.path.join("tests", "data", "millionaid") split, task = request.param - transforms = nn.Identity() # type: ignore[no-untyped-call] + transforms = nn.Identity() return MillionAID( root=root, split=split, task=task, transforms=transforms, checksum=True )