Skip to content

Commit

Permalink
Update torch_utils.py to use gpu if available
Browse files Browse the repository at this point in the history
  • Loading branch information
yeelauren authored Jan 24, 2024
1 parent cb0b427 commit 48dfcf2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion spacer/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ def load_weights(model: Any,
:param weights_datastream: model weights, already loaded from storage
:return: well trained model
"""
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load weights
state_dicts = torch.load(weights_datastream,
map_location=torch.device('cpu'))
map_location=device)

with config.log_entry_and_exit('model initialization'):
new_state_dicts = OrderedDict()
Expand Down

0 comments on commit 48dfcf2

Please sign in to comment.