Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #39

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
rev: v0.3.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.9.0
hooks:
- id: mypy

Expand All @@ -39,7 +39,7 @@ repos:
args: [--check-filenames]

- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
rev: 0.7.1
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
10 changes: 8 additions & 2 deletions examples/functorch_mlp_ensemble.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@
"\n",
" plt.scatter(*points.T, c=labels)\n",
" plt.gca().set(xlabel=\"x\", ylabel=\"y\", title=f\"{noise_std=}\")\n",
" plt.show()"
" plt.show()\n",
"\n"
]
},
{
Expand Down Expand Up @@ -297,6 +298,8 @@
],
"source": [
"# If so, the loss should decrease with step count.\n",
"\n",
"\n",
"metrics = {}\n",
"for step in range(n_train_steps):\n",
" loss, acc, weights = train_step_fn(weights, points, labels)\n",
Expand All @@ -310,7 +313,8 @@
"ax.set_ylabel(\"loss\", color=\"tab:blue\")\n",
"ax2 = ax.twinx()\n",
"ax2.plot(list(metrics), [v[\"acc\"] for v in metrics.values()], color=\"red\")\n",
"ax2.set_ylabel(\"accuracy\", color=\"tab:red\")"
"ax2.set_ylabel(\"accuracy\", color=\"tab:red\")\n",
"\n"
]
},
{
Expand Down Expand Up @@ -352,6 +356,8 @@
"source": [
"# at the same time! Note that metrics like losses and accuracies are all tuples here,\n",
"# one scalar per model.\n",
"\n",
"\n",
"parallel_train_step_fn = functorch.vmap(train_step_fn, in_dims=(0, None, None))\n",
"batched_weights = initialize_ensemble(n_models=5)\n",
"for step in tqdm(range(n_train_steps), desc=\"training MLP ensemble\"):\n",
Expand Down
14 changes: 10 additions & 4 deletions examples/wandb_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@
" \"learning_rate\": 0.005,\n",
" \"dataset\": \"MNIST\",\n",
" \"architecture\": \"CNN\",\n",
"}"
"}\n",
"\n"
]
},
{
Expand Down Expand Up @@ -281,7 +282,8 @@
" out = self.layer1(x)\n",
" out = self.layer2(out)\n",
" out = out.reshape(out.size(0), -1)\n",
" return self.fc(out)"
" return self.fc(out)\n",
"\n"
]
},
{
Expand Down Expand Up @@ -371,7 +373,8 @@
" wandb.log(\n",
" {\"epoch\": epoch, \"loss\": loss, \"accuracy\": accuracy},\n",
" step=sample_count,\n",
" )"
" )\n",
"\n"
]
},
{
Expand Down Expand Up @@ -474,6 +477,8 @@
"outputs": [],
"source": [
"# Make the data\n",
"\n",
"\n",
"train_set, test_set = (\n",
" torchvision.datasets.MNIST(\n",
" root=\".\",\n",
Expand Down Expand Up @@ -515,7 +520,8 @@
"\n",
" wandb.finish()\n",
"\n",
" return model"
" return model\n",
"\n"
]
},
{
Expand Down
8 changes: 4 additions & 4 deletions paper/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ @online{biewald_weights_2020
date = {2020},
url = {https://docs.wandb.ai/company/academics},
urldate = {2022-08-21},
organization = {{Use Weights \& Biases for free to track experiments, collaborate, and publish results}}
organization = {Use Weights \& Biases for free to track experiments, collaborate, and publish results}
}

@software{bradbury_jax_2018,
Expand All @@ -29,7 +29,7 @@ @software{developers_tensorflow_2022
doi = {10.5281/zenodo.6574269},
url = {https://zenodo.org/record/6574269},
urldate = {2022-08-21},
organization = {{Zenodo}}
organization = {Zenodo}
}

@online{fey_fast_2019,
Expand All @@ -55,7 +55,7 @@ @article{harris_array_2020
volume = {585},
number = {7825},
pages = {357--362},
publisher = {{Nature Publishing Group}},
publisher = {Nature Publishing Group},
issn = {1476-4687},
doi = {10.1038/s41586-020-2649-2},
url = {https://www.nature.com/articles/s41586-020-2649-2},
Expand All @@ -72,7 +72,7 @@ @article{jumper_highly_2021
volume = {596},
number = {7873},
pages = {583--589},
publisher = {{Nature Publishing Group}},
publisher = {Nature Publishing Group},
issn = {1476-4687},
doi = {10.1038/s41586-021-03819-2},
url = {https://www.nature.com/articles/s41586-021-03819-2},
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ warn_unused_ignores = true
[tool.ruff]
target-version = "py38"
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]

[tool.ruff.lint]
select = ["ALL"]
ignore = [
"ANN101",
"ANN101", # Missing type annotation for self in method
"ANN401",
"ARG001",
"C901",
Expand All @@ -87,7 +89,7 @@ ignore = [
]
pydocstyle.convention = "google"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D103", "D104", "INP001", "S101"]
"__init__.py" = ["F401"]
"examples/*" = ["D102", "D103", "D107", "E402", "FA102"]
2 changes: 1 addition & 1 deletion tensorboard_reducer/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _rm_rf_or_raise(path: str, overwrite: bool) -> None:
if overwrite and (is_data_file or is_tb_dir):
os.system(f"rm -rf {path}") # noqa: S605
elif overwrite:
ValueError(
raise ValueError(
f"Received the overwrite flag but the content of '{path}' does not "
"look like it was written by this program. Please make sure you really "
f"want to delete '{path}' and then do so manually."
Expand Down
5 changes: 3 additions & 2 deletions tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def generate_sample_data(
n_tags: int = 1, n_runs: int = 10, n_steps: int = 5
) -> dict[str, pd.DataFrame]:
events_dict = {}
rng = np.random.default_rng()
for idx in range(n_tags):
data = np.random.random((n_steps, n_runs))
data = rng.random((n_steps, n_runs))
df_rand = pd.DataFrame(data, columns=[f"run_{j}" for j in range(n_runs)])
events_dict[f"tag_{idx}"] = df_rand
return events_dict
Expand Down Expand Up @@ -92,4 +93,4 @@ def test_reduce_events_dimensions(n_tags: int, n_runs: int, n_steps: int) -> Non
@pytest.mark.parametrize("reduce_ops", [["mean"], ["max", "min"], ["std", "median"]])
def test_reduce_events_empty_input(reduce_ops: Sequence[str]) -> None:
reduced_events = reduce_events({}, reduce_ops)
assert reduced_events == dict.fromkeys(reduce_ops, {})
assert reduced_events == {op: {} for op in reduce_ops}