Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multipeak #13

Merged
merged 6 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cohere_core/controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .reconstruction_GA import *
from .reconstruction_single import *
from .reconstruction_multi import *
from .phasing import *
from .reconstruction_coupled import *
from .phasing import *
from .multigrain import *
5 changes: 5 additions & 0 deletions cohere_core/controller/op_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,10 @@ def get_flow_arr(params, flow_items_list, curr_gen=None, first_run=False):
if 'progress_trigger' in params:
flow_arr[i] = trigger_row(params['progress_trigger'], iter_no)
flow_arr[i][-1] = 1
elif flow_items_list[i] == 'switch_peaks':
if 'switch_peak_trigger' in params:
flow_arr[i] = trigger_row(params['switch_peak_trigger'], iter_no)
flow_arr[i][-1] = 1


return pc_start is not None, flow_arr
258 changes: 257 additions & 1 deletion cohere_core/controller/phasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import time
import os
from math import pi
import random
import importlib

import numpy as np

import cohere_core.utilities.dvc_utils as dvut
import cohere_core.utilities.utils as ut
import cohere_core.utilities.config_verifier as ver
Expand Down Expand Up @@ -131,6 +136,9 @@ def update_phase(self, ds_image):
phase_condition = (phase > self.params['phm_phase_min']) & (phase < self.params['phm_phase_max'])
self.support *= phase_condition

def flip(self):
self.support = devlib.conj(devlib.flip(self.support))


class Rec:
"""
Expand Down Expand Up @@ -332,7 +340,7 @@ def init(self, dir=None, gen=None):
self.gen = gen
self.prev_dir = dir
self.sigma = self.params['shrink_wrap_gauss_sigma']
self.support_obj = Support(self.params, self.dims, dir)
self.support_obj = Support(self.params, self.dims[:-1], dir)
if self.is_pc:
self.pc_obj = Pcdi(self.params, self.data, dir)

Expand Down Expand Up @@ -377,6 +385,7 @@ def breed(self):


def iterate(self):
self.iter = -1
start_t = time.time()
for f in self.flow:
f()
Expand Down Expand Up @@ -534,6 +543,253 @@ def get_ratio(self, divident, divisor):
return ratio


class CoupledRec(Rec):
"""
Performs a coupled reconstruction of multiple Bragg peaks using iterative phase retrieval. It alternates between a
shared object with a density and atomic displacement field and a working object with an amplitude and phase. The
general outline of this process is as follows:
1. Initialize the shared object with random values.
2. Randomly select a diffraction pattern and corresponding reciprocal lattice vector G from the collected data.
3. Set the working object to the projection of the shared object onto G.
4. Apply standard phase retrieval techniques to the working object for a set number of iterations.
5. Update the G-projection of the shared object to be a weighted average of its current value with the working
object.
6. Repeat steps 2-5.

params : dict
parameters used in reconstruction. Refer to x for parameters description
data_file : str
name of file containing data for each peak to be reconstructed

"""
__all__ = []

def __init__(self, params, data_file):
super().__init__(params, data_file)

if "switch_peak_trigger" not in params:
params["switch_peak_trigger"] = [0, 10]
if "mp_max_weight" not in params:
params["mp_max_weight"] = 0.9
if "mp_taper" not in params:
params["mp_taper"] = 0.75

def init_dev(self, device_id):
self.dev = device_id
if device_id != -1:
try:
devlib.set_device(device_id)
except Exception as e:
print(e)
print('may need to restart GUI')
return -1
if self.data_file.endswith('tif'):
try:
data_np = ut.read_tif(self.data_file)
data = devlib.from_numpy(data_np)
except Exception as e:
print(e)
return -1
elif self.data_file.endswith('npy'):
try:
data = devlib.load(self.data_file)
except Exception as e:
print(e)
return -1
else:
print('no data file found')
return -1

# in the formatted data the max is in the center, we want it in the corner, so do fft shift
self.data = devlib.fftshift(devlib.absolute(data))
self.dims = devlib.dims(self.data)[1:]
self.num_peaks = devlib.dims(self.data)[0]
print('data shape:', self.dims)
print('data sets:', self.num_peaks)

if self.need_save_data:
self.saved_data = devlib.copy(self.data)
self.need_save_data = False

return 0

def init(self, img_dir=None, gen=None):
if self.ds_image is not None:
first_run = False
elif img_dir is None or not os.path.isfile(img_dir + '/image.npy'):
self.ds_image = devlib.random(self.dims, dtype=self.data.dtype)
first_run = True
else:
self.ds_image = devlib.load(img_dir + '/image.npy')
first_run = False
# Define the shared image
self.shared_image = devlib.absolute(self.ds_image[:, :, :, None]) * devlib.array([1, 1, 1, 1])

# Define the vectors used when projecting to each peak
G_0 = 2*pi/self.params["lattice_size"]
self.G_vectors = devlib.array([[0, *x] for x in self.params["orientations"]]) * G_0
self.GdotG = devlib.array([devlib.dot(x, x) for x in self.G_vectors])
self.g_vec = self.G_vectors[0]
self.gdotg = self.GdotG[0]
self.pk = 0 # This is the current peak that's being reconstructed
self.rho_hat = devlib.array([1, 0, 0, 0]) # This is the density "unit vector" in the shared object.

iter_functions = [self.next,
self.resolution_trigger,
self.reset_resolution,
self.shrink_wrap_trigger,
self.phase_support_trigger,
self.to_reciprocal_space,
self.new_func_trigger,
self.pc_trigger,
self.pc_modulus,
self.modulus,
self.set_prev_pc_trigger,
self.to_direct_space,
self.er,
self.hio,
self.new_alg,
self.twin_trigger,
self.average_trigger,
self.progress_trigger,
self.switch_peaks]

flow_items_list = []
for f in iter_functions:
flow_items_list.append(f.__name__)

self.is_pc, flow = of.get_flow_arr(self.params, flow_items_list, gen, first_run)

self.flow = []
(op_no, iter_no) = flow.shape
for i in range(iter_no):
for j in range(op_no):
if flow[j, i] == 1:
self.flow.append(iter_functions[j])

# Define the multipeak projection weighting and tapering
coeff = self.params["mp_taper"] / (self.params["mp_taper"] - 1)
self.proj_weight = devlib.square(devlib.cos(devlib.linspace(coeff*1.57, 1.57, iter_no).clip(0, 2)))
self.proj_weight = self.proj_weight * self.params["mp_max_weight"]

self.aver = None
self.iter = -1
self.errs = []
self.gen = gen
self.prev_dir = img_dir
self.sigma = self.params['shrink_wrap_gauss_sigma']
self.support_obj = Support(self.params, self.dims, img_dir)
if self.is_pc:
self.pc_obj = Pcdi(self.params, self.data, dir)

# for the fast GA the data needs to be saved, as it would be changed by each lr generation
# for non-fast GA the Rec object is created in each generation with the initial data
if self.saved_data is not None:
if self.params['low_resolution_generations'] > self.gen:
self.data = devlib.gaussian_filter(self.saved_data, self.params['ga_lowpass_filter_sigmas'][self.gen])
else:
self.data = self.saved_data
else:
if self.gen is not None and self.params['low_resolution_generations'] > self.gen:
self.data = devlib.gaussian_filter(self.data, self.params['ga_lowpass_filter_sigmas'][self.gen])

if 'll_sigma' not in self.params or not first_run:
self.iter_data = self.data[0]
else:
self.iter_data = self.data.copy()[0]

if first_run:
max_data = devlib.amax(self.data)
self.ds_image *= get_norm(self.ds_image) * max_data

# the line below are for testing to set the initial guess to support
# self.ds_image = devlib.full(self.dims, 1.0) + 1j * devlib.full(self.dims, 1.0)

self.ds_image *= self.support_obj.get_support()
return 0

def save_res(self, save_dir):
from array import array

self.shared_image = self.shared_image * self.support_obj.get_support()[:, :, :, None]
if not os.path.exists(save_dir):
os.makedirs(save_dir)
devlib.save(save_dir + "/image", self.shared_image)
devlib.save(save_dir + "/shared_density", self.shared_image[:, :, :, 0])
devlib.save(save_dir + "/shared_u1", self.shared_image[:, :, :, 1])
devlib.save(save_dir + "/shared_u2", self.shared_image[:, :, :, 2])
devlib.save(save_dir + "/shared_u3", self.shared_image[:, :, :, 3])
devlib.save(save_dir + '/support', self.support_obj.get_support())
for i, hkl in enumerate(self.params["orientations"]):
self.g_vec = self.G_vectors[i]
self.gdotg = self.GdotG[i]
self.to_working_image()
suffix = ''
for v in hkl:
suffix += str(v)
devlib.save(save_dir + '/image_' + suffix, self.ds_image)
errs = array('f', self.errs)

with open(save_dir + "/errors.txt", "w+") as err_f:
err_f.write('\n'.join(map(str, errs)))

devlib.save(save_dir + '/errors', errs)

metric = dvut.all_metrics(self.ds_image, self.errs)
with open(save_dir + "/metrics.txt", "w+") as f:
f.write(str(metric))
# for key, value in metric.items():
# f.write(key + ' : ' + str(value) + '\n')

return 0

def switch_peaks(self):
self.to_shared_image()
self.pk = random.choice([x for x in range(self.num_peaks) if x not in (self.pk,)])
self.iter_data = self.data[self.pk]
self.g_vec = self.G_vectors[self.pk]
self.gdotg = self.GdotG[self.pk]
self.to_working_image()
pass

def to_shared_image(self):
beta = self.proj_weight[self.iter]
old_image = (devlib.dot(self.shared_image, self.g_vec) / self.gdotg)[:, :, :, None] * self.g_vec + \
devlib.dot(self.shared_image, self.rho_hat)[:, :, :, None] * self.rho_hat
new_image = (devlib.angle(self.ds_image) / self.gdotg)[:, :, :, None] * self.g_vec + \
devlib.absolute(self.ds_image)[:, :, :, None] * self.rho_hat
self.shared_image = self.shared_image + beta*(new_image - old_image)

def to_working_image(self):
phi = devlib.dot(self.shared_image, self.g_vec)
rho = self.shared_image[:, :, :, 0]
self.ds_image = rho * devlib.exp(1j*phi)

def progress_trigger(self):
pk = self.params["orientations"][self.pk]
print(f'| iter {self.iter:>4} '
f'| [{pk[0]:>2}, {pk[1]:>2}, {pk[2]:>2}] '
f'| err {self.errs[-1]:0.6f} '
f'| max {self.shared_image[:, :, :, 0].max():0.5g}'
)

def get_density(self):
return self.shared_image[:, :, :, 0]

def get_distortion(self):
return self.shared_image[:, :, :, 1:].swapaxes(3, 2).swapaxes(2, 1).swapaxes(1, 0)

def flip(self):
self.shared_image = devlib.flip(self.shared_image, axis=(0, 1, 2))
self.shared_image[:, :, :, 1:] *= -1
self.support_obj.flip()

def shift_to_center(self, ind, cutoff=None):
shift_dist = -devlib.array(ind) + (self.dims[0]//2)
self.shared_image = devlib.shift(self.shared_image, shift_dist, axis=(0, 1, 2))
self.support_obj.support = devlib.shift(self.support_obj.support, shift_dist)


def reconstruction(datafile, **kwargs):
"""
Reconstructs the image from experiment data in datafile according to given parameters. The results: image.npy, support.npy, and errors.npy are saved in 'saved_dir' defined in kwargs, or if not defined, in the directory of datafile.
Expand Down
Loading