Skip to content

Commit

Permalink
Merge pull request #31 from nisse3000/master
Browse files Browse the repository at this point in the history
Changed site-descriptor data structure & adapted structure similarity builder accordingly
  • Loading branch information
shyamd authored Apr 2, 2018
2 parents 4c4743d + cc34f50 commit f026e12
Show file tree
Hide file tree
Showing 5 changed files with 835 additions and 756 deletions.
137 changes: 92 additions & 45 deletions emmet/materials/site_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
CrystalSiteFingerprint, CoordinationNumber

# TODO:
# AGNIFingerprints, EwaldSiteEnergy, \
# VoronoiFingerprint, ChemEnvSiteFingerprint, \
# ChemicalSRO

# 1) Add checking OPs present in current implementation of site fingerprints.
# 2) Complete documentation!!!

from maggma.builder import Builder

Expand All @@ -25,19 +23,26 @@

class SiteDescriptorsBuilder(Builder):

def __init__(self, materials, site_descriptors, query=None, **kwargs):
def __init__(self, materials, site_descriptors, mat_query=None, **kwargs):
"""
Calculates site descriptors for materials
Calculates site-based descriptors (e.g., coordination numbers
with different near-neighbor finding approaches) for materials and
runs statistics analysis on selected descriptor types
(order parameter-based site fingerprints). The latter is
useful as a definition of a structure fingerprint
on the basis of local coordination information.
Args:
materials (Store): Store of materials documents
site_descriptors (Store): Store of site-descriptors data such as tetrahedral order parameter or percentage of 8-fold coordination
query (dict): dictionary to limit materials to be analyzed
materials (Store): Store of materials documents.
site_descriptors (Store): Store of site-descriptors data such
as tetrahedral order parameter or
fraction of being 8-fold coordinated.
mat_query (dict): dictionary to limit materials to be analyzed.
"""

self.materials = materials
self.site_descriptors = site_descriptors
self.query = query if query else {}
self.mat_query = mat_query if mat_query else {}

# Set up all targeted site descriptors.
self.sds = {}
Expand All @@ -50,8 +55,10 @@ def __init__(self, materials, site_descriptors, query=None, **kwargs):
self.sds[k] = CoordinationNumber(nn_(), use_weights=False)
k = 'cn_wt_{}'.format(t)
self.sds[k] = CoordinationNumber(nn_(), use_weights=True)
self.all_output_pieces = {'site_descriptors': [k for k in self.sds.keys()]}
self.sds['opsf'] = OPSiteFingerprint()
#self.sds['csf'] = CrystalSiteFingerprint.from_preset('ops')
self.sds['csf'] = CrystalSiteFingerprint.from_preset('ops')
self.all_output_pieces['statistics'] = ['opsf', 'csf']

super().__init__(sources=[materials],
targets=[site_descriptors],
Expand All @@ -60,9 +67,16 @@ def __init__(self, materials, site_descriptors, query=None, **kwargs):
def get_items(self):
"""
Gets all materials that need new site descriptors.
For example, entirely new materials and materials
for which certain descriptor in the current Store
are still missing.
Returns:
generator of materials to calculate site descriptors.
generator of materials to calculate site descriptors
and of the target quantities to be calculated
(e.g., CN with the minimum distance near neighbor
(MinimumDistanceNN) finding class from pymatgen which has label
"cn_mdnn").
"""

self.logger.info("Site-Descriptors Builder Started")
Expand All @@ -71,19 +85,49 @@ def get_items(self):

# All relevant materials that have been updated since site-descriptors
# were last calculated
q = dict(self.query)

q = dict(self.mat_query)
all_task_ids = list(self.materials.distinct(self.materials.key, q))
q.update(self.materials.lu_filter(self.site_descriptors))
task_ids = list(self.materials.distinct(self.materials.key, q))
new_task_ids = list(self.materials.distinct(self.materials.key, q))
self.logger.info(
"Found {} new materials for site-descriptors data".format(len(task_ids)))
for task_id in task_ids:
yield self.materials.query(
properties=[self.materials.key, "structure"],
criteria={self.materials.key: task_id}).limit(1)[0]
"Found {} entirely new materials for site-descriptors data".format(
len(new_task_ids)))
for task_id in all_task_ids:
if task_id in new_task_ids:
any_piece = True

else: # Any piece of info missing?
data_present = self.site_descriptors.query(
properties=[self.site_descriptors.key, "site_descriptors", "statistics"],
criteria={self.site_descriptors.key: task_id}).limit(1)[0]
any_piece = False
for k, v in self.all_output_pieces.items():
if k not in list(data_present.keys()):
any_piece = True
break
else:
any_piece = False
for e in v:
if e not in data_present[k]:
any_piece = True
break
if not any_piece:
for fp in ['opsf', 'csf']:
for l in self.sds[fp].feature_labels():
for fpi in data_present['site_descriptors'][fp]:
if l not in fpi.keys():
any_piece = True
break
if any_piece:
yield self.materials.query(
properties=[self.materials.key, "structure"],
criteria={self.materials.key: task_id}).limit(1)[0]

def process_item(self, item):
"""
Calculates site descriptors for the structures
Calculates all site descriptors for the structures
Args:
item (dict): a dict with a task_id and a structure
Expand All @@ -100,8 +144,8 @@ def process_item(self, item):
site_descr_doc['site_descriptors'] = \
self.get_site_descriptors_from_struct(
site_descr_doc['structure'])
site_descr_doc['opsf_statistics'] = \
self.get_opsf_statistics(
site_descr_doc['statistics'] = \
self.get_statistics(
site_descr_doc['site_descriptors'])
site_descr_doc[self.site_descriptors.key] = item[self.materials.key]

Expand All @@ -128,10 +172,10 @@ def get_site_descriptors_from_struct(self, structure):
# Compute descriptors.
for k, sd in self.sds.items():
try:
d = {}
d = []
l = sd.feature_labels()
for i, s in enumerate(structure.sites):
d[i] = {}
d.append({'site': i})
for j, desc in enumerate(sd.featurize(structure, i)):
d[i][l[j]] = desc
doc[k] = d
Expand All @@ -142,29 +186,32 @@ def get_site_descriptors_from_struct(self, structure):

return doc

def get_opsf_statistics(self, site_descr):
def get_statistics(self, site_descr, fps=('opsf', 'csf')):
doc = {}

# Compute site-descriptor statistics.
try:
n_site = len(list(site_descr['opsf'].keys()))
tmp = {}
for isite in range(n_site):
for l, v in site_descr['opsf'][isite].items():
if l not in list(tmp.keys()):
tmp[l] = []
tmp[l].append(v)
d = {}
for k, l in tmp.items():
d[k] = {}
d[k]['min'] = min(tmp[k])
d[k]['max'] = max(tmp[k])
d[k]['mean'] = np.mean(tmp[k])
d[k]['std'] = np.std(tmp[k])
doc = d

except Exception as e:
self.logger.error("Failed calculating statistics of site "
"descriptors: {}".format(e))
for fp in fps:
doc[fp] = {}
try:
n_site = len(site_descr[fp])
tmp = {}
for isite in range(n_site):
for l, v in site_descr[fp][isite].items():
if l not in list(tmp.keys()):
tmp[l] = []
tmp[l].append(v)
d = []
for k, l in tmp.items():
dtmp = {'name': k}
dtmp['min'] = min(tmp[k])
dtmp['max'] = max(tmp[k])
dtmp['mean'] = np.mean(tmp[k])
dtmp['std'] = np.std(tmp[k])
d.append(dtmp)
doc[fp] = d

except Exception as e:
self.logger.error("Failed calculating statistics of site "
"descriptors: {}".format(e))

return doc
29 changes: 19 additions & 10 deletions emmet/materials/structure_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class StructureSimilarityBuilder(Builder):

def __init__(self, site_descriptors, structure_similarity,
**kwargs):
fp_type='csf', **kwargs):
"""
Calculates similarity metrics between structures on the basis
of site descriptors.
Expand All @@ -20,10 +20,17 @@ def __init__(self, site_descriptors, structure_similarity,
or percentage of 8-fold coordination.
structure_similarity (Store): storage of structure similarity
metrics.
fp_type (str): target site fingerprint type to be
used for similarity computation
("csf" (based on matminer's
CrystalSiteFingerprint class)
or "opsf" (based on matminer's
OPSiteFingerprint class)).
"""

self.site_descriptors = site_descriptors
self.structure_similarity = structure_similarity
self.fp_type = fp_type

super().__init__(sources=[site_descriptors],
targets=[structure_similarity],
Expand All @@ -46,12 +53,12 @@ def get_items(self):
n_task_ids = len(task_ids)
for i in range(n_task_ids-1):
d1 = self.site_descriptors.query(
properties=[self.site_descriptors.key, "opsf_statistics"],
properties=[self.site_descriptors.key, "statistics"],
criteria={self.site_descriptors.key: task_ids[i]}).limit(1)[0]
for j in range(i+1, n_task_ids):
d2 = self.site_descriptors.query(
properties=[
self.site_descriptors.key, "opsf_statistics"],
self.site_descriptors.key, "statistics"],
criteria={self.site_descriptors.key: task_ids[j]}).limit(1)[0]
yield list([d1, d2])

Expand All @@ -63,7 +70,7 @@ def process_item(self, item):
item (list): a list (length 2) with each one document that
carries a task ID in "task_id" and a statistics
vector from OP site-fingerprints in
"opsf_statistics".
"statistics".
Returns:
dict: similarity measures.
Expand Down Expand Up @@ -102,14 +109,16 @@ def get_similarities(self, d1, d2):
dout = {}
l = {}
v = {}
for i, d in enumerate([d1['opsf_statistics'],
d2['opsf_statistics']]):
for i, li in enumerate([d1['statistics'][self.fp_type],
d2['statistics'][self.fp_type]]):
v[i] = []
l[i] = []
for optype, stats in d.items():
for stattype, val in stats.items():
v[i].append(val)
l[i].append('{} {}'.format(optype, stattype))
#for optype, stats in d.items():
for opdict in li:
for stattype, val in opdict.items():
if stattype != 'name':
v[i].append(val)
l[i].append('{} {}'.format(opdict['name'], stattype))
if len(l[0]) != len(l[1]):
raise RuntimeError('Site-fingerprint statistics dictionaries'
' have different sizes ({}, {})'.format(
Expand Down
55 changes: 35 additions & 20 deletions emmet/materials/tests/test_site_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ def test_builder(self):
for t in sd_builder.get_items():
processed = sd_builder.process_item(t)
if processed:
pass
sd_builder.update_targets([processed])
else:
import nose
nose.tools.set_trace()
self.assertEqual(len([t for t in sd_builder.get_items()]), 0)

# Remove one data piece in diamond entry and test partial update.
test_site_descriptors.collection.find_one_and_update(
{'task_id': 'mp-66'}, {'$unset': {'site_descriptors': 1}})
items = [e for e in list(sd_builder.get_items())]
self.assertEqual(len(items), 1)

def test_get_all_site_descriptors(self):
test_site_descriptors = MemoryStore("test_site_descriptors")
Expand All @@ -49,7 +56,7 @@ def test_get_all_site_descriptors(self):
# Diamond.
d = sd_builder.get_site_descriptors_from_struct(Structure.from_dict(C["structure"]))
for di in d.values():
self.assertEqual(len([k for k in di.keys()]), 2)
self.assertEqual(len(di), 2)
self.assertEqual(d['cn_vnn'][0]['CN_VoronoiNN'], 18)
self.assertAlmostEqual(d['cn_wt_vnn'][0]['CN_VoronoiNN'], 4.5381162)
self.assertEqual(d['cn_jmnn'][0]['CN_JMolNN'], 4)
Expand All @@ -64,34 +71,42 @@ def test_get_all_site_descriptors(self):
self.assertAlmostEqual(d['cn_wt_bnn'][0]['CN_BrunnerNN'], 4)
self.assertAlmostEqual(d['opsf'][0]['tetrahedral CN_4'], 0.9995)
#self.assertAlmostEqual(d['csf'][0]['tetrahedral CN_4'], 0.9886777)
ds = sd_builder.get_opsf_statistics(d)
for di in ds.values():
self.assertEqual(len(list(di.keys())), 4)
self.assertAlmostEqual(ds['tetrahedral CN_4']['max'], 0.9995)
self.assertAlmostEqual(ds['tetrahedral CN_4']['min'], 0.9995)
self.assertAlmostEqual(ds['tetrahedral CN_4']['mean'], 0.9995)
self.assertAlmostEqual(ds['tetrahedral CN_4']['std'], 0)
self.assertAlmostEqual(ds['octahedral CN_6']['mean'], 0.0005)
ds = sd_builder.get_statistics(d)
self.assertTrue('opsf' in list(ds.keys()))
self.assertTrue('csf' in list(ds.keys()))
for k, dsk in ds.items():
for di in dsk:
self.assertEqual(len(list(di.keys())), 5)
def get_index(li, optype):
for i, di in enumerate(li):
if di['name'] == optype:
return i
raise RuntimeError('did not find optype {}'.format(optype))
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['max'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['min'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['mean'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['std'], 0)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['mean'], 0.0005)

# NaCl.
d = sd_builder.get_site_descriptors_from_struct(Structure.from_dict(NaCl["structure"]))
self.assertAlmostEqual(d['opsf'][0]['octahedral CN_6'], 0.9995)
#self.assertAlmostEqual(d['csf'][0]['octahedral CN_6'], 1)
ds = sd_builder.get_opsf_statistics(d)
self.assertAlmostEqual(ds['octahedral CN_6']['max'], 0.9995)
self.assertAlmostEqual(ds['octahedral CN_6']['min'], 0.9995)
self.assertAlmostEqual(ds['octahedral CN_6']['mean'], 0.9995)
self.assertAlmostEqual(ds['octahedral CN_6']['std'], 0)
ds = sd_builder.get_statistics(d)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['max'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['min'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['mean'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['std'], 0)

# Iron.
d = sd_builder.get_site_descriptors_from_struct(Structure.from_dict(Fe["structure"]))
self.assertAlmostEqual(d['opsf'][0]['body-centered cubic CN_8'], 0.9995)
#self.assertAlmostEqual(d['csf'][0]['body-centered cubic CN_8'], 0.755096)
ds = sd_builder.get_opsf_statistics(d)
self.assertAlmostEqual(ds['body-centered cubic CN_8']['max'], 0.9995)
self.assertAlmostEqual(ds['body-centered cubic CN_8']['min'], 0.9995)
self.assertAlmostEqual(ds['body-centered cubic CN_8']['mean'], 0.9995)
self.assertAlmostEqual(ds['body-centered cubic CN_8']['std'], 0)
ds = sd_builder.get_statistics(d)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['max'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['min'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['mean'], 0.9995)
self.assertAlmostEqual(ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['std'], 0)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit f026e12

Please sign in to comment.