Skip to content

Commit

Permalink
Add weights_only=True to all torch.load calls (mllam#86)
Browse files Browse the repository at this point in the history
## Describe your changes

Currently running neural-lam with the latest version of pytorch gives a
warning:

```
FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models  for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
```

As we only use `torch.load` to load tensors and lists, we can just set
`weights_only=True` and get rid of this warning (and increase security I
suppose).

## Issue Link
None

## Type of change

- [x] 🐛 Bug fix (non-breaking change that fixes an issue)
- [ ] ✨ New feature (non-breaking change that adds functionality)
- [ ] 💥 Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] 📖 Documentation (Addition or improvements to documentation)

## Checklist before requesting a review

- [x] My branch is up-to-date with the target branch - if not update
your fork with the changes from the target branch (use `pull` with
`--rebase` option if possible).
- [x] I have performed a self-review of my code
- [x] For any new/modified functions/classes I have added docstrings
that clearly describe its purpose, expected inputs and returned values
- [x] I have placed in-line comments to clarify the intent of any
hard-to-understand passages of my code
- [x] I have updated the [README](README.MD) to cover introduced code
changes
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have given the PR a name that clearly describes the change,
written in imperative form
([context](https://www.gitkraken.com/learn/git/best-practices/git-commit-message#using-imperative-verb-form)).
- [x] I have requested a reviewer and an assignee (assignee is
responsible for merging). This applies only if you have write access to
the repo, otherwise feel free to tag a maintainer to add a reviewer and
assignee.

## Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should
check the following:
- [x] the code is readable
- [ ] the code is well tested
- [x] the code is documented (including return types and parameters)
- [x] the code is easy to maintain

## Author checklist after completed review

- [ ] I have added a line to the CHANGELOG describing this change, in a
section
  reflecting type of change (add section where missing):
  - *added*: when you have added new functionality
  - *changed*: when default behaviour of the code has been changed
  - *fixes*: when your contribution fixes a bug

## Checklist for assignee

- [ ] PR is up to date with the base branch
- [ ] the tests pass
- [ ] author has added an entry to the changelog (and designated the
change as *added*, *changed* or *fixed*)
- Once the PR is ready to be merged, squash commits and merge the PR.
  • Loading branch information
joeloskarsson authored Nov 18, 2024
1 parent 7112013 commit 2cc617e
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def load_dataset_stats(dataset_name, device="cpu"):

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
os.path.join(static_dir_path, fn),
map_location=device,
weights_only=True,
)

data_mean = loads_file("parameter_mean.pt") # (d_features,)
Expand All @@ -42,7 +44,9 @@ def load_static_data(dataset_name, device="cpu"):

def loads_file(fn):
return torch.load(
os.path.join(static_dir_path, fn), map_location=device
os.path.join(static_dir_path, fn),
map_location=device,
weights_only=True,
)

# Load border mask, 1. if node is part of border, else 0.
Expand Down Expand Up @@ -116,7 +120,11 @@ def load_graph(graph_name, device="cpu"):
graph_dir_path = os.path.join("graphs", graph_name)

def loads_file(fn):
return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
return torch.load(
os.path.join(graph_dir_path, fn),
map_location=device,
weights_only=True,
)

# Load edges (edge_index)
m2m_edge_index = BufferList(
Expand Down

0 comments on commit 2cc617e

Please sign in to comment.