diff --git a/atom3d/datasets/__main__.py b/atom3d/datasets/__main__.py index ad0688e..8ce7792 100644 --- a/atom3d/datasets/__main__.py +++ b/atom3d/datasets/__main__.py @@ -5,6 +5,7 @@ import atom3d.datasets.datasets as da import atom3d.util.file as fi import atom3d.util.formats as fo +import atom3d.util.rosetta as ro logger = logging.getLogger(__name__) @@ -31,9 +32,13 @@ def main(input_dir, output_lmdb, filetype, score_path, serialization_format): else: fileext = filetype file_list = da.get_file_list(input_dir, fileext) + if score_path is not None: + transform = ro.Scores([score_path]) + else: + transform = None logger.info(f'Found {len(file_list)} files.') - dataset = da.load_dataset(file_list, filetype) + dataset = da.load_dataset(file_list, filetype, transform=transform) da.make_lmdb_dataset( dataset, output_lmdb, serialization_format=serialization_format) diff --git a/atom3d/util/rosetta.py b/atom3d/util/rosetta.py index 156ceed..5f8cd43 100644 --- a/atom3d/util/rosetta.py +++ b/atom3d/util/rosetta.py @@ -23,12 +23,14 @@ def __init__(self, file_list): file_list = [Path(x).absolute() for x in file_list] for silent_file in file_list: key = self._key_from_silent_file(silent_file) + if len(file_list) == 1: + key = 'all' self._scores[key] = self._parse_scores(silent_file) self._scores = pd.concat(self._scores).sort_index() def _parse_scores(self, silent_file): - grep_cmd = f"grep ^SCORE: {silent_file}" + grep_cmd = f"grep ^SCORE: '{silent_file}'" out = subprocess.Popen( grep_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=os.getcwd(), shell=True) @@ -48,7 +50,11 @@ def _lookup_helper(self, key): # If there are multiple rows matching key, return only the first one. # Sometime pandas return single row pd.DataFrame, so we use .squeeze() # to ensure it always return a pd.Series. - return self._scores.loc[key].head(1).astype(np.float64).squeeze().to_dict() + tmp = self._scores.loc[key] + if type(tmp) == pd.DataFrame: + tmp = tmp.head(1) + + return tmp.astype(np.float64).squeeze().to_dict() def _lookup(self, file_path): file_path = Path(file_path) @@ -61,9 +67,16 @@ def _lookup(self, file_path): key = (file_path.parent.parent.stem, file_path.stem) if key in self._scores.index: return key, self._lookup_helper(key) + if len(self._scores.index.get_level_values(0).unique()): + key = ('all', file_path.name) + result = self._lookup_helper(key) + key = (file_path.stem.split('_')[0], + '_'.join(file_path.stem.split('_')[1:])) + return (key, result) + return file_path.parent.stem, None - def __call__(self, x, error_if_missing=False): + def __call__(self, x, error_if_missing=True): key, x['scores'] = self._lookup(x['file_path']) x['id'] = str(key) if x['scores'] is None and error_if_missing: