From d30836fb7c4018f2c07dd0f2584bed7c05bb96a1 Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:54:03 -0400 Subject: [PATCH 1/8] add PR template inspired by https://github.com/neuropsychology/NeuroKit/blob/97b5a97e660c867d01ffc683031e38c35ff7d034/.github/PULL_REQUEST_TEMPLATE.md --- .github/pull_request_template.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .github/pull_request_template.md diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000..7631101 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,20 @@ +This is a template for making a pull-request. You can remove the text and sections and write your own thing if you wish, just make sure you give enough information about how and why. If you have any issues or difficulties, don't hesitate to open an issue. + + +# Description + +The aim is to add this feature ... + +# Proposed Changes + +I changed the `foo()` function so that ... + + +# Checklist + +Here are some things to check before creating the pull request. If you encounter any issues, don't hesitate to ask for help :) + +- [ ] I have read the [contributor's guide](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md). +- [ ] The base branch of my pull request is the `dev` branch, not the `main` branch. +- [ ] I ran the [code checks](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md#implement-your-changes) on the files I added or modified and fixed the errors. +- [ ] I updated the [changelog](https://github.com/arnab39/equiadapt/blob/main/CHANGELOG.md). \ No newline at end of file From 71deebc0fcdbec766d39e058f20ea814a3a9eaab Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:55:07 -0400 Subject: [PATCH 2/8] update format of changelog --- CHANGELOG.md | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd0325b..f02a1bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,16 @@ # Changelog -## Version 0.1 (development) +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Added + +### Fixed + +### Changed + +### Removed \ No newline at end of file From d52312f326b1f7dde394c43dcb19467f221de26b Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:57:31 -0400 Subject: [PATCH 3/8] update links in template --- .github/pull_request_template.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7631101..bc5b355 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -14,7 +14,7 @@ I changed the `foo()` function so that ... Here are some things to check before creating the pull request. If you encounter any issues, don't hesitate to ask for help :) -- [ ] I have read the [contributor's guide](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md). +- [ ] I have read the [contributor's guide](https://github.com/arnab39/reptrix/blob/main/CONTRIBUTING.md). - [ ] The base branch of my pull request is the `dev` branch, not the `main` branch. -- [ ] I ran the [code checks](https://github.com/arnab39/equiadapt/blob/main/CONTRIBUTING.md#implement-your-changes) on the files I added or modified and fixed the errors. -- [ ] I updated the [changelog](https://github.com/arnab39/equiadapt/blob/main/CHANGELOG.md). \ No newline at end of file +- [ ] I ran the [code checks](https://github.com/arnab39/reptrix/blob/main/CONTRIBUTING.md#implement-your-changes) on the files I added or modified and fixed the errors. +- [ ] I updated the [changelog](https://github.com/arnab39/reptrix/blob/main/CHANGELOG.md). \ No newline at end of file From e73ebcc8d85d0c16fcf80718afda962592fb1509 Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 15:57:42 -0400 Subject: [PATCH 4/8] add information about code checks to contributor's guide --- CONTRIBUTING.md | 47 ++++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 884b817..7bd206f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -183,12 +183,10 @@ conda activate reptrix and start making changes. Never work on the main branch! -2. Start your work on this branch. Don't forget to add [docstrings] to new - functions, modules and classes, especially if they are part of public APIs. +2. Start your work on this branch. Don't forget to add [docstrings] to the new + functions, modules and classes, especially if they are part of [equiadapt]. -3. Add yourself to the list of contributors in `AUTHORS.rst`. - -4. When you’re done editing, do: +3. When you’re done editing, do: ``` git add @@ -197,38 +195,41 @@ conda activate reptrix to record your changes in [git]. - ```{todo} if you are not using pre-commit, please remove the following item: - ``` - Please make sure to see the validation messages from [pre-commit] and fix any eventual issues. This should automatically use [flake8]/[black] to check/fix the code style in a way that is compatible with the project. - :::{important} - Don't forget to add unit tests and documentation in case your +> **Note**: + Please add unit tests and documentation in case your contribution adds an additional feature and is not just a bugfix. - Moreover, writing a [descriptive commit message] is highly recommended. In case of doubt, you can check the commit history with: + `git log --graph --decorate --pretty=oneline --abbrev-commit --all` + to look for recurring communication patterns. - ``` - git log --graph --decorate --pretty=oneline --abbrev-commit --all - ``` +#### Run code checks - to look for recurring communication patterns. - ::: +Please make sure to see the validation messages from pre-commit and fix any +eventual issues. This should automatically use [flake8]/[black] to check/fix +the code style in a way that is compatible with the project. -5. Please check that your changes don't break any unit tests with: +To run pre-commit manually, you can use: - ``` - tox - ``` +``` +pre-commit run --all-files +``` + +Please also check that your changes don't break any unit tests with: + +``` +tox +``` - (after having installed [tox] with `pip install tox` or `pipx`). +(after having installed [tox] with `pip install tox` or `pipx`). - You can also use [tox] to run several other pre-configured tasks in the - repository. Try `tox -av` to see a list of the available checks. +You can also use [tox] to run several other pre-configured tasks in the +repository. Try `tox -av` to see a list of the available checks. ### Submit your contribution From d9ae9bcca51118f7a589ea4c2d799be2a8b2ca08 Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 16:01:37 -0400 Subject: [PATCH 5/8] fix formatting in readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7f8ba72..7c8ef62 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,8 @@ Representation Metrics for assessing quality in pretrained deep models You can check out the [contributor's guide](CONTRIBUTING.md). -This project uses `pre-commit`_, you can install it before making any -changes:: +This project uses `pre-commit`, you can install it before making any +changes: pip install pre-commit cd reptrix From c45b2b0295a4291d61a7999020e6a95582cd75b1 Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 16:05:12 -0400 Subject: [PATCH 6/8] run pre-commit autoupdate --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3099a08..128fe68 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: '^docs/conf.py' repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: check-added-large-files @@ -41,7 +41,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 24.4.0 hooks: - id: black language_version: python3 @@ -67,7 +67,7 @@ repos: # - id: codespell - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.8.0' + rev: 'v1.9.0' hooks: - id: mypy args: [--disallow-untyped-defs, --ignore-missing-imports] From 2f10cc5a410dcf57effe314913d879838f85333b Mon Sep 17 00:00:00 2001 From: danibene <34680344+danibene@users.noreply.github.com> Date: Sat, 13 Apr 2024 16:05:44 -0400 Subject: [PATCH 7/8] run pre-commit run --all-files --- .github/pull_request_template.md | 2 +- CHANGELOG.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index bc5b355..5c609ac 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -17,4 +17,4 @@ Here are some things to check before creating the pull request. If you encounter - [ ] I have read the [contributor's guide](https://github.com/arnab39/reptrix/blob/main/CONTRIBUTING.md). - [ ] The base branch of my pull request is the `dev` branch, not the `main` branch. - [ ] I ran the [code checks](https://github.com/arnab39/reptrix/blob/main/CONTRIBUTING.md#implement-your-changes) on the files I added or modified and fixed the errors. -- [ ] I updated the [changelog](https://github.com/arnab39/reptrix/blob/main/CHANGELOG.md). \ No newline at end of file +- [ ] I updated the [changelog](https://github.com/arnab39/reptrix/blob/main/CHANGELOG.md). diff --git a/CHANGELOG.md b/CHANGELOG.md index f02a1bb..ab086d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,4 +13,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -### Removed \ No newline at end of file +### Removed From 54cc611d0abe2553b745ec0593abc2004cd888c2 Mon Sep 17 00:00:00 2001 From: Arna Ghosh Date: Thu, 30 May 2024 20:56:45 -0400 Subject: [PATCH 8/8] Updated tutorial notebook: fixed get_features, added time and LiDAR in description --- tutorial.ipynb | 171 ++++++++++++++++++++++++++++--------------------- 1 file changed, 98 insertions(+), 73 deletions(-) diff --git a/tutorial.ipynb b/tutorial.ipynb index 9ff639c..031de0d 100644 --- a/tutorial.ipynb +++ b/tutorial.ipynb @@ -14,9 +14,11 @@ "\n", "To assess the quality of the learned representations, we will use various metrics, including:\n", "\n", - "- **Alpha**: This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the alpha exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative.\n", + "- [**Alpha**](https://proceedings.neurips.cc/paper_files/paper/2022/hash/70596d70542c51c8d9b4e423f4bf2736-Abstract-Conference.html): This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the alpha exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative.\n", "\n", - "- **RankMe**: This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.\n", + "- [**RankMe**](https://proceedings.mlr.press/v202/garrido23a): This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.\n", + "\n", + "- [**LiDAR**](https://openreview.net/forum?id=f3g5XpL9Kb): This metric computes the rank of the Linear Discriminant Analysis (LDA) matrix constructed using representations of augmented versions of images. A higher rank indicates representations of higher discriminability.\n", "\n", "\n", "We will compute these metrics using the Reptrix library, which provides a convenient interface for representation analysis. Let's dive into the code and explore the evaluation process in detail.\n", @@ -40,7 +42,8 @@ "import torchvision\n", "from tqdm import tqdm\n", "from reptrix import alpha, rankme, lidar\n", - "import reptrix" + "import reptrix\n", + "import time" ] }, { @@ -79,14 +82,14 @@ " # Loop over the dataset and collect the representations\n", " for i, data in enumerate(tqdm(dataloader, 0)):\n", " inputs, _ = data\n", - " # apply 10 random augmentations for each image\n", + " # apply multiple augmentations for each image\n", " if transform:\n", " inputs = torch.cat([transform(inputs) for _ in range(num_augmentations)], dim=0)\n", " with torch.no_grad():\n", " features = encoder_function(inputs.to(device))\n", " if transform:\n", " # put the augmentations in an additonal dimension\n", - " features = features.reshape(-1, num_augmentations, features.shape[1])\n", + " features = features.reshape(num_augmentations, -1, features.shape[1]).transpose(1,0)\n", " all_features.append(features)\n", " \n", " \n", @@ -108,11 +111,11 @@ "metadata": {}, "outputs": [], "source": [ - "transform = torchvision.transforms.Compose([\n", - " torchvision.transforms.ToTensor(),\n", - " torchvision.transforms.Normalize((0.4467, 0.4398, 0.4066), \n", - " (0.2242, 0.2215, 0.2239))\n", - "])\n", + "transform_to_tensor = torchvision.transforms.ToTensor()\n", + "\n", + "STL_MEAN = (0.4467, 0.4398, 0.4066)\n", + "STL_STD = (0.2242, 0.2215, 0.2239)\n", + "transform_base = torchvision.transforms.Normalize(STL_MEAN, STL_STD)\n", "\n", "# Define additional SSL transformations for the Lidar metric evaluation\n", "transform_ssl = torchvision.transforms.Compose([\n", @@ -120,14 +123,16 @@ " torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),\n", " torchvision.transforms.RandomGrayscale(p=0.2),\n", " torchvision.transforms.RandomResizedCrop(\n", - " 96, scale=(0.2, 1.0), \n", + " 96, scale=(0.8, 1.0), \n", " ratio=(0.75, (4/3)), \n", " interpolation=torchvision.transforms.InterpolationMode.BICUBIC),\n", + " torchvision.transforms.Normalize(STL_MEAN, STL_STD)\n", "])\n", " \n", - "\n", + "dataset_folder = '/network/datasets/stl10.var/stl10_torchvision'\n", "# Get the STL10 test dataset to measure the quality of the representations learned by the model\n", - "testset = torchvision.datasets.STL10(root='./data', split='test', download=False, transform=transform)\n", + "# testset = torchvision.datasets.STL10(root='./data', split='test', download=False, transform=transform)\n", + "testset = torchvision.datasets.STL10(root=dataset_folder, split='test', download=False, transform=transform_to_tensor)\n", "\n", "# Define a dataloader to load the test dataset\n", "testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)" @@ -158,10 +163,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "Using cache found in /home/mila/a/arnab.mondal/.cache/torch/hub/facebookresearch_barlowtwins_main\n", - "/home/mila/a/arnab.mondal/.conda/envs/equiadapt/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + "Using cache found in /home/mila/g/ghosharn/.cache/torch/hub/facebookresearch_barlowtwins_main\n", + "/network/scratch/g/ghosharn/conda_envs/ffcv_ssl_fastssl/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", " warnings.warn(\n", - "/home/mila/a/arnab.mondal/.conda/envs/equiadapt/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + "/network/scratch/g/ghosharn/conda_envs/ffcv_ssl_fastssl/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", " warnings.warn(msg)\n" ] } @@ -183,7 +188,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 32/32 [00:02<00:00, 14.42it/s]\n" + "100%|██████████| 32/32 [00:03<00:00, 9.16it/s]\n" ] } ], @@ -194,27 +199,52 @@ "# Set the model to evaluation mode\n", "encoder.eval()\n", "\n", - "all_representations = get_features(encoder, testloader)" + "all_representations = get_features(encoder, testloader, transform=transform_base, num_augmentations=1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([8000, 1, 2048])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_representations.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, "outputs": [], "source": [ + "start_time = time.time()\n", "metric_alpha = alpha.get_alpha(all_representations)\n", - "metric_rankme = rankme.get_rankme(all_representations)" + "alpha_time = time.time()\n", + "metric_rankme = rankme.get_rankme(all_representations)\n", + "rankme_time = time.time()\n", + "alpha_compute_time = alpha_time - start_time\n", + "rankme_compute_time = rankme_time - alpha_time" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -233,21 +263,24 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 32/32 [01:22<00:00, 2.58s/it]\n" + " 0%| | 0/32 [00:00" ] @@ -379,14 +414,16 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 32/32 [01:21<00:00, 2.55s/it]\n" + " 0%| | 0/32 [00:00