Skip to content

Commit

Permalink
use tempfile to avoid name conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
geyang committed Sep 18, 2021
1 parent e184579 commit e7f754e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ml_logger/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.8.32
0.8.33
10 changes: 6 additions & 4 deletions ml_logger/ml_logger/ml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,14 +1605,16 @@ def download_file(self, *keys, path=None, to, relative=False):
f.write(buf.getbuffer())

def load_torch(self, *keys, path=None, map_location=None, **kwargs):
import torch
import torch, tempfile
path = pJoin(*keys, path)
if path.lower().startswith('s3://'):
fn_or_buff = os.path.basename(path)
self.download_s3(path[5:], to=fn_or_buff)
postfix = os.path.basename(path)
with tempfile.NamedTemporaryFile(suffix=f'.{postfix}') as ntp:
self.download_s3(path[5:], to=ntp.name)
return torch.load(ntp, map_location=map_location, **kwargs)
else:
fn_or_buff = self.load_file(path)
return torch.load(fn_or_buff, map_location=map_location, **kwargs)
return torch.load(fn_or_buff, map_location=map_location, **kwargs)

torch_load = load_torch

Expand Down

0 comments on commit e7f754e

Please sign in to comment.