diff --git a/atom3d/datasets/datasets.py b/atom3d/datasets/datasets.py index f7d9cc8..3ef7c31 100644 --- a/atom3d/datasets/datasets.py +++ b/atom3d/datasets/datasets.py @@ -5,6 +5,7 @@ import io import logging import msgpack +import os from pathlib import Path import pickle as pkl import tqdm @@ -347,6 +348,19 @@ def deserialize(x, serialization_format): def get_file_list(input_path, filetype): if filetype == 'lmdb': file_list = [input_path] + elif os.path.isfile(input_path): + with open(input_path) as f: + all_paths = f.readlines() + input_dir = os.path.dirname(input_path) + file_list = [] + for x in all_paths: + x = x.strip() + if not fo.is_type(x, filetype): + continue + x = os.path.join(input_dir, x) + if not os.path.exists(x): + raise RuntimeError(f'{x} does not exist!') + file_list.append(x) else: file_list = fi.find_files(input_path, fo.patterns[filetype]) return file_list @@ -356,7 +370,7 @@ def load_dataset(file_list, filetype, transform=None, include_bonds=False): """ Load files in file_list into corresponding dataset object. All files should be of type filetype. - :param file_list: List containing paths to silent files. Assumes one structure per file. + :param file_list: List containing paths to files. Assumes one structure per file. :type file_list: list[Union[str, Path]] :param filetype: Type of dataset. Allowable types are 'lmdb', 'pdb', 'silent', 'sdf', 'xyz', 'xyz-gdb'. :type filetype: str diff --git a/atom3d/util/formats.py b/atom3d/util/formats.py index acb6c72..584f319 100755 --- a/atom3d/util/formats.py +++ b/atom3d/util/formats.py @@ -183,6 +183,10 @@ def read_any(f, name=None): _regexes = {k: re.compile(v) for k, v in patterns.items()} +def is_type(f, filetype): + return _regexes[filetype].search(str(f)) + + def is_pdb(f): """Check if file is in pdb format.""" return _regexes['pdb'].search(str(f)) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 22f1d3b..144107a 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -18,6 +18,15 @@ def test_load_dataset_lmdb(): assert df['atoms'].z.dtype == 'float' +def test_load_dataset_list(): + dataset = da.load_dataset('tests/test_data/list/pdbs.txt', 'pdb') + assert len(dataset) == 4 + for df in dataset: + print(df) + assert df['atoms'].x.dtype == 'float' + assert df['atoms'].y.dtype == 'float' + assert df['atoms'].z.dtype == 'float' + #def test_load_dataset_sharded(): # dataset = da.load_dataset('tests/test_data/sharded', 'sharded') # assert len(dataset) == 4 diff --git a/tests/test_data/list/pdbs.txt b/tests/test_data/list/pdbs.txt new file mode 100644 index 0000000..c02f6f2 --- /dev/null +++ b/tests/test_data/list/pdbs.txt @@ -0,0 +1,4 @@ +../pdb/2olx.pdb +../pdb/11as.pdb +../pdb/103l.pdb +../pdb/117e.pdb \ No newline at end of file