Skip to content

Commit

Permalink
Add support for loading dataset from list.
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Townshend committed Feb 3, 2021
1 parent b4c1a60 commit 099bf76
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 1 deletion.
16 changes: 15 additions & 1 deletion atom3d/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io
import logging
import msgpack
import os
from pathlib import Path
import pickle as pkl
import tqdm
Expand Down Expand Up @@ -347,6 +348,19 @@ def deserialize(x, serialization_format):
def get_file_list(input_path, filetype):
if filetype == 'lmdb':
file_list = [input_path]
elif os.path.isfile(input_path):
with open(input_path) as f:
all_paths = f.readlines()
input_dir = os.path.dirname(input_path)
file_list = []
for x in all_paths:
x = x.strip()
if not fo.is_type(x, filetype):
continue
x = os.path.join(input_dir, x)
if not os.path.exists(x):
raise RuntimeError(f'{x} does not exist!')
file_list.append(x)
else:
file_list = fi.find_files(input_path, fo.patterns[filetype])
return file_list
Expand All @@ -356,7 +370,7 @@ def load_dataset(file_list, filetype, transform=None, include_bonds=False):
"""
Load files in file_list into corresponding dataset object. All files should be of type filetype.
:param file_list: List containing paths to silent files. Assumes one structure per file.
:param file_list: List containing paths to files. Assumes one structure per file.
:type file_list: list[Union[str, Path]]
:param filetype: Type of dataset. Allowable types are 'lmdb', 'pdb', 'silent', 'sdf', 'xyz', 'xyz-gdb'.
:type filetype: str
Expand Down
4 changes: 4 additions & 0 deletions atom3d/util/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def read_any(f, name=None):
_regexes = {k: re.compile(v) for k, v in patterns.items()}


def is_type(f, filetype):
return _regexes[filetype].search(str(f))


def is_pdb(f):
"""Check if file is in pdb format."""
return _regexes['pdb'].search(str(f))
Expand Down
9 changes: 9 additions & 0 deletions tests/datasets/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def test_load_dataset_lmdb():
assert df['atoms'].z.dtype == 'float'


def test_load_dataset_list():
dataset = da.load_dataset('tests/test_data/list/pdbs.txt', 'pdb')
assert len(dataset) == 4
for df in dataset:
print(df)
assert df['atoms'].x.dtype == 'float'
assert df['atoms'].y.dtype == 'float'
assert df['atoms'].z.dtype == 'float'

#def test_load_dataset_sharded():
# dataset = da.load_dataset('tests/test_data/sharded', 'sharded')
# assert len(dataset) == 4
Expand Down
4 changes: 4 additions & 0 deletions tests/test_data/list/pdbs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
../pdb/2olx.pdb
../pdb/11as.pdb
../pdb/103l.pdb
../pdb/117e.pdb

0 comments on commit 099bf76

Please sign in to comment.