Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add hash check to data download #284

Merged
merged 3 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from tqdm import tqdm

from anomalib.data.inference import InferenceDataset
from anomalib.data.utils import DownloadProgressBar, read_image
from anomalib.data.utils import DownloadProgressBar, hash_check, read_image
from anomalib.data.utils.split import (
create_validation_set_from_test_set,
split_normal_images_in_train_set,
Expand Down Expand Up @@ -359,7 +359,8 @@ def prepare_data(self) -> None:
filename=zip_filename,
reporthook=progress_bar.update_to,
) # nosec

logger.info("Checking hash")
hash_check(zip_filename, "c1fa4d56ac50dd50908ce04e81037a8e")
logger.info("Extracting the dataset.")
with zipfile.ZipFile(zip_filename, "r") as zip_file:
zip_file.extractall(self.root.parent)
Expand Down
11 changes: 7 additions & 4 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from torchvision.datasets.folder import VisionDataset

from anomalib.data.inference import InferenceDataset
from anomalib.data.utils import DownloadProgressBar, read_image
from anomalib.data.utils import DownloadProgressBar, hash_check, read_image
from anomalib.data.utils.split import (
create_validation_set_from_test_set,
split_normal_images_in_train_set,
Expand Down Expand Up @@ -378,19 +378,22 @@ def prepare_data(self) -> None:
logger.info("Downloading the Mvtec AD dataset.")
url = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094"
dataset_name = "mvtec_anomaly_detection.tar.xz"
zip_filename = self.root / dataset_name
with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc="MVTec AD") as progress_bar:
urlretrieve(
url=f"{url}/{dataset_name}",
filename=self.root / dataset_name,
filename=zip_filename,
reporthook=progress_bar.update_to,
)
logger.info("Checking hash")
hash_check(zip_filename, "eefca59f2cede9c3fc5b6befbfec275e")

logger.info("Extracting the dataset.")
with tarfile.open(self.root / dataset_name) as tar_file:
with tarfile.open(zip_filename) as tar_file:
tar_file.extractall(self.root)

logger.info("Cleaning the tar file")
(self.root / dataset_name).unlink()
(zip_filename).unlink()

def setup(self, stage: Optional[str] = None) -> None:
"""Setup train, validation and test data.
Expand Down
4 changes: 2 additions & 2 deletions anomalib/data/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .download import DownloadProgressBar
from .download import DownloadProgressBar, hash_check
from .image import get_image_filenames, read_image

__all__ = ["get_image_filenames", "read_image", "DownloadProgressBar"]
__all__ = ["get_image_filenames", "hash_check", "read_image", "DownloadProgressBar"]
25 changes: 19 additions & 6 deletions anomalib/data/utils/download.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Helper to show progress bars with `urlretrieve`.

Based on https://stackoverflow.com/a/53877507
"""
"""Helper to show progress bars with `urlretrieve`, check hash of file."""

# Copyright (C) 2020 Intel Corporation
#
Expand All @@ -17,7 +14,9 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import hashlib
import io
from pathlib import Path
from typing import Dict, Iterable, Optional, Union

from tqdm import tqdm
Expand Down Expand Up @@ -146,7 +145,7 @@ def __init__(
colour: Optional[str] = None,
delay: Optional[float] = 0,
gui: Optional[bool] = False,
**kwargs
**kwargs,
):
super().__init__(
iterable=iterable,
Expand Down Expand Up @@ -175,13 +174,14 @@ def __init__(
colour=colour,
delay=delay,
gui=gui,
**kwargs
**kwargs,
)
self.total: Optional[Union[int, float]]

def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size=None):
"""Progress bar hook for tqdm.

Based on https://stackoverflow.com/a/53877507
The implementor does not have to bother about passing parameters to this as it gets them from urlretrieve.
However the context needs a few parameters. Refer to the example.

Expand All @@ -193,3 +193,16 @@ def update_to(self, chunk_number: int = 1, max_chunk_size: int = 1, total_size=N
if total_size is not None:
self.total = total_size
self.update(chunk_number * max_chunk_size - self.n)


def hash_check(file_path: Path, expected_hash: str):
"""Raise assert error if hash does not match the calculated hash of the file.

Args:
file_path (Path): Path to file.
expected_hash (str): Expected hash of the file.
"""
with open(file_path, "rb") as hash_file:
assert (
hashlib.md5(hash_file.read()).hexdigest() == expected_hash
), f"Downloaded file {file_path} does not match the required hash."