Skip to content

Commit

Permalink
Re-enable rosetta in make_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Raphael Townshend committed Jul 20, 2022
1 parent 0783041 commit c9e1767
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
7 changes: 6 additions & 1 deletion atom3d/datasets/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import atom3d.datasets.datasets as da
import atom3d.util.file as fi
import atom3d.util.formats as fo
import atom3d.util.rosetta as ro

logger = logging.getLogger(__name__)

Expand All @@ -31,9 +32,13 @@ def main(input_dir, output_lmdb, filetype, score_path, serialization_format):
else:
fileext = filetype
file_list = da.get_file_list(input_dir, fileext)
if score_path is not None:
transform = ro.Scores([score_path])
else:
transform = None
logger.info(f'Found {len(file_list)} files.')

dataset = da.load_dataset(file_list, filetype)
dataset = da.load_dataset(file_list, filetype, transform=transform)
da.make_lmdb_dataset(
dataset, output_lmdb, serialization_format=serialization_format)

Expand Down
19 changes: 16 additions & 3 deletions atom3d/util/rosetta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def __init__(self, file_list):
file_list = [Path(x).absolute() for x in file_list]
for silent_file in file_list:
key = self._key_from_silent_file(silent_file)
if len(file_list) == 1:
key = 'all'
self._scores[key] = self._parse_scores(silent_file)

self._scores = pd.concat(self._scores).sort_index()

def _parse_scores(self, silent_file):
grep_cmd = f"grep ^SCORE: {silent_file}"
grep_cmd = f"grep ^SCORE: '{silent_file}'"
out = subprocess.Popen(
grep_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
cwd=os.getcwd(), shell=True)
Expand All @@ -48,7 +50,11 @@ def _lookup_helper(self, key):
# If there are multiple rows matching key, return only the first one.
# Sometime pandas return single row pd.DataFrame, so we use .squeeze()
# to ensure it always return a pd.Series.
return self._scores.loc[key].head(1).astype(np.float64).squeeze().to_dict()
tmp = self._scores.loc[key]
if type(tmp) == pd.DataFrame:
tmp = tmp.head(1)

return tmp.astype(np.float64).squeeze().to_dict()

def _lookup(self, file_path):
file_path = Path(file_path)
Expand All @@ -61,9 +67,16 @@ def _lookup(self, file_path):
key = (file_path.parent.parent.stem, file_path.stem)
if key in self._scores.index:
return key, self._lookup_helper(key)
if len(self._scores.index.get_level_values(0).unique()):
key = ('all', file_path.name)
result = self._lookup_helper(key)
key = (file_path.stem.split('_')[0],
'_'.join(file_path.stem.split('_')[1:]))
return (key, result)

return file_path.parent.stem, None

def __call__(self, x, error_if_missing=False):
def __call__(self, x, error_if_missing=True):
key, x['scores'] = self._lookup(x['file_path'])
x['id'] = str(key)
if x['scores'] is None and error_if_missing:
Expand Down

0 comments on commit c9e1767

Please sign in to comment.