diff --git a/GetOrganelleLib/assembly_parser.py b/GetOrganelleLib/assembly_parser.py index 37a845f..02dad3a 100755 --- a/GetOrganelleLib/assembly_parser.py +++ b/GetOrganelleLib/assembly_parser.py @@ -4,28 +4,39 @@ from hashlib import sha256 from collections import OrderedDict +# try: +# from sympy import Symbol, solve, lambdify +# from sympy import log as symlog +# # from scipy import optimize +# except ImportError: +# def Symbol(foo, integer): +# raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify, log'!") +# +# +# def solve(foo1, foo2): +# raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify, log'!") +# +# +# def lambdify(args=None, expr=None): +# raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify, log'!") +# +# +# def symlog(foo): +# raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify, log'!") + try: - from sympy import Symbol, solve, lambdify - from scipy import optimize + from gekko import GEKKO except ImportError: - def Symbol(foo, integer): - raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify'!") - - - def solve(foo1, foo2): - raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify'!") - - - def lambdify(args=None, expr=None): - raise ImportError("Failed in 'from sympy import Symbol, solve, lambdify'!") - + def GEKKO(remote): + raise ImportError("Failed in 'from gekko import GEKKO'!") - class optimize: - def __init__(self): - pass - def minimize(self, fun=None, x0=None, jac=None, method=None, bounds=None, constraints=None, options=None): - raise ImportError("Failed in 'from scipy import optimize'!") + # class optimize: + # def __init__(self): + # pass + # + # def minimize(self, fun=None, x0=None, jac=None, method=None, bounds=None, constraints=None, options=None): + # raise ImportError("Failed in 'from scipy import optimize'!") PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] sys.path.insert(0, os.path.join(PATH_OF_THIS_SCRIPT, "..")) @@ -33,9 +44,9 @@ def minimize(self, fun=None, x0=None, jac=None, method=None, bounds=None, constr sys.path.insert(0, os.path.join(PATH_OF_THIS_SCRIPT, "..")) from GetOrganelleLib.seq_parser import * from GetOrganelleLib.statistical_func import * +from GetOrganelleLib.pipe_control_func import log_target_res PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] -import random from copy import deepcopy MAJOR_VERSION, MINOR_VERSION = sys.version_info[:2] @@ -244,7 +255,7 @@ def vertex_set(self): class SimpleAssembly(object): - def __init__(self, graph_file=None, min_cov=0., max_cov=inf): + def __init__(self, graph_file=None, min_cov=0., max_cov=inf, log_handler=None): """ :param graph_file: :param min_cov: @@ -255,7 +266,7 @@ def __init__(self, graph_file=None, min_cov=0., max_cov=inf): self.__uni_overlap = None if graph_file: if graph_file.endswith(".gfa"): - self.parse_gfa(graph_file, min_cov=min_cov, max_cov=max_cov) + self.parse_gfa(graph_file, min_cov=min_cov, max_cov=max_cov, log_handler=log_handler) else: self.parse_fastg(graph_file, min_cov=min_cov, max_cov=max_cov) @@ -279,7 +290,7 @@ def __iter__(self): for vertex in sorted(self.vertex_info): yield self.vertex_info[vertex] - def parse_gfa(self, gfa_file, default_cov=1., min_cov=0., max_cov=inf): + def parse_gfa(self, gfa_file, default_cov=1., min_cov=0., max_cov=inf, log_handler=None): with open(gfa_file) as gfa_open: kmer_values = set() line = gfa_open.readline() @@ -308,14 +319,19 @@ def parse_gfa(self, gfa_file, default_cov=1., min_cov=0., max_cov=inf): # skip RC/FC if element[0].upper() == "LN": seq_len_tag = int(element[-1]) + check_positive_value(seq_len_tag, "LN", log_handler=log_handler) elif element[0].upper() == "KC": kmer_count = int(element[-1]) + check_positive_value(kmer_count, "KC", log_handler=log_handler) elif element[0].upper() == "RC": # took read counts as kmer counts kmer_count = int(element[-1]) + check_positive_value(kmer_count, "RC", log_handler=log_handler) elif element[0].upper() == "DP": seq_depth_tag = float(element[-1]) + check_positive_value(seq_depth_tag, "DP", log_handler=log_handler) elif element[0].upper() == "RD": # took read depth as seq_depth_tag counts seq_depth_tag = int(element[-1]) + check_positive_value(seq_depth_tag, "RD", log_handler=log_handler) elif element[0].upper() == "SH": sh_256_val = ":".join(element[2:]) elif element[0].upper() == "UR": @@ -394,12 +410,16 @@ def parse_gfa(self, gfa_file, default_cov=1., min_cov=0., max_cov=inf): # skip RC/FC if element[0].upper() == "KC": kmer_count = int(element[-1]) + check_positive_value(kmer_count, "KC", log_handler=log_handler) elif element[0].upper() == "RC": # took read counts as kmer counts kmer_count = int(element[-1]) + check_positive_value(kmer_count, "RC", log_handler=log_handler) elif element[0].upper() == "DP": seq_depth_tag = float(element[-1]) + check_positive_value(seq_depth_tag, "DP", log_handler=log_handler) elif element[0].upper() == "RD": # took read depth as seq_depth_tag counts seq_depth_tag = int(element[-1]) + check_positive_value(seq_depth_tag, "RD", log_handler=log_handler) elif element[0].upper() == "SH": sh_256_val = ":".join(element[2:]) elif element[0].upper() == "UR": @@ -478,7 +498,7 @@ def parse_fastg(self, fastg_file, min_cov=0., max_cov=inf): else: this_vertex_str, next_vertices_str = seq.label.strip(";"), "" v_tag, vertex_name, l_tag, vertex_len, c_tag, vertex_cov = this_vertex_str.strip("'").split("_") - # skip vertices with cov out of bounds + # skip vertices_set with cov out of bounds vertex_cov = float(vertex_cov) if not (min_cov <= vertex_cov <= max_cov): continue @@ -492,7 +512,7 @@ def parse_fastg(self, fastg_file, min_cov=0., max_cov=inf): else: this_vertex_str, next_vertices_str = seq.label.strip(";"), "" v_tag, vertex_name, l_tag, vertex_len, c_tag, vertex_cov = this_vertex_str.strip("'").split("_") - # skip vertices that not in self.vertex_info: 1. with cov out of bounds + # skip vertices_set that not in self.vertex_info: 1. with cov out of bounds if vertex_name in self.vertex_info: # connections this_end = not this_vertex_str.endswith("'") @@ -501,7 +521,7 @@ def parse_fastg(self, fastg_file, min_cov=0., max_cov=inf): next_name = next_vertex_str.strip("'").split("_")[1] if next_name in self.vertex_info: next_end = next_vertex_str.endswith("'") - # Adding connection information (edge) to both of the related vertices + # Adding connection information (edge) to both of the related vertices_set # even it is only mentioned once in some SPAdes output files self.vertex_info[vertex_name].connections[this_end][(next_name, next_end)] = 0 # None? self.vertex_info[next_name].connections[next_end][(vertex_name, this_end)] = 0 @@ -610,13 +630,13 @@ def write_to_gfa(self, out_file, check_postfix=True, other_attr=None): class Assembly(SimpleAssembly): - def __init__(self, graph_file=None, min_cov=0., max_cov=inf, uni_overlap=None): + def __init__(self, graph_file=None, min_cov=0., max_cov=inf, uni_overlap=None, log_handler=None): """ :param graph_file: :param min_cov: :param max_cov: """ - super(Assembly, self).__init__(graph_file=graph_file, min_cov=min_cov, max_cov=max_cov) + super(Assembly, self).__init__(graph_file=graph_file, min_cov=min_cov, max_cov=max_cov, log_handler=log_handler) if uni_overlap: self.__uni_overlap = uni_overlap else: @@ -637,9 +657,9 @@ def __init__(self, graph_file=None, min_cov=0., max_cov=inf, uni_overlap=None): # else: # return int(self.__uni_overlap) - def new_graph_with_vertex_reseeded(self, start_from=1): + def new_graph_with_vertex_reseeded(self, start_from=1, log_handler=None): those_vertices = sorted(self.vertex_info) - new_graph = Assembly(uni_overlap=self.__uni_overlap) + new_graph = Assembly(uni_overlap=self.__uni_overlap, log_handler=log_handler) name_trans = {those_vertices[go - start_from]: str(go) for go in range(start_from, start_from + len(those_vertices))} for old_name in those_vertices: @@ -704,21 +724,24 @@ def write_out_tags(self, db_names, out_file): for db_n in db_names: tagged_vertices |= self.tagged_vertices[db_n] tagged_vertices = sorted(tagged_vertices) - lines = [["EDGE", "database", "database_weight", "loci"]] + lines = [["EDGE", "database", "database_weight", "loci", "loci_weight"]] for this_vertex in tagged_vertices: if "tags" in self.vertex_info[this_vertex].other_attr: - all_tags = self.vertex_info[this_vertex].other_attr["tags"] - all_tag_list = sorted(all_tags) + all_type_tags = self.vertex_info[this_vertex].other_attr["tags"] + all_types = sorted(all_type_tags) all_weights = self.vertex_info[this_vertex].other_attr["weight"] lines.append([this_vertex, - ";".join(all_tag_list), - ";".join([tag_n + "(" + str(all_weights[tag_n]) + ")" for tag_n in all_tag_list]), - ";".join([",".join(sorted(all_tags[tag_n])) for tag_n in all_tag_list])]) + ";".join(all_types), + ";".join([tag_n + "(%.3f)" % all_weights[tag_n] for tag_n in all_types]), + ";".join([",".join(sorted(all_type_tags[tag_n])) for tag_n in all_types]), + ";".join([",".join(sorted(["%s(%.3f)" % (_l, _w) + for _l, _w in all_type_tags[tag_n].items()])) + for tag_n in all_types])]) else: here_tags = {tag_n for tag_n in db_names if this_vertex in self.tagged_vertices[tag_n]} lines.append([this_vertex, ";".join(sorted(here_tags)), - "", ""]) + "", "", ""]) open(out_file, "w").writelines(["\t".join(line) + "\n" for line in lines]) def update_orf_total_len(self, limited_vertices=None): @@ -733,35 +756,106 @@ def update_orf_total_len(self, limited_vertices=None): self.vertex_info[vertex_name].other_attr["orf"][direction] = {"lengths": this_orf_lens, "sum_len": sum(this_orf_lens)} - def update_vertex_clusters(self): - # TODO: faster algorithm exists + # def update_vertex_clusters(self): + # self.vertex_clusters = [] + # vertices_set = sorted(self.vertex_info) + # for this_vertex in vertices_set: + # connecting_those = set() + # for connected_set in self.vertex_info[this_vertex].connections.values(): + # for next_v, next_d in connected_set: + # for go_to_set, cluster in enumerate(self.vertex_clusters): + # if next_v in cluster: + # connecting_those.add(go_to_set) + # if not connecting_those: + # self.vertex_clusters.append({this_vertex}) + # elif len(connecting_those) == 1: + # self.vertex_clusters[connecting_those.pop()].add(this_vertex) + # else: + # sorted_those = sorted(connecting_those, reverse=True) + # self.vertex_clusters[sorted_those[-1]].add(this_vertex) + # for go_to_set in sorted_those[:-1]: + # for that_vertex in self.vertex_clusters[go_to_set]: + # self.vertex_clusters[sorted_those[-1]].add(that_vertex) + # del self.vertex_clusters[go_to_set] + def update_vertex_clusters(self): + """ + faster than original update_vertex_clusters algorithm. + """ self.vertex_clusters = [] - vertices = sorted(self.vertex_info) - for this_vertex in vertices: - connecting_those = set() - for connected_set in self.vertex_info[this_vertex].connections.values(): - for next_v, next_d in connected_set: - for go_to_set, cluster in enumerate(self.vertex_clusters): - if next_v in cluster: - connecting_those.add(go_to_set) - if not connecting_those: - self.vertex_clusters.append({this_vertex}) - elif len(connecting_those) == 1: - self.vertex_clusters[connecting_those.pop()].add(this_vertex) - else: - sorted_those = sorted(connecting_those, reverse=True) - self.vertex_clusters[sorted_those[-1]].add(this_vertex) - for go_to_set in sorted_those[:-1]: - for that_vertex in self.vertex_clusters[go_to_set]: - self.vertex_clusters[sorted_those[-1]].add(that_vertex) - del self.vertex_clusters[go_to_set] + candidate_vs = set(self.vertex_info) + while candidate_vs: + new_root = candidate_vs.pop() + self.vertex_clusters.append({new_root}) + waiting_vs = set([next_v + for this_e in (True, False) + for next_v, next_e in self.vertex_info[new_root].connections[this_e] + if next_v in candidate_vs]) + while candidate_vs and waiting_vs: + next_v = waiting_vs.pop() + self.vertex_clusters[-1].add(next_v) + candidate_vs.discard(next_v) + for next_e in (True, False): + for n_next_v, n_next_e in self.vertex_info[next_v].connections[next_e]: + if n_next_v in candidate_vs: + waiting_vs.add(n_next_v) + # for reproducible, not necessary for some cases + self.vertex_clusters.sort(key=lambda x: max(x)) + + def get_clusters(self, limited_vertices=None): + if limited_vertices is None: + candidate_vs = set(self.vertex_info) + else: + candidate_vs = set(limited_vertices) + vertex_clusters = [] + while candidate_vs: + new_root = candidate_vs.pop() + vertex_clusters.append({new_root}) + waiting_vs = set([next_v + for this_e in (True, False) + for next_v, next_e in self.vertex_info[new_root].connections[this_e] + if next_v in candidate_vs]) + while candidate_vs and waiting_vs: + next_v = waiting_vs.pop() + vertex_clusters[-1].add(next_v) + candidate_vs.discard(next_v) + for next_e in (True, False): + for n_next_v, n_next_e in self.vertex_info[next_v].connections[next_e]: + if n_next_v in candidate_vs: + waiting_vs.add(n_next_v) + # for reproducible, not necessary for some cases + return sorted(vertex_clusters, key=lambda x: max(x)) + + def check_connected(self, vertices_set): + """ + a fast algorithm modified from update_vertex_clusters + :param vertices_set: + :return: + """ + candidate_vs = set(vertices_set) + while candidate_vs: + new_root = candidate_vs.pop() + waiting_vs = set([next_v + for this_e in (True, False) + for next_v, next_e in self.vertex_info[new_root].connections[this_e] + if next_v in candidate_vs]) + while candidate_vs and waiting_vs: + next_v = waiting_vs.pop() + candidate_vs.discard(next_v) + for next_e in (True, False): + for n_next_v, n_next_e in self.vertex_info[next_v].connections[next_e]: + if n_next_v in candidate_vs: + waiting_vs.add(n_next_v) + if candidate_vs and not waiting_vs: + return False + return True def remove_vertex(self, vertices, update_cluster=True): for vertex_name in vertices: for this_end, connected_dict in list(self.vertex_info[vertex_name].connections.items()): for next_v, next_e in list(connected_dict): del self.vertex_info[next_v].connections[next_e][(vertex_name, this_end)] + for vertex_name in vertices: del self.vertex_info[vertex_name] for tag in self.tagged_vertices: if vertex_name in self.tagged_vertices[tag]: @@ -817,7 +911,7 @@ def rename_vertex(self, old_vertex, new_vertex, update_cluster=True): # self.merging_history[new_vertex] = self.merging_history[old_vertex] # del self.merging_history[old_vertex] - def detect_parallel_vertices(self, limited_vertices=None): + def detect_parallel_vertices(self, limited_vertices=None, detect_neighbors=True): if not limited_vertices: limiting = False limited_vertices = sorted(self.vertex_info) @@ -837,7 +931,7 @@ def detect_parallel_vertices(self, limited_vertices=None): if this_ends not in all_both_ends: all_both_ends[this_ends] = set() all_both_ends[this_ends].add((vertex_name, direction_remained)) - if limiting: + if limiting and detect_neighbors: limited_vertex_set = set(limited_vertices) for each_vertex in self.vertex_info: if each_vertex not in limited_vertex_set: @@ -905,7 +999,7 @@ def is_sequential_repeat(self, search_vertex_name, return_pair_in_the_trunk_path all_pairs_of_inner_circles.sort( key=lambda x: (self.vertex_info[x[0][0]].cov + self.vertex_info[x[1][0]].cov)) if all_pairs_of_inner_circles and return_pair_in_the_trunk_path: - # switch nearby vertices + # switch nearby vertices_set # keep those prone to be located in the "trunk road" of the repeat single_pair_in_main_path = [] if len(all_pairs_of_inner_circles) == 1: @@ -923,6 +1017,11 @@ def is_sequential_repeat(self, search_vertex_name, return_pair_in_the_trunk_path return all_pairs_of_inner_circles def merge_all_possible_vertices(self, limited_vertices=None, copy_tags=True): + # follow variables were not updated because of max_majority_copy should not be considered here + # self.copy_to_vertex + # self.vertex_to_copy + # self.vertex_to_float_copy + if not limited_vertices: limited_vertices = sorted(self.vertex_info) else: @@ -975,6 +1074,20 @@ def merge_all_possible_vertices(self, limited_vertices=None, copy_tags=True): self.vertex_info[new_vertex].seq[not this_end] \ = self.vertex_info[next_vertex].seq[next_end][:next_len - this_overlap] \ + self.vertex_info[this_vertex].seq[not this_end] + + # follow variables were not updated because of max_majority_copy should not be considered here + # self.copy_to_vertex + # self.vertex_to_copy + # self.vertex_to_float_copy + # average_cov = this_cov / self.vertex_to_float_copy[this_vertex] + # this_float_copy = self.vertex_info[new_vertex].cov / average_cov + # this_copy = min(max(1, int(round(this_float_copy, 0))), max_majority_copy) + # self.vertex_to_float_copy[new_vertex] = this_float_copy + # self.vertex_to_copy[new_vertex] = this_copy + # if this_copy not in self.copy_to_vertex: + # self.copy_to_vertex[this_copy] = set() + # self.copy_to_vertex[this_copy].add(new_vertex) + # tags if copy_tags: if "tags" in self.vertex_info[next_vertex].other_attr: @@ -987,8 +1100,14 @@ def merge_all_possible_vertices(self, limited_vertices=None, copy_tags=True): self.vertex_info[new_vertex].other_attr["tags"][db_n] \ = deepcopy(self.vertex_info[next_vertex].other_attr["tags"][db_n]) else: - self.vertex_info[new_vertex].other_attr["tags"][db_n] \ - |= self.vertex_info[next_vertex].other_attr["tags"][db_n] + # adjust for update in 2023-01-13 + for ln, lw in self.vertex_info[next_vertex].other_attr["tags"][db_n].items(): + if ln not in self.vertex_info[new_vertex].other_attr["tags"][db_n]: + self.vertex_info[new_vertex].other_attr["tags"][db_n][ln] = lw + else: + self.vertex_info[new_vertex].other_attr["tags"][db_n][ln] += lw + # self.vertex_info[new_vertex].other_attr["tags"][db_n] \ + # |= self.vertex_info[other_vertex].other_attr["tags"][db_n] if "weight" in self.vertex_info[next_vertex].other_attr: if "weight" not in self.vertex_info[new_vertex].other_attr: self.vertex_info[new_vertex].other_attr["weight"] \ @@ -1010,35 +1129,53 @@ def merge_all_possible_vertices(self, limited_vertices=None, copy_tags=True): self.tagged_vertices[db_n].remove(next_vertex) self.remove_vertex([this_vertex, next_vertex], update_cluster=False) break - self.update_vertex_clusters() + if merged: + self.update_vertex_clusters() return merged def estimate_copy_and_depth_by_cov(self, limited_vertices=None, given_average_cov=None, mode="embplant_pt", - re_initialize=False, log_handler=None, verbose=True, debug=False): + min_sigma=0.1, re_initialize=False, log_handler=None, verbose=True, debug=False): + """ + :param limited_vertices: + :param given_average_cov: + :param mode: + :param min_sigma: when only one sample + :param re_initialize: + :param log_handler: + :param verbose: + :param debug: + :return: + """ # overlap = self.__overlap if self.__overlap else 0 - if mode == "embplant_pt": - max_majority_copy = 2 - elif mode == "other_pt": - max_majority_copy = 10 - elif mode == "embplant_mt": - max_majority_copy = 4 - elif mode == "embplant_nr": - max_majority_copy = 2 - elif mode == "animal_mt": - max_majority_copy = 4 - elif mode == "fungus_mt": - max_majority_copy = 8 - elif mode == "fungus_nr": - max_majority_copy = 4 - elif mode == "all": - max_majority_copy = 100 - else: - max_majority_copy = 100 + # those are all empirical values + # TODO: GetOrganelle need a better algorithm for target filtering + # if mode == "embplant_pt": + # max_majority_copy = 2 + # elif mode == "other_pt": + # max_majority_copy = 10 + # elif mode == "embplant_mt": + # max_majority_copy = 4 + # elif mode == "embplant_nr": + # max_majority_copy = 2 + # elif mode == "animal_mt": + # # the difference between mt and nucl are usually not that large, + # # making the upper boundary useless for excluding non-target but excluding target by mistake + # max_majority_copy = 1000 + # elif mode == "fungus_mt": + # max_majority_copy = 8 + # elif mode == "fungus_nr": + # max_majority_copy = 4 + # elif mode == "all": + # max_majority_copy = 100 + # else: + # max_majority_copy = 100 if not limited_vertices: limited_vertices = sorted(self.vertex_info) else: limited_vertices = sorted(limited_vertices) + if not limited_vertices: + raise ProcessingGraphFailed("Too strict criteria removing all contigs in an insufficient graph") if re_initialize: for vertex_name in limited_vertices: @@ -1051,9 +1188,13 @@ def estimate_copy_and_depth_by_cov(self, limited_vertices=None, given_average_co self.copy_to_vertex[1] = set() self.copy_to_vertex[1].add(vertex_name) + cov_ls = [] + len_ls = [] if not given_average_cov: previous_val = {0.} new_val = -1. + new_std = -1. + # arbitrary setting, without influence of limited_vertices min_average_depth = 0.9 * min([self.vertex_info[vertex_n].cov for vertex_n in self.vertex_info]) while round(new_val, 5) not in previous_val: previous_val.add(round(new_val, 5)) @@ -1063,11 +1204,15 @@ def estimate_copy_and_depth_by_cov(self, limited_vertices=None, given_average_co for vertex_name in limited_vertices: # do we need to exclude the overlap? this_len = self.vertex_info[vertex_name].len * self.vertex_to_copy.get(vertex_name, 1) + len_ls.append(this_len) this_cov = self.vertex_info[vertex_name].cov / self.vertex_to_copy.get(vertex_name, 1) + cov_ls.append(this_cov) total_len += this_len total_product += this_len * this_cov # new_val = total_product / total_len new_val = max(total_product / total_len, min_average_depth) + new_std = (sum([_w * (new_val - _c) ** 2 for _c, _w in zip(cov_ls, len_ls)]) / sum(len_ls)) ** 0.5 + # (sum(len_ls) * (len(cov_ls) - 1) / len(cov_ls))) ** 0.5 # print("new val: ", new_val) # adjust this_copy according to new baseline depth for vertex_name in self.vertex_info: @@ -1077,7 +1222,8 @@ def estimate_copy_and_depth_by_cov(self, limited_vertices=None, given_average_co if not self.copy_to_vertex[old_copy]: del self.copy_to_vertex[old_copy] this_float_copy = self.vertex_info[vertex_name].cov / new_val - this_copy = min(max(1, int(round(this_float_copy, 0))), max_majority_copy) + # this_copy = min(max(1, int(round(this_float_copy, 0))), max_majority_copy) + this_copy = max(1, int(round(this_float_copy, 0))) self.vertex_to_float_copy[vertex_name] = this_float_copy self.vertex_to_copy[vertex_name] = this_copy if this_copy not in self.copy_to_vertex: @@ -1090,30 +1236,746 @@ def estimate_copy_and_depth_by_cov(self, limited_vertices=None, given_average_co else: sys.stdout.write("updating average " + mode + cov_str + str(round(new_val, 2)) + "\n") # print("return ", new_val) - return new_val + return new_val, new_val * min_sigma if len(limited_vertices) == 1 and new_std == 0. else new_std else: # adjust this_copy according to user-defined depth - for vertex_name in self.vertex_info: + for vertex_name in limited_vertices: if vertex_name in self.vertex_to_copy: old_copy = self.vertex_to_copy[vertex_name] self.copy_to_vertex[old_copy].remove(vertex_name) if not self.copy_to_vertex[old_copy]: del self.copy_to_vertex[old_copy] this_float_copy = self.vertex_info[vertex_name].cov / given_average_cov - this_copy = min(max(1, int(round(this_float_copy, 0))), max_majority_copy) + # this_copy = min(max(1, int(round(this_float_copy, 0))), max_majority_copy) + this_copy = max(1, int(round(this_float_copy, 0))) self.vertex_to_float_copy[vertex_name] = this_float_copy self.vertex_to_copy[vertex_name] = this_copy if this_copy not in self.copy_to_vertex: self.copy_to_vertex[this_copy] = set() self.copy_to_vertex[this_copy].add(vertex_name) - return given_average_cov - - def estimate_copy_and_depth_precisely(self, maximum_copy_num=8, broken_graph_allowed=False, - return_new_graphs=False, verbose=False, log_handler=None, debug=False, - target_name_for_log="target"): - - def get_formula(from_vertex, from_end, back_to_vertex, back_to_end, here_record_ends): - result_form = vertex_to_symbols[from_vertex] + cov_ls.append(self.vertex_info[vertex_name].cov / this_copy) + len_ls.append(self.vertex_info[vertex_name].len * this_copy) + new_std = (sum([_w * (given_average_cov - _c) ** 2 for _c, _w in zip(cov_ls, len_ls)]) / sum(len_ls)) ** 0.5 + return given_average_cov, \ + given_average_cov * min_sigma if len(limited_vertices) == 1 and new_std == 0. else new_std + + # def estimate_copy_and_depth_precisely_using_multinomial( + # self, expected_average_cov, # broken_graph_allowed=False, + # verbose=False, log_handler=None, debug=False, + # target_name_for_log="target", n_iterations=None): + # """ + # Currently problematic because of + # frequently reporting + # Exception: @error: Solution Not Found + # + # :param expected_average_cov: + # :param verbose: + # :param log_handler: + # :param debug: + # :param target_name_for_log: + # :return: + # """ + # # TODO: to test whether it's better than least-square + # + # def get_formula(from_vertex, from_end, back_to_vertex, back_to_end, here_record_ends): + # result_form = vertex_to_symbols[from_vertex] + # here_record_ends.add((from_vertex, from_end)) + # # if back_to_vertex ~ from_vertex (from_vertex == back_to_vertex) form a loop, skipped + # if from_vertex != back_to_vertex: + # for next_v, next_e in self.vertex_info[from_vertex].connections[from_end]: + # # if next_v ~ from_vertex (next_v == from_vertex) form a loop, add a pseudo vertex + # if (next_v, next_e) == (from_vertex, not from_end): + # # skip every self-loop 2020-06-23 + # # pseudo_self_circle_str = "P" + from_vertex + # # if pseudo_self_circle_str not in extra_str_to_symbol_m2: + # # extra_str_to_symbol_m2[pseudo_self_circle_str] = Symbol(pseudo_self_circle_str, integer=True) + # # extra_symbol_to_str_m2[extra_str_to_symbol_m2[pseudo_self_circle_str]] = pseudo_self_circle_str + # # result_form -= (extra_str_to_symbol_m2[pseudo_self_circle_str] - 1) + # pass + # # elif (next_v, next_e) != (back_to_vertex, back_to_end): + # elif (next_v, next_e) not in here_record_ends: + # result_form -= get_formula(next_v, next_e, from_vertex, from_end, here_record_ends) + # return result_form + # + # # # for compatibility between scipy and sympy + # # def least_square_function_v(x): + # # return least_square_function(*tuple(x)) + # # + # # """ create constraints by creating inequations: the copy of every contig has to be >= 1 """ + # # + # # def constraint_min_function(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # min_copy = np.array([1.001] * len(all_v_symbols) + + # # [1.001] * len(extra_symbol_to_str_m1) + + # # [2.001] * len(extra_symbol_to_str_m2)) + # # # effect: expression_array >= int(min_copy) + # # return expression_array - min_copy + # # + # # def constraint_min_function_for_customized_brute(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # min_copy = np.array([1.0] * len(all_v_symbols) + + # # [1.0] * len(extra_symbol_to_str_m1) + + # # [2.0] * len(extra_symbol_to_str_m2)) + # # # effect: expression_array >= min_copy + # # return expression_array - min_copy + # + # def constraint_min_function_for_gekko(g_vars): + # subs_tuples = [(symb_used_, Symbol("g_vars[" + str(go_sym) + "]")) + # for go_sym, symb_used_ in enumerate(free_copy_variables)] + # expression_array = [copy_solution[this_sym].subs(subs_tuples) for this_sym in all_symbols] + # min_copy = [1] * len(all_v_symbols) + \ + # [1] * len(extra_symbol_to_str_m1) + \ + # [2] * len(extra_symbol_to_str_m2) + # # effect: expression_array >= min_copy + # expression = [] + # if verbose or debug: + # for e, c in zip(expression_array, min_copy): + # expression.append(eval(str(e) + ">=" + str(c))) + # log_handler.info(" constraint: " + str(e) + ">=" + str(c)) + # else: + # for e, c in zip(expression_array, min_copy): + # expression.append(eval(str(e) + ">=" + str(c))) + # expression = [expr for expr in expression if not isinstance(expr, bool)] + # return expression + # + # # def constraint_max_function(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # max_copy = np.array([expected_average_cov] * len(all_v_symbols) + + # # [expected_average_cov] * len(extra_symbol_to_str_m1) + + # # [expected_average_cov * 2] * len(extra_symbol_to_str_m2)) + # # # effect: expression_array <= max_copy + # # return max_copy - expression_array + # # + # # def constraint_int_function(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # # diff = np.array([0] * len(all_symbols)) + # # return sum([abs(every_copy - int(every_copy)) for every_copy in expression_array]) + # # + # # def minimize_brute_force(func, range_list, constraint_list, round_digit=4, display_p=True, + # # in_log_handler=log_handler): + # # # time0 = time.time() + # # best_fun_val = inf + # # best_para_val = [] + # # count_round = 0 + # # count_valid = 0 + # # for value_set in product(*[list(this_range) for this_range in range_list]): + # # count_round += 1 + # # is_valid_set = True + # # for cons in constraint_list: + # # if cons["type"] == "ineq": + # # try: + # # if (cons["fun"](value_set) < 0).any(): + # # is_valid_set = False + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) + # # break + # # except TypeError: + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) + # # is_valid_set = False + # # break + # # elif cons["type"] == "eq": + # # try: + # # if cons["fun"](value_set) != 0: + # # is_valid_set = False + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) + # # break + # # except TypeError: + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) + # # is_valid_set = False + # # break + # # if not is_valid_set: + # # continue + # # count_valid += 1 + # # this_fun_val = func(value_set) + # # if in_log_handler: + # # if debug or display_p: + # # in_log_handler.info("value_set={} ; fun_val={}".format(value_set, this_fun_val)) + # # this_fun_val = round(this_fun_val, round_digit) + # # if this_fun_val < best_fun_val: + # # best_para_val = [value_set] + # # best_fun_val = this_fun_val + # # elif this_fun_val == best_fun_val: + # # best_para_val.append(value_set) + # # else: + # # pass + # # if in_log_handler: + # # if debug or display_p: + # # in_log_handler.info("Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round)) + # # in_log_handler.info("Brute best function value: " + str(best_fun_val)) + # # if debug: + # # in_log_handler.info("Best solution: " + str(best_para_val)) + # # else: + # # if debug or display_p: + # # sys.stdout.write( + # # "Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round) + "\n") + # # sys.stdout.write("Brute best function value: " + str(best_fun_val) + "\n") + # # if debug: + # # sys.stdout.write("Best solution: " + str(best_para_val) + "\n") + # # return best_para_val + # + # vertices_list = sorted(self.vertex_info) + # if len(vertices_list) == 1: + # cov_ = self.vertex_info[vertices_list[0]].cov + # # 2022-12-15, remove return_new_graph + # # if return_new_graphs: + # return [{"graph": deepcopy(self), "cov": cov_}] + # # else: + # # if log_handler: + # # log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2))) + # # else: + # # sys.stdout.write( + # # "Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2)) + "\n") + # # return + # + # # reduce expected_average_cov to reduce computational burden + # all_coverages = [self.vertex_info[v_name].cov for v_name in vertices_list] + # # max_contig_multiplicity = \ + # # min(max_contig_multiplicity, int(2 * math.ceil(max(all_coverages) / min(all_coverages)))) + # # if verbose: + # # if log_handler: + # # log_handler.info("Maximum multiplicity: " + str(max_contig_multiplicity)) + # # else: + # # sys.stdout.write("Maximum multiplicity: " + str(max_contig_multiplicity) + "\n") + # + # """ create constraints by creating multivariate equations """ + # vertex_to_symbols = {vertex_name: Symbol("V" + vertex_name, integer=True) # positive=True) + # for vertex_name in vertices_list} + # symbols_to_vertex = {vertex_to_symbols[vertex_name]: vertex_name for vertex_name in vertices_list} + # extra_str_to_symbol_m1 = {} + # extra_str_to_symbol_m2 = {} + # extra_symbol_to_str_m1 = {} + # extra_symbol_to_str_m2 = {} + # extra_symbol_initial_values = {} + # formulae = [] + # recorded_ends = set() + # for vertex_name in vertices_list: + # for this_end in (True, False): + # if (vertex_name, this_end) not in recorded_ends: + # recorded_ends.add((vertex_name, this_end)) + # if self.vertex_info[vertex_name].connections[this_end]: + # this_formula = vertex_to_symbols[vertex_name] + # formulized = False + # for n_v, n_e in self.vertex_info[vertex_name].connections[this_end]: + # if (n_v, n_e) not in recorded_ends: + # # if n_v in vertices_set: + # # recorded_ends.add((n_v, n_e)) + # try: + # this_formula -= get_formula(n_v, n_e, vertex_name, this_end, recorded_ends) + # formulized = True + # # if verbose: + # # if log_handler: + # # log_handler.info("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # # vertex_name + ECHO_DIRECTION[this_end] + ": " + + # # str(this_formula)) + # # else: + # # sys.stdout.write("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # # vertex_name + ECHO_DIRECTION[this_end] + ": " + + # # str(this_formula)+"\n") + # except RecursionError: + # if log_handler: + # log_handler.warning("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # vertex_name + ECHO_DIRECTION[this_end] + " failed!") + # else: + # sys.stdout.write("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # vertex_name + ECHO_DIRECTION[this_end] + " failed!\n") + # raise ProcessingGraphFailed("RecursionError!") + # if verbose: + # if log_handler: + # log_handler.info( + # "formulating for: " + vertex_name + ECHO_DIRECTION[this_end] + ": " + + # str(this_formula)) + # else: + # sys.stdout.write( + # "formulating for: " + vertex_name + ECHO_DIRECTION[this_end] + ": " + + # str(this_formula) + "\n") + # if formulized: + # formulae.append(this_formula) + # # 2022-12-13 remove this restriction + # # because we have a reduce_list_with_gcd for all graph component + # # elif broken_graph_allowed: + # # # Extra limitation to force terminal vertex to have only one copy, to avoid over-estimation + # # # Under-estimation would not be a problem here, + # # # because the True-multiple-copy vertex would simply have no other connections, + # # # or failed in the following estimation if it does + # # formulae.append(vertex_to_symbols[vertex_name] - 1) + # + # # add self-loop formulae + # self_loop_v = set() + # for vertex_name in vertices_list: + # if self.vertex_info[vertex_name].is_self_loop(): + # self_loop_v.add(vertex_name) + # if log_handler: + # log_handler.warning("Self-loop contig detected: Vertex_" + vertex_name) + # pseudo_self_loop_str = "P" + vertex_name + # if pseudo_self_loop_str not in extra_str_to_symbol_m1: + # extra_str_to_symbol_m1[pseudo_self_loop_str] = Symbol(pseudo_self_loop_str, integer=True) + # extra_symbol_to_str_m1[extra_str_to_symbol_m1[pseudo_self_loop_str]] = pseudo_self_loop_str + # this_formula = vertex_to_symbols[vertex_name] - extra_str_to_symbol_m1[pseudo_self_loop_str] + # extra_symbol_initial_values[extra_str_to_symbol_m1[pseudo_self_loop_str]] = \ + # self.vertex_to_copy[vertex_name] + # formulae.append(this_formula) + # if verbose: + # if log_handler: + # log_handler.info( + # "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula)) + # else: + # sys.stdout.write( + # "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula) + "\n") + # + # # add following extra limitation + # # set cov_sequential_repeat = x*near_by_cov, x is an integer + # for vertex_name in vertices_list: + # single_pair_in_the_trunk_path = self.is_sequential_repeat(vertex_name) + # if single_pair_in_the_trunk_path: + # (from_v, from_e), (to_v, to_e) = single_pair_in_the_trunk_path + # # from_v and to_v are already in the "trunk path", if they are the same, + # # the graph is like two circles sharing the same sequential repeat, no need to add this limitation + # if from_v != to_v: + # new_str = "E" + str(len(extra_str_to_symbol_m1) + len(extra_str_to_symbol_m2)) + # if vertex_name in self_loop_v: + # # self-loop vertex is allowed to have the multiplicity of 1 + # extra_str_to_symbol_m1[new_str] = Symbol(new_str, integer=True) + # extra_symbol_to_str_m1[extra_str_to_symbol_m1[new_str]] = new_str + # this_formula = vertex_to_symbols[vertex_name] - \ + # vertex_to_symbols[from_v] * extra_str_to_symbol_m1[new_str] + # extra_symbol_initial_values[extra_str_to_symbol_m1[new_str]] = \ + # round(self.vertex_to_float_copy[vertex_name] / self.vertex_to_float_copy[from_v]) + # else: + # extra_str_to_symbol_m2[new_str] = Symbol(new_str, integer=True) + # extra_symbol_to_str_m2[extra_str_to_symbol_m2[new_str]] = new_str + # this_formula = vertex_to_symbols[vertex_name] - \ + # vertex_to_symbols[from_v] * extra_str_to_symbol_m2[new_str] + # extra_symbol_initial_values[extra_str_to_symbol_m2[new_str]] = \ + # round(self.vertex_to_float_copy[vertex_name] / self.vertex_to_float_copy[from_v]) + # formulae.append(this_formula) + # if verbose: + # if log_handler: + # log_handler.info("formulating for: " + vertex_name + ": " + str(this_formula)) + # else: + # sys.stdout.write("formulating for: " + vertex_name + ": " + str(this_formula) + "\n") + # + # all_v_symbols = list(symbols_to_vertex) + # all_symbols = all_v_symbols + list(extra_symbol_to_str_m1) + list(extra_symbol_to_str_m2) + # if verbose or debug: + # if log_handler: + # log_handler.info("formulae: " + str(formulae)) + # else: + # sys.stdout.write("formulae: " + str(formulae) + "\n") + # # solve the equations + # copy_solution = solve(formulae, all_v_symbols) + # + # copy_solution = copy_solution if copy_solution else {} + # if type(copy_solution) == list: # delete 0 containing set, even for self-loop vertex + # go_solution = 0 + # while go_solution < len(copy_solution): + # if 0 in set(copy_solution[go_solution].values()): + # del copy_solution[go_solution] + # else: + # go_solution += 1 + # if not copy_solution: + # raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (1)!") + # elif type(copy_solution) == list: + # if len(copy_solution) > 2: + # raise ProcessingGraphFailed("Incomplete/Complicated " + target_name_for_log + " graph (2)!") + # else: + # copy_solution = copy_solution[0] + # + # free_copy_variables = list() + # for symbol_used in all_symbols: + # if symbol_used not in copy_solution: + # free_copy_variables.append(symbol_used) + # copy_solution[symbol_used] = symbol_used + # if verbose: + # if log_handler: + # log_handler.info("copy equations: " + str(copy_solution)) + # log_handler.info("free variables: " + str(free_copy_variables)) + # else: + # sys.stdout.write("copy equations: " + str(copy_solution) + "\n") + # sys.stdout.write("free variables: " + str(free_copy_variables) + "\n") + # + # # """ minimizing equation-based copy's deviations from coverage-based copy values """ + # # least_square_expr = 0 + # # for symbol_used in all_v_symbols: + # # # least_square_expr += copy_solution[symbol_used] + # # this_vertex = symbols_to_vertex[symbol_used] + # # this_copy = self.vertex_to_float_copy[this_vertex] + # # least_square_expr += (copy_solution[symbol_used] - this_copy) ** 2 # * self.vertex_info[this_vertex]["len"] + # # least_square_function = lambdify(args=free_copy_variables, expr=least_square_expr) + # + # if free_copy_variables: + # """ Maximize the likelihood of the multinomial distribution of kmers (kmer_cov * contig_len)""" + # m = GEKKO(remote=False) + # g_vars = m.Array(m.Var, + # len(free_copy_variables), + # lb=1, + # ub=int(4 * math.ceil(max(all_coverages) / min(all_coverages))), + # integer=True) + # # initialize free variables + # for go_sym, symbol_used in enumerate(free_copy_variables): + # if symbol_used in symbols_to_vertex and symbols_to_vertex[symbol_used] in self.vertex_to_copy: + # g_vars[go_sym].value = self.vertex_to_copy[symbols_to_vertex[symbol_used]] + # elif symbol_used in extra_symbol_initial_values: + # g_vars[go_sym].value = extra_symbol_initial_values[symbol_used] + # replacements = [(symbol_used, Symbol("g_vars[" + str(go_sym) + "]")) + # for go_sym, symbol_used in enumerate(free_copy_variables)] + # # account for the influence of the overlap + # total_len = 0 + # multinomial_loglike_list = [] + # v_to_len = {} + # v_to_copy = {} + # v_to_real_len = {} + # all_obs = [] + # if self.__uni_overlap: + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # v_to_real_len[this_vertex] = self.vertex_info[this_vertex].len - self.__uni_overlap + # v_to_copy[this_vertex] = eval(str(copy_solution[symbol_used].subs(replacements))) + # v_to_len[this_vertex] = v_to_copy[this_vertex] * v_to_real_len[this_vertex] + # total_len += v_to_len[this_vertex] + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # prob = v_to_len[this_vertex] / total_len + # obs = self.vertex_info[this_vertex].cov * v_to_real_len[this_vertex] + # multinomial_loglike_list.append(m.log(prob) * obs) + # all_obs.append(obs) + # if verbose: + # if log_handler: + # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # else: + # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # else: + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # overlaps = [_ovl + # for _strand in (True, False) + # for _next, _ovl in self.vertex_info[this_vertex].connections[_strand].items()] + # approximate_overlap = average_np_free(overlaps) + # v_to_real_len[this_vertex] = self.vertex_info[this_vertex].len - approximate_overlap + # v_to_copy[this_vertex] = eval(str(copy_solution[symbol_used].subs(replacements))) + # v_to_len[this_vertex] = v_to_copy[this_vertex] * v_to_real_len[this_vertex] + # total_len += v_to_len[this_vertex] + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # prob = v_to_len[this_vertex] / total_len + # obs = self.vertex_info[this_vertex].cov * v_to_real_len[this_vertex] + # multinomial_loglike_list.append(m.log(prob) * obs) + # all_obs.append(obs) + # if verbose: + # if log_handler: + # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # else: + # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # """extra arbitrary restriction to avoid over inflation of copies""" + # sum_obs = sum(all_obs) + # multinomial_loglike_list.append(-abs(sum_obs / expected_average_cov - total_len)) + # # """extra restriction to constraint the integer solution for dependant variables""" + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # multinomial_loglike_list.append(sum_obs * (v_to_copy[this_vertex] - int(v_to_copy[this_vertex])) ** 2) + # """generate the expression""" + # # multinomial_loglike_expr = m.sum(multinomial_loglike_list) will lead to No solution error + # multinomial_loglike_expr = sum(multinomial_loglike_list) + # exp_str_len = len(str(multinomial_loglike_expr)) + # if exp_str_len > 15000: # not allowed by Gekko:APM + # num_blocks = math.ceil(exp_str_len / 10000.) + # block_size = math.ceil(len(multinomial_loglike_list) / float(num_blocks)) + # block_list = [] + # for g_b in range(num_blocks): + # block_list.append(sum(multinomial_loglike_list[g_b * block_size: (g_b + 1)* block_size])) + # multinomial_loglike_expr = m.sum(block_list) + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # total_len += eval(str(copy_solution[symbol_used].subs(replacements))) * self.vertex_info[this_vertex].len + # # multinomial_like_expr = 0 + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # prob = eval(str(copy_solution[symbol_used].subs(replacements))) \ + # # * self.vertex_info[this_vertex].len / total_len + # # obs = self.vertex_info[this_vertex].cov * self.vertex_info[this_vertex].len + # # multinomial_like_expr += m.log(prob) * obs + # # if verbose: + # # if log_handler: + # # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # # else: + # # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # + # # multinomial_like_function = lambdify(args=free_copy_variables, expr=multinomial_like_expr) + # m.Equations(constraint_min_function_for_gekko(g_vars)) + # m.Maximize(multinomial_loglike_expr) + # # m.Minimize(least_square_function_v(g_vars)) + # # 1 for APOPT, 2 for BPOPT, 3 for IPOPT, 0 for all available solvers + # # here only 1 and 3 are available + # m.options.SOLVER = 1 + # # setting empirical options + # # 5000 costs ~ 150 sec + # if n_iterations is None: + # n_high_copy = sum([math.log2(self.vertex_to_float_copy[_v]) + # for _v in self.vertex_info if self.vertex_to_float_copy[_v] > 2]) + # n_iterations = 500 + int(len(self.vertex_info) * n_high_copy) + # if verbose or debug: + # log_handler.info("setting n_iterations=" + str(n_iterations)) + # m.solver_options = ['minlp_maximum_iterations ' + str(n_iterations), + # # minlp iterations with integer solution + # 'minlp_max_iter_with_int_sol ' + str(n_iterations), + # # treat minlp as nlp + # 'minlp_as_nlp 0', + # # nlp sub-problem max iterations + # 'nlp_maximum_iterations ' + str(n_iterations), + # # 1 = depth first, 2 = breadth first + # 'minlp_branch_method 2', + # # maximum deviation from whole number + # 'minlp_integer_tol 1.0e-6', + # # covergence tolerance + # 'minlp_gap_tol 1.0e-6'] + # if debug or verbose: + # m.solve() + # else: + # m.solve(disp=False) + # # print([x.value[0] for x in g_vars]) + # copy_results = list([x.value[0] for x in g_vars]) + # + # # # for safe running + # # if len(free_copy_variables) > 10: + # # raise ProcessingGraphFailed("Free variable > 10 is not accepted yet!") + # # + # # if expected_average_cov ** len(free_copy_variables) < 5E6: + # # # sometimes, SLSQP ignores bounds and constraints + # # copy_results = minimize_brute_force( + # # func=least_square_function_v, range_list=[range(1, expected_average_cov + 1)] * len(free_copy_variables), + # # constraint_list=({'type': 'ineq', 'fun': constraint_min_function_for_customized_brute}, + # # {'type': 'eq', 'fun': constraint_int_function}, + # # {'type': 'ineq', 'fun': constraint_max_function}), + # # display_p=verbose) + # # else: + # # constraints = ({'type': 'ineq', 'fun': constraint_min_function}, + # # {'type': 'eq', 'fun': constraint_int_function}, + # # {'type': 'ineq', 'fun': constraint_max_function}) + # # copy_results = set() + # # best_fun = inf + # # opt = {'disp': verbose, "maxiter": 100} + # # for initial_copy in range(expected_average_cov * 2 + 1): + # # if initial_copy < expected_average_cov: + # # initials = np.array([initial_copy + 1] * len(free_copy_variables)) + # # elif initial_copy < expected_average_cov * 2: + # # initials = np.array([random.randint(1, expected_average_cov)] * len(free_copy_variables)) + # # else: + # # initials = np.array([self.vertex_to_copy.get(symbols_to_vertex.get(symb, False), 2) + # # for symb in free_copy_variables]) + # # bounds = [(1, expected_average_cov) for foo in range(len(free_copy_variables))] + # # try: + # # copy_result = optimize.minimize(fun=least_square_function_v, x0=initials, jac=False, + # # method='SLSQP', bounds=bounds, constraints=constraints, options=opt) + # # except Exception: + # # continue + # # if copy_result.fun < best_fun: + # # best_fun = round(copy_result.fun, 2) + # # copy_results = {tuple(copy_result.x)} + # # elif copy_result.fun == best_fun: + # # copy_results.add(tuple(copy_result.x)) + # # else: + # # pass + # # if debug or verbose: + # # if log_handler: + # # log_handler.info("Best function value: " + str(best_fun)) + # # else: + # # sys.stdout.write("Best function value: " + str(best_fun) + "\n") + # if verbose or debug: + # if log_handler: + # log_handler.info("Copy results: " + str(copy_results)) + # else: + # sys.stdout.write("Copy results: " + str(copy_results) + "\n") + # # if len(copy_results) == 1: + # # copy_results = list(copy_results) + # # elif len(copy_results) > 1: + # # # draftly sort results by freedom vertices_set + # # copy_results = sorted(copy_results, key=lambda + # # x: sum([(x[go_sym] - self.vertex_to_float_copy[symbols_to_vertex[symb_used]]) ** 2 + # # for go_sym, symb_used in enumerate(free_copy_variables) + # # if symb_used in symbols_to_vertex])) + # # else: + # # raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (3)!") + # else: + # copy_results = [] + # + # # if return_new_graphs: + # # """ produce all possible vertex copy combinations """ + # final_results = [] + # all_copy_sets = set() + # # maybe no more multiple results since 2022-12 gekko update + # for go_res, copy_result in enumerate([copy_results]): + # free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) + # for i, this_copy in enumerate(copy_result)} + # + # """ simplify copy values """ + # # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] + # # 2022-12-15 add cluster info to simplify by graph components when the graph is broken + # all_copies = [] + # v_to_cid = {} + # for go_id, this_symbol in enumerate(all_v_symbols): + # vertex_name = symbols_to_vertex[this_symbol] + # v_to_cid[vertex_name] = go_id + # this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) + # if this_copy <= 0: + # raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") + # all_copies.append(this_copy) + # if len(self.vertex_clusters) == 1: + # if len(all_copies) == 0: + # raise ProcessingGraphFailed( + # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # elif len(all_copies) == 1: + # all_copies = [1] + # elif min(all_copies) == 1: + # pass + # else: + # new_all_copies = reduce_list_with_gcd(all_copies) + # if verbose and new_all_copies != all_copies: + # if log_handler: + # log_handler.info("Estimated copies: " + str(all_copies)) + # log_handler.info("Reduced copies: " + str(new_all_copies)) + # else: + # sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") + # sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") + # all_copies = new_all_copies + # else: + # for v_cluster in self.vertex_clusters: + # ids = [v_to_cid[_v] for _v in v_cluster] + # component_copies = [all_copies[_id] for _id in ids] + # if len(component_copies) == 0: + # raise ProcessingGraphFailed( + # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # elif len(component_copies) == 1: + # component_copies = [1] + # elif min(component_copies) == 1: + # pass + # else: + # new_comp_copies = reduce_list_with_gcd(component_copies) + # if verbose and new_comp_copies != component_copies: + # if log_handler: + # log_handler.info("Estimated copies: " + str(component_copies)) + # log_handler.info("Reduced copies: " + str(new_comp_copies)) + # else: + # sys.stdout.write("Estimated copies: " + str(component_copies) + "\n") + # sys.stdout.write("Reduced copies: " + str(new_comp_copies) + "\n") + # component_copies = new_comp_copies + # for sequential_id, _id in enumerate(ids): + # all_copies[_id] = component_copies[sequential_id] + # + # all_copies = tuple(all_copies) + # if all_copies not in all_copy_sets: + # all_copy_sets.add(all_copies) + # else: + # continue + # + # """ record new copy values """ + # final_results.append({"graph": deepcopy(self)}) + # for go_s, this_symbol in enumerate(all_v_symbols): + # vertex_name = symbols_to_vertex[this_symbol] + # if vertex_name in final_results[go_res]["graph"].vertex_to_copy: + # old_copy = final_results[go_res]["graph"].vertex_to_copy[vertex_name] + # final_results[go_res]["graph"].copy_to_vertex[old_copy].remove(vertex_name) + # if not final_results[go_res]["graph"].copy_to_vertex[old_copy]: + # del final_results[go_res]["graph"].copy_to_vertex[old_copy] + # this_copy = all_copies[go_s] + # final_results[go_res]["graph"].vertex_to_copy[vertex_name] = this_copy + # if this_copy not in final_results[go_res]["graph"].copy_to_vertex: + # final_results[go_res]["graph"].copy_to_vertex[this_copy] = set() + # final_results[go_res]["graph"].copy_to_vertex[this_copy].add(vertex_name) + # + # """ re-estimate baseline depth """ + # total_product = 0. + # total_len = 0 + # for vertex_name in vertices_list: + # this_len = self.vertex_info[vertex_name].len \ + # * final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) + # this_cov = self.vertex_info[vertex_name].cov \ + # / final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) + # total_len += this_len + # total_product += this_len * this_cov + # final_results[go_res]["cov"] = total_product / total_len + # return final_results + # # else: + # # """ produce the first-ranked copy combination """ + # # free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) + # # for i, this_copy in enumerate(copy_results)} + # # + # # """ simplify copy values """ # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] + # # all_copies = [] + # # for this_symbol in all_v_symbols: + # # vertex_name = symbols_to_vertex[this_symbol] + # # this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) + # # if this_copy <= 0: + # # raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") + # # all_copies.append(this_copy) + # # if len(all_copies) == 0: + # # raise ProcessingGraphFailed( + # # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # # elif len(all_copies) == 1: + # # all_copies = [1] + # # elif min(all_copies) == 1: + # # pass + # # else: + # # new_all_copies = reduce_list_with_gcd(all_copies) + # # if verbose and new_all_copies != all_copies: + # # if log_handler: + # # log_handler.info("Estimated copies: " + str(all_copies)) + # # log_handler.info("Reduced copies: " + str(new_all_copies)) + # # else: + # # sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") + # # sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") + # # all_copies = new_all_copies + # # + # # """ record new copy values """ + # # for go_s, this_symbol in enumerate(all_v_symbols): + # # vertex_name = symbols_to_vertex[this_symbol] + # # if vertex_name in self.vertex_to_copy: + # # old_copy = self.vertex_to_copy[vertex_name] + # # self.copy_to_vertex[old_copy].remove(vertex_name) + # # if not self.copy_to_vertex[old_copy]: + # # del self.copy_to_vertex[old_copy] + # # this_copy = all_copies[go_s] + # # self.vertex_to_copy[vertex_name] = this_copy + # # if this_copy not in self.copy_to_vertex: + # # self.copy_to_vertex[this_copy] = set() + # # self.copy_to_vertex[this_copy].add(vertex_name) + # # + # # if debug or verbose: + # # """ re-estimate baseline depth """ + # # total_product = 0. + # # total_len = 0 + # # for vertex_name in vertices_list: + # # this_len = self.vertex_info[vertex_name].len \ + # # * self.vertex_to_copy.get(vertex_name, 1) + # # this_cov = self.vertex_info[vertex_name].cov / self.vertex_to_copy.get(vertex_name, 1) + # # total_len += this_len + # # total_product += this_len * this_cov + # # new_val = total_product / total_len + # # if log_handler: + # # log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2))) + # # else: + # # sys.stdout.write( + # # "Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2)) + "\n") + + def estimate_copy_and_depth_precisely(self, expected_average_cov=None, # broken_graph_allowed=False, + verbose=False, log_handler=None, debug=False, + target_name_for_log="target", n_iterations=None): + """ + :param expected_average_cov: not used in the least-square version + :param verbose: + :param log_handler: + :param debug: + :param target_name_for_log: + :param n_iterations: + :return: + """ + def get_formula(from_vertex, from_end, back_to_vertex, here_record_ends): + result_form = v_vars[vertices_ids[from_vertex]] here_record_ends.add((from_vertex, from_end)) # if back_to_vertex ~ from_vertex (from_vertex == back_to_vertex) form a loop, skipped if from_vertex != back_to_vertex: @@ -1129,158 +1991,173 @@ def get_formula(from_vertex, from_end, back_to_vertex, back_to_end, here_record_ pass # elif (next_v, next_e) != (back_to_vertex, back_to_end): elif (next_v, next_e) not in here_record_ends: - result_form -= get_formula(next_v, next_e, from_vertex, from_end, here_record_ends) + result_form -= get_formula(next_v, next_e, from_vertex, here_record_ends) return result_form - # for compatibility between scipy and sympy - def least_square_function_v(x): - return least_square_function(*tuple(x)) - - """ create constraints by creating inequations: the copy of every contig has to be >= 1 """ - - def constraint_min_function(x): - replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] - expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) - min_copy = np.array([1.001] * len(all_v_symbols) + - [1.001] * len(extra_symbol_to_str_m1) + - [2.001] * len(extra_symbol_to_str_m2)) - # effect: expression_array >= int(min_copy) - return expression_array - min_copy - - def constraint_min_function_for_customized_brute(x): - replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] - expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) - min_copy = np.array([1.0] * len(all_v_symbols) + - [1.0] * len(extra_symbol_to_str_m1) + - [2.0] * len(extra_symbol_to_str_m2)) - # effect: expression_array >= min_copy - return expression_array - min_copy - - def constraint_max_function(x): - replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] - expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) - max_copy = np.array([maximum_copy_num] * len(all_v_symbols) + - [maximum_copy_num] * len(extra_symbol_to_str_m1) + - [maximum_copy_num * 2] * len(extra_symbol_to_str_m2)) - # effect: expression_array <= max_copy - return max_copy - expression_array - - def constraint_int_function(x): - replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] - expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) - # diff = np.array([0] * len(all_symbols)) - return sum([abs(every_copy - int(every_copy)) for every_copy in expression_array]) - - def minimize_brute_force(func, range_list, constraint_list, round_digit=4, display_p=True, - in_log_handler=log_handler): - # time0 = time.time() - best_fun_val = inf - best_para_val = [] - count_round = 0 - count_valid = 0 - for value_set in product(*[list(this_range) for this_range in range_list]): - count_round += 1 - is_valid_set = True - for cons in constraint_list: - if cons["type"] == "ineq": - try: - if (cons["fun"](value_set) < 0).any(): - is_valid_set = False - # if in_log_handler and (debug or display_p): - # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) - break - except TypeError: - # if in_log_handler and (debug or display_p): - # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) - is_valid_set = False - break - elif cons["type"] == "eq": - try: - if cons["fun"](value_set) != 0: - is_valid_set = False - # if in_log_handler and (debug or display_p): - # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) - break - except TypeError: - # if in_log_handler and (debug or display_p): - # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) - is_valid_set = False - break - if not is_valid_set: - continue - count_valid += 1 - this_fun_val = func(value_set) - if in_log_handler: - if debug or display_p: - in_log_handler.info("value_set={} ; fun_val={}".format(value_set, this_fun_val)) - this_fun_val = round(this_fun_val, round_digit) - if this_fun_val < best_fun_val: - best_para_val = [value_set] - best_fun_val = this_fun_val - elif this_fun_val == best_fun_val: - best_para_val.append(value_set) - else: - pass - if in_log_handler: - if debug or display_p: - in_log_handler.info("Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round)) - in_log_handler.info("Brute best function value: " + str(best_fun_val)) - if debug: - in_log_handler.info("Best solution: " + str(best_para_val)) - else: - if debug or display_p: - sys.stdout.write( - "Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round) + "\n") - sys.stdout.write("Brute best function value: " + str(best_fun_val) + "\n") - if debug: - sys.stdout.write("Best solution: " + str(best_para_val) + "\n") - return best_para_val + # # for compatibility between scipy and sympy + # def least_square_function_v(x): + # return least_square_expr(*tuple(x)) + # + # """ create constraints by creating inequations: the copy of every contig has to be >= 1 """ + # + # def constraint_min_function(x): + # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # min_copy = np.array([1.001] * len(all_v_symbols) + + # [1.001] * len(extra_symbol_to_str_m1) + + # [2.001] * len(extra_symbol_to_str_m2)) + # # effect: expression_array >= int(min_copy) + # return expression_array - min_copy + # + # def constraint_min_function_for_customized_brute(x): + # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # min_copy = np.array([1.0] * len(all_v_symbols) + + # [1.0] * len(extra_symbol_to_str_m1) + + # [2.0] * len(extra_symbol_to_str_m2)) + # # effect: expression_array >= min_copy + # return expression_array - min_copy + # def constraint_max_function(x): + # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # max_copy = np.array([expected_average_cov] * len(all_v_symbols) + + # [expected_average_cov] * len(extra_symbol_to_str_m1) + + # [expected_average_cov * 2] * len(extra_symbol_to_str_m2)) + # # effect: expression_array <= max_copy + # return max_copy - expression_array + # + # def constraint_int_function(x): + # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # diff = np.array([0] * len(all_symbols)) + # return sum([abs(every_copy - int(every_copy)) for every_copy in expression_array]) + # + # def minimize_brute_force(func, range_list, constraint_list, round_digit=4, display_p=True, + # in_log_handler=log_handler): + # # time0 = time.time() + # best_fun_val = inf + # best_para_val = [] + # count_round = 0 + # count_valid = 0 + # for value_set in product(*[list(this_range) for this_range in range_list]): + # count_round += 1 + # is_valid_set = True + # for cons in constraint_list: + # if cons["type"] == "ineq": + # try: + # if (cons["fun"](value_set) < 0).any(): + # is_valid_set = False + # # if in_log_handler and (debug or display_p): + # # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) + # break + # except TypeError: + # # if in_log_handler and (debug or display_p): + # # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) + # is_valid_set = False + # break + # elif cons["type"] == "eq": + # try: + # if cons["fun"](value_set) != 0: + # is_valid_set = False + # # if in_log_handler and (debug or display_p): + # # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) + # break + # except TypeError: + # # if in_log_handler and (debug or display_p): + # # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) + # is_valid_set = False + # break + # if not is_valid_set: + # continue + # count_valid += 1 + # this_fun_val = func(value_set) + # if in_log_handler: + # if debug or display_p: + # in_log_handler.info("value_set={} ; fun_val={}".format(value_set, this_fun_val)) + # this_fun_val = round(this_fun_val, round_digit) + # if this_fun_val < best_fun_val: + # best_para_val = [value_set] + # best_fun_val = this_fun_val + # elif this_fun_val == best_fun_val: + # best_para_val.append(value_set) + # else: + # pass + # if in_log_handler: + # if debug or display_p: + # in_log_handler.info("Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round)) + # in_log_handler.info("Brute best function value: " + str(best_fun_val)) + # if debug: + # in_log_handler.info("Best solution: " + str(best_para_val)) + # else: + # if debug or display_p: + # sys.stdout.write( + # "Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round) + "\n") + # sys.stdout.write("Brute best function value: " + str(best_fun_val) + "\n") + # if debug: + # sys.stdout.write("Best solution: " + str(best_para_val) + "\n") + # return best_para_val + if verbose: + log_handler.info("Estimating copy and depth precisely ...") vertices_list = sorted(self.vertex_info) + vertices_ids = {_v: _i for _i, _v in enumerate(vertices_list)} if len(vertices_list) == 1: cov_ = self.vertex_info[vertices_list[0]].cov - if return_new_graphs: - return [{"graph": deepcopy(self), "cov": cov_}] - else: - if log_handler: - log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2))) - else: - sys.stdout.write( - "Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2)) + "\n") - return + # 2022-12-15, remove return_new_graph + # if return_new_graphs: + return [{"graph": deepcopy(self), "cov": cov_}] + # else: + # if log_handler: + # log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2))) + # else: + # sys.stdout.write( + # "Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2)) + "\n") + # return - # reduce maximum_copy_num to reduce computational burden + # reduce expected_average_cov to reduce computational burden all_coverages = [self.vertex_info[v_name].cov for v_name in vertices_list] - maximum_copy_num = min(maximum_copy_num, int(2 * math.ceil(max(all_coverages) / min(all_coverages)))) - if verbose: - if log_handler: - log_handler.info("Maximum multiplicity: " + str(maximum_copy_num)) - else: - sys.stdout.write("Maximum multiplicity: " + str(maximum_copy_num) + "\n") + # max_contig_multiplicity = \ + # min(max_contig_multiplicity, int(2 * math.ceil(max(all_coverages) / min(all_coverages)))) + # if verbose: + # if log_handler: + # log_handler.info("Maximum multiplicity: " + str(max_contig_multiplicity)) + # else: + # sys.stdout.write("Maximum multiplicity: " + str(max_contig_multiplicity) + "\n") + """ use local gekko """ + m = GEKKO(remote=False) """ create constraints by creating multivariate equations """ - vertex_to_symbols = {vertex_name: Symbol("V" + vertex_name, integer=True) # positive=True) - for vertex_name in vertices_list} - symbols_to_vertex = {vertex_to_symbols[vertex_name]: vertex_name for vertex_name in vertices_list} - extra_str_to_symbol_m1 = {} - extra_str_to_symbol_m2 = {} - extra_symbol_to_str_m1 = {} - extra_symbol_to_str_m2 = {} + copy_upper_bound = int(4 * math.ceil(max(all_coverages) / min(all_coverages))) + v_vars = m.Array(m.Var, + len(vertices_list), + lb=1, + ub=copy_upper_bound, + integer=True) + # initialize free variables + for go_v, v_var in enumerate(v_vars): + v_var.value = self.vertex_to_copy[vertices_list[go_v]] + # for go_sym, symbol_used in enumerate(free_copy_variables): + # if symbol_used in symbols_to_vertex and symbols_to_vertex[symbol_used] in self.vertex_to_copy: + # g_vars[go_sym].value = self.vertex_to_copy[symbols_to_vertex[symbol_used]] + # elif symbol_used in extra_symbol_initial_values: + # g_vars[go_sym].value = extra_symbol_initial_values[symbol_used] + + # vertex_to_symbols = {vertex_name: Symbol("V" + vertex_name, integer=True) # positive=True) + # for vertex_name in vertices_list} + # symbols_to_vertex = {vertex_to_symbols[vertex_name]: vertex_name for vertex_name in vertices_list} formulae = [] recorded_ends = set() - for vertex_name in vertices_list: + for go_v, vertex_name in enumerate(vertices_list): for this_end in (True, False): if (vertex_name, this_end) not in recorded_ends: recorded_ends.add((vertex_name, this_end)) if self.vertex_info[vertex_name].connections[this_end]: - this_formula = vertex_to_symbols[vertex_name] + this_formula = v_vars[go_v] formulized = False for n_v, n_e in self.vertex_info[vertex_name].connections[this_end]: if (n_v, n_e) not in recorded_ends: - # if n_v in vertices_set: - # recorded_ends.add((n_v, n_e)) try: - this_formula -= get_formula(n_v, n_e, vertex_name, this_end, recorded_ends) + this_formula -= get_formula(n_v, n_e, vertex_name, recorded_ends) formulized = True # if verbose: # if log_handler: @@ -1310,36 +2187,41 @@ def minimize_brute_force(func, range_list, constraint_list, round_digit=4, displ str(this_formula) + "\n") if formulized: formulae.append(this_formula) - elif broken_graph_allowed: - # Extra limitation to force terminal vertex to have only one copy, to avoid over-estimation - # Under-estimation would not be a problem here, - # because the True-multiple-copy vertex would simply have no other connections, - # or failed in the following estimation if it does - formulae.append(vertex_to_symbols[vertex_name] - 1) - - # add self-loop formulae - self_loop_v = set() + # 2022-12-13 remove this restriction + # because we have a reduce_list_with_gcd for all graph component + # elif broken_graph_allowed: + # # Extra limitation to force terminal vertex to have only one copy, to avoid over-estimation + # # Under-estimation would not be a problem here, + # # because the True-multiple-copy vertex would simply have no other connections, + # # or failed in the following estimation if it does + # formulae.append(vertex_to_symbols[vertex_name] - 1) + + """ add self-loop formulae """ + self_loop_v = OrderedDict() for vertex_name in vertices_list: if self.vertex_info[vertex_name].is_self_loop(): - self_loop_v.add(vertex_name) + self_loop_v[vertex_name] = self.vertex_to_copy[vertex_name] if log_handler: log_handler.warning("Self-loop contig detected: Vertex_" + vertex_name) - pseudo_self_loop_str = "P" + vertex_name - if pseudo_self_loop_str not in extra_str_to_symbol_m1: - extra_str_to_symbol_m1[pseudo_self_loop_str] = Symbol(pseudo_self_loop_str, integer=True) - extra_symbol_to_str_m1[extra_str_to_symbol_m1[pseudo_self_loop_str]] = pseudo_self_loop_str - this_formula = vertex_to_symbols[vertex_name] - extra_str_to_symbol_m1[pseudo_self_loop_str] + # + if self_loop_v: + p_vars = m.Array(m.Var, + len(self_loop_v), + lb=1, + ub=copy_upper_bound, + integer=True) + for go_p, (vertex_name, initial_val) in enumerate(self_loop_v.items()): + # set initial value + p_vars[go_p].value = initial_val + # add formulae + this_formula = v_vars[vertices_ids[vertex_name]] - p_vars[go_p] formulae.append(this_formula) - if verbose: - if log_handler: - log_handler.info( - "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula)) - else: - sys.stdout.write( - "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula) + "\n") - - # add following extra limitation - # set cov_sequential_repeat = x*near_by_cov, x is an integer + if verbose and log_handler: + log_handler.info( + "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula)) + """ add extra restriction on repeats """ + extra_m1 = [] + extra_m2 = [] for vertex_name in vertices_list: single_pair_in_the_trunk_path = self.is_sequential_repeat(vertex_name) if single_pair_in_the_trunk_path: @@ -1347,151 +2229,232 @@ def minimize_brute_force(func, range_list, constraint_list, round_digit=4, displ # from_v and to_v are already in the "trunk path", if they are the same, # the graph is like two circles sharing the same sequential repeat, no need to add this limitation if from_v != to_v: - new_str = "E" + str(len(extra_str_to_symbol_m1) + len(extra_str_to_symbol_m2)) + initial_val = round(self.vertex_to_float_copy[vertex_name] / self.vertex_to_float_copy[from_v]) if vertex_name in self_loop_v: # self-loop vertex is allowed to have the multiplicity of 1 - extra_str_to_symbol_m1[new_str] = Symbol(new_str, integer=True) - extra_symbol_to_str_m1[extra_str_to_symbol_m1[new_str]] = new_str - this_formula = vertex_to_symbols[vertex_name] - \ - vertex_to_symbols[from_v] * extra_str_to_symbol_m1[new_str] + extra_m1.append([vertex_name, from_v, initial_val]) else: - extra_str_to_symbol_m2[new_str] = Symbol(new_str, integer=True) - extra_symbol_to_str_m2[extra_str_to_symbol_m2[new_str]] = new_str - this_formula = vertex_to_symbols[vertex_name] - \ - vertex_to_symbols[from_v] * extra_str_to_symbol_m2[new_str] - formulae.append(this_formula) - if verbose: - if log_handler: - log_handler.info("formulating for: " + vertex_name + ": " + str(this_formula)) - else: - sys.stdout.write("formulating for: " + vertex_name + ": " + str(this_formula) + "\n") + extra_m2.append([vertex_name, from_v, max(initial_val, 2)]) + if extra_m1: + m1_vars = m.Array(m.Var, + len(extra_m1), + lb=1, + ub=copy_upper_bound, + integer=True) + for go_m, (vertex_name, from_v, initial_val) in enumerate(extra_m1): + m1_vars[go_m].value = initial_val + this_formula = v_vars[vertices_ids[vertex_name]] - v_vars[vertices_ids[from_v]] * m1_vars[go_m] + formulae.append(this_formula) + if verbose and log_handler: + log_handler.info("formulating for: " + vertex_name + ": " + str(this_formula)) + if extra_m2: + m2_vars = m.Array(m.Var, + len(extra_m2), + lb=2, + ub=copy_upper_bound, + integer=True) + for go_m, (vertex_name, from_v, initial_val) in enumerate(extra_m2): + m2_vars[go_m].value = initial_val + this_formula = v_vars[vertices_ids[vertex_name]] - v_vars[vertices_ids[from_v]] * m2_vars[go_m] + formulae.append(this_formula) + if verbose and log_handler: + log_handler.info("formulating for: " + vertex_name + ": " + str(this_formula)) - all_v_symbols = list(symbols_to_vertex) - all_symbols = all_v_symbols + list(extra_symbol_to_str_m1) + list(extra_symbol_to_str_m2) + # # solve the equations + # copy_solution = solve(formulae, all_v_symbols) + # + # copy_solution = copy_solution if copy_solution else {} + # if type(copy_solution) == list: # delete 0 containing set, even for self-loop vertex + # go_solution = 0 + # while go_solution < len(copy_solution): + # if 0 in set(copy_solution[go_solution].values()): + # del copy_solution[go_solution] + # else: + # go_solution += 1 + # if not copy_solution: + # raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (1)!") + # elif type(copy_solution) == list: + # if len(copy_solution) > 2: + # raise ProcessingGraphFailed("Incomplete/Complicated " + target_name_for_log + " graph (2)!") + # else: + # copy_solution = copy_solution[0] + # + # free_copy_variables = list() + # for symbol_used in all_symbols: + # if symbol_used not in copy_solution: + # free_copy_variables.append(symbol_used) + # copy_solution[symbol_used] = symbol_used + # if verbose: + # if log_handler: + # log_handler.info("copy equations: " + str(copy_solution)) + # log_handler.info("free variables: " + str(free_copy_variables)) + # else: + # sys.stdout.write("copy equations: " + str(copy_solution) + "\n") + # sys.stdout.write("free variables: " + str(free_copy_variables) + "\n") + + # """ """ + # least_square_expr = 0 + # for symbol_used in all_v_symbols: + # # least_square_expr += copy_solution[symbol_used] + # this_vertex = symbols_to_vertex[symbol_used] + # this_copy = self.vertex_to_float_copy[this_vertex] + # least_square_expr += (copy_solution[symbol_used] - this_copy) ** 2 # * self.vertex_info[this_vertex]["len"] + # least_square_expr = lambdify(args=free_copy_variables, expr=least_square_expr) + + # if free_copy_variables: + """ minimizing equation-based copy's deviations from coverage-based copy values """ + # ignore overlap influence + least_square_list = [] + for go_v, vertex_name in enumerate(vertices_list): + estimated_copy = self.vertex_to_float_copy[vertex_name] + least_square_list.append((v_vars[go_v] - estimated_copy) ** 2) + least_square_expr = sum(least_square_list) if verbose or debug: - if log_handler: - log_handler.info("formulae: " + str(formulae)) - else: - sys.stdout.write("formulae: " + str(formulae) + "\n") - # solve the equations - copy_solution = solve(formulae, all_v_symbols) - - copy_solution = copy_solution if copy_solution else {} - if type(copy_solution) == list: # delete 0 containing set, even for self-loop vertex - go_solution = 0 - while go_solution < len(copy_solution): - if 0 in set(copy_solution[go_solution].values()): - del copy_solution[go_solution] - else: - go_solution += 1 - if not copy_solution: - raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (1)!") - elif type(copy_solution) == list: - if len(copy_solution) > 2: - raise ProcessingGraphFailed("Incomplete/Complicated " + target_name_for_log + " graph (2)!") + log_handler.info("square function: " + str(repr(least_square_expr))) + # reform least_square_expr if string length > 15000 + exp_str_len = len(str(least_square_expr)) + if exp_str_len > 15000: # not allowed by Gekko:APM + num_blocks = math.ceil(exp_str_len / 10000.) + block_size = math.ceil(len(least_square_list) / float(num_blocks)) + block_list = [] + for g_b in range(num_blocks): + block_list.append(sum(least_square_list[g_b * block_size: (g_b + 1) * block_size])) + least_square_expr = m.sum(block_list) + + # account for the influence of the overlap + # total_len = 0 + # multinomial_loglike_list = [] + # v_to_len = {} + # v_to_real_len = {} + # all_obs = [] + # if self.__uni_overlap: + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # v_to_real_len[this_vertex] = self.vertex_info[this_vertex].len - self.__uni_overlap + # v_to_len[this_vertex] = eval(str(copy_solution[symbol_used].subs(replacements))) \ + # * v_to_real_len[this_vertex] + # total_len += v_to_len[this_vertex] + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # prob = v_to_len[this_vertex] / total_len + # obs = self.vertex_info[this_vertex].cov * v_to_real_len[this_vertex] + # multinomial_loglike_list.append(m.log(prob) * obs) + # all_obs.append(obs) + # if verbose: + # if log_handler: + # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # else: + # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # else: + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # overlaps = [_ovl + # for _strand in (True, False) + # for _next, _ovl in self.vertex_info[this_vertex].connections[_strand].items()] + # approximate_overlap = average_np_free(overlaps) + # v_to_real_len[this_vertex] = self.vertex_info[this_vertex].len - approximate_overlap + # v_to_len[this_vertex] = eval(str(copy_solution[symbol_used].subs(replacements)))\ + # * v_to_real_len[this_vertex] + # total_len += v_to_len[this_vertex] + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # prob = v_to_len[this_vertex] / total_len + # obs = self.vertex_info[this_vertex].cov * v_to_real_len[this_vertex] + # # multinomial_loglike_list.append(m.log(prob) * obs) + # all_obs.append(obs) + # if verbose: + # if log_handler: + # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # else: + # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # """extra arbitrary restriction to avoid over inflation of copies""" + # multinomial_loglike_list.append(-abs(sum(all_obs) / expected_average_cov - total_len)) + # multinomial_loglike_expr = m.sum(multinomial_loglike_list) + + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # total_len += eval(str(copy_solution[symbol_used].subs(replacements))) * self.vertex_info[this_vertex].len + # multinomial_like_expr = 0 + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # prob = eval(str(copy_solution[symbol_used].subs(replacements))) \ + # * self.vertex_info[this_vertex].len / total_len + # obs = self.vertex_info[this_vertex].cov * self.vertex_info[this_vertex].len + # multinomial_like_expr += m.log(prob) * obs + # if verbose: + # if log_handler: + # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # else: + # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + m.Equations([f_ == 0 for f_ in formulae]) + m.Minimize(least_square_expr) + # 1 for APOPT, 2 for BPOPT, 3 for IPOPT, 0 for all available solvers + # here only 1 and 3 are available + m.options.SOLVER = 1 + # setting empirical options + # 5000 costs ~ 150 sec + # get the variation within the data + single_variations = [] + for vertex_name in self.vertex_info: + f_copy = self.vertex_to_float_copy[vertex_name] + single_variations.append(abs((math.ceil(f_copy) - f_copy) ** 2 - (math.floor(f_copy) - f_copy) ** 2)) + single_variations.sort() + # largest_var = single_variations[-1] * 4 + if n_iterations is None: + n_high_copy = sum([math.log2(self.vertex_to_float_copy[_v]) + for _v in self.vertex_info if self.vertex_to_float_copy[_v] > 2]) + n_iterations = 500 + int(len(self.vertex_info) * n_high_copy) + if verbose or debug: + log_handler.info("setting n_iterations=" + str(n_iterations)) + if verbose or debug: + log_handler.info("setting minlp_gap_tol=%.0e" % single_variations[0]) + m.solver_options = ['minlp_maximum_iterations ' + str(n_iterations), + # minlp iterations with integer solution + 'minlp_max_iter_with_int_sol ' + str(n_iterations), + # treat minlp as nlp + 'minlp_as_nlp 0', + # nlp sub-problem max iterations + 'nlp_maximum_iterations ' + str(n_iterations), + # 1 = depth first, 2 = breadth first + 'minlp_branch_method 2', + # maximum deviation from whole number: + # amount that a candidate solution variable can deviate from an integer solution + # and still be considered an integer + 'minlp_integer_tol 1.0e-2', + # covergence tolerance + 'minlp_gap_tol %.0e' % single_variations[0]] + try: + # TODO: + # there is currently no random seed option for gekko, no random effect has been observed yet though + if debug or verbose: + m.solve() else: - copy_solution = copy_solution[0] - - free_copy_variables = list() - for symbol_used in all_symbols: - if symbol_used not in copy_solution: - free_copy_variables.append(symbol_used) - copy_solution[symbol_used] = symbol_used - if verbose: - if log_handler: - log_handler.info("copy equations: " + str(copy_solution)) - log_handler.info("free variables: " + str(free_copy_variables)) + m.solve(disp=False) + except NameError as e: # temporary for gekko's bug + if "TimeoutExpired" in str(e): + raise ProcessingGraphFailed("Timeout.") else: - sys.stdout.write("copy equations: " + str(copy_solution) + "\n") - sys.stdout.write("free variables: " + str(free_copy_variables) + "\n") - - # """ minimizing equation-based copy values and their deviations from coverage-based copy values """ - """ minimizing equation-based copy's deviations from coverage-based copy values """ - least_square_expr = 0 - for symbol_used in all_v_symbols: - # least_square_expr += copy_solution[symbol_used] - this_vertex = symbols_to_vertex[symbol_used] - this_copy = self.vertex_to_float_copy[this_vertex] - least_square_expr += (copy_solution[symbol_used] - this_copy) ** 2 # * self.vertex_info[this_vertex]["len"] - least_square_function = lambdify(args=free_copy_variables, expr=least_square_expr) - - # for safe running - if len(free_copy_variables) > 10: - raise ProcessingGraphFailed("Free variable > 10 is not accepted yet!") - - if maximum_copy_num ** len(free_copy_variables) < 5E6: - # sometimes, SLSQP ignores bounds and constraints - copy_results = minimize_brute_force( - func=least_square_function_v, range_list=[range(1, maximum_copy_num + 1)] * len(free_copy_variables), - constraint_list=({'type': 'ineq', 'fun': constraint_min_function_for_customized_brute}, - {'type': 'eq', 'fun': constraint_int_function}, - {'type': 'ineq', 'fun': constraint_max_function}), - display_p=verbose) - else: - constraints = ({'type': 'ineq', 'fun': constraint_min_function}, - {'type': 'eq', 'fun': constraint_int_function}, - {'type': 'ineq', 'fun': constraint_max_function}) - copy_results = set() - best_fun = inf - opt = {'disp': verbose, "maxiter": 100} - for initial_copy in range(maximum_copy_num * 2 + 1): - if initial_copy < maximum_copy_num: - initials = np.array([initial_copy + 1] * len(free_copy_variables)) - elif initial_copy < maximum_copy_num * 2: - initials = np.array([random.randint(1, maximum_copy_num)] * len(free_copy_variables)) - else: - initials = np.array([self.vertex_to_copy.get(symbols_to_vertex.get(symb, False), 2) - for symb in free_copy_variables]) - bounds = [(1, maximum_copy_num) for foo in range(len(free_copy_variables))] - try: - copy_result = optimize.minimize(fun=least_square_function_v, x0=initials, jac=False, - method='SLSQP', bounds=bounds, constraints=constraints, options=opt) - except Exception: - continue - if copy_result.fun < best_fun: - best_fun = round(copy_result.fun, 2) - copy_results = {tuple(copy_result.x)} - elif copy_result.fun == best_fun: - copy_results.add(tuple(copy_result.x)) - else: - pass - if debug or verbose: - if log_handler: - log_handler.info("Best function value: " + str(best_fun)) - else: - sys.stdout.write("Best function value: " + str(best_fun) + "\n") - if verbose or debug: - if log_handler: - log_handler.info("Copy results: " + str(copy_results)) + raise e + except Exception as e: + # TODO adjust parameters according to apm result, currently I do not know how to load apm result + if "Solution Not Found" in str(e): + raise ProcessingGraphFailed("Solution not found by apm for current graph!") else: - sys.stdout.write("Copy results: " + str(copy_results) + "\n") - if len(copy_results) == 1: - copy_results = list(copy_results) - elif len(copy_results) > 1: - # draftly sort results by freedom vertices - copy_results = sorted(copy_results, key=lambda - x: sum([(x[go_sym] - self.vertex_to_float_copy[symbols_to_vertex[symb_used]]) ** 2 - for go_sym, symb_used in enumerate(free_copy_variables) - if symb_used in symbols_to_vertex])) - else: - raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (3)!") - - if return_new_graphs: - """ produce all possible vertex copy combinations """ - final_results = [] - all_copy_sets = set() - for go_res, copy_result in enumerate(copy_results): - free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) - for i, this_copy in enumerate(copy_result)} - - """ simplify copy values """ # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] - all_copies = [] - for this_symbol in all_v_symbols: - vertex_name = symbols_to_vertex[this_symbol] - this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) - if this_copy <= 0: - raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") - all_copies.append(this_copy) + raise e + copy_results = list([x.value[0] for x in v_vars]) + if debug or verbose: + for go_v, vertex_name in enumerate(vertices_list): + log_handler.info(vertex_name + ": " + str(copy_results[go_v])) + # """ produce all possible vertex copy combinations """ + # maybe no more multiple results since 2022-12 gekko update + final_results = [] + all_copy_sets = set() # removing duplicates in multiple results + for go_res, copy_result in enumerate([copy_results]): + """ simplify copy values """ + # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] + # 2022-12-15 add cluster info to simplify by graph components when the graph is broken + all_copies = copy_result + if len(self.vertex_clusters) == 1: if len(all_copies) == 0: raise ProcessingGraphFailed( "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") @@ -1509,178 +2472,942 @@ def minimize_brute_force(func, range_list, constraint_list, round_digit=4, displ sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") all_copies = new_all_copies - all_copies = tuple(all_copies) - if all_copies not in all_copy_sets: - all_copy_sets.add(all_copies) - else: - continue - - """ record new copy values """ - final_results.append({"graph": deepcopy(self)}) - for go_s, this_symbol in enumerate(all_v_symbols): - vertex_name = symbols_to_vertex[this_symbol] - if vertex_name in final_results[go_res]["graph"].vertex_to_copy: - old_copy = final_results[go_res]["graph"].vertex_to_copy[vertex_name] - final_results[go_res]["graph"].copy_to_vertex[old_copy].remove(vertex_name) - if not final_results[go_res]["graph"].copy_to_vertex[old_copy]: - del final_results[go_res]["graph"].copy_to_vertex[old_copy] - this_copy = all_copies[go_s] - final_results[go_res]["graph"].vertex_to_copy[vertex_name] = this_copy - if this_copy not in final_results[go_res]["graph"].copy_to_vertex: - final_results[go_res]["graph"].copy_to_vertex[this_copy] = set() - final_results[go_res]["graph"].copy_to_vertex[this_copy].add(vertex_name) - - """ re-estimate baseline depth """ - total_product = 0. - total_len = 0 - for vertex_name in vertices_list: - this_len = self.vertex_info[vertex_name].len \ - * final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) - this_cov = self.vertex_info[vertex_name].cov \ - / final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) - total_len += this_len - total_product += this_len * this_cov - final_results[go_res]["cov"] = total_product / total_len - return final_results - - else: - """ produce the first-ranked copy combination """ - free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) - for i, this_copy in enumerate(copy_results[0])} - - """ simplify copy values """ # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] - all_copies = [] - for this_symbol in all_v_symbols: - vertex_name = symbols_to_vertex[this_symbol] - this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) - if this_copy <= 0: - raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") - all_copies.append(this_copy) - if len(all_copies) == 0: - raise ProcessingGraphFailed( - "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") - elif len(all_copies) == 1: - all_copies = [1] - elif min(all_copies) == 1: - pass else: - new_all_copies = reduce_list_with_gcd(all_copies) - if verbose and new_all_copies != all_copies: - if log_handler: - log_handler.info("Estimated copies: " + str(all_copies)) - log_handler.info("Reduced copies: " + str(new_all_copies)) + for v_cluster in self.vertex_clusters: + ids = [vertices_ids[_v] for _v in v_cluster] + component_copies = [all_copies[_id] for _id in ids] + if len(component_copies) == 0: + raise ProcessingGraphFailed( + "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + elif len(component_copies) == 1: + component_copies = [1] + elif min(component_copies) == 1: + pass else: - sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") - sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") - all_copies = new_all_copies + new_comp_copies = reduce_list_with_gcd(component_copies) + if verbose and new_comp_copies != component_copies: + if log_handler: + log_handler.info("Estimated copies: " + str(component_copies)) + log_handler.info("Reduced copies: " + str(new_comp_copies)) + else: + sys.stdout.write("Estimated copies: " + str(component_copies) + "\n") + sys.stdout.write("Reduced copies: " + str(new_comp_copies) + "\n") + component_copies = new_comp_copies + for sequential_id, _id in enumerate(ids): + all_copies[_id] = component_copies[sequential_id] + + all_copies = tuple(all_copies) + if all_copies not in all_copy_sets: + all_copy_sets.add(all_copies) + else: + continue """ record new copy values """ - for go_s, this_symbol in enumerate(all_v_symbols): - vertex_name = symbols_to_vertex[this_symbol] - if vertex_name in self.vertex_to_copy: - old_copy = self.vertex_to_copy[vertex_name] - self.copy_to_vertex[old_copy].remove(vertex_name) - if not self.copy_to_vertex[old_copy]: - del self.copy_to_vertex[old_copy] - this_copy = all_copies[go_s] - self.vertex_to_copy[vertex_name] = this_copy - if this_copy not in self.copy_to_vertex: - self.copy_to_vertex[this_copy] = set() - self.copy_to_vertex[this_copy].add(vertex_name) - - if debug or verbose: - """ re-estimate baseline depth """ - total_product = 0. - total_len = 0 - for vertex_name in vertices_list: - this_len = self.vertex_info[vertex_name].len \ - * self.vertex_to_copy.get(vertex_name, 1) - this_cov = self.vertex_info[vertex_name].cov / self.vertex_to_copy.get(vertex_name, 1) - total_len += this_len - total_product += this_len * this_cov - new_val = total_product / total_len - if log_handler: - log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2))) - else: - sys.stdout.write( - "Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2)) + "\n") - - def tag_in_between(self, database_n): - # add those in between the tagged vertices to tagged_vertices, which offered the only connection - updated = True - candidate_vertices = list(self.vertex_info) - while updated: - updated = False - go_to_v = 0 - while go_to_v < len(candidate_vertices): - can_v = candidate_vertices[go_to_v] - if can_v in self.tagged_vertices[database_n]: - del candidate_vertices[go_to_v] - continue - else: - if sum([bool(c_c) for c_c in self.vertex_info[can_v].connections.values()]) != 2: + final_results.append({"graph": deepcopy(self)}) + for go_v, vertex_name in enumerate(vertices_list): + if vertex_name in final_results[go_res]["graph"].vertex_to_copy: + old_copy = final_results[go_res]["graph"].vertex_to_copy[vertex_name] + final_results[go_res]["graph"].copy_to_vertex[old_copy].remove(vertex_name) + if not final_results[go_res]["graph"].copy_to_vertex[old_copy]: + del final_results[go_res]["graph"].copy_to_vertex[old_copy] + estimated_copy = all_copies[go_v] + final_results[go_res]["graph"].vertex_to_copy[vertex_name] = estimated_copy + if estimated_copy not in final_results[go_res]["graph"].copy_to_vertex: + final_results[go_res]["graph"].copy_to_vertex[estimated_copy] = set() + final_results[go_res]["graph"].copy_to_vertex[estimated_copy].add(vertex_name) + + """ re-estimate baseline depth """ + total_product = 0. + total_len = 0 + for vertex_name in vertices_list: + this_len = self.vertex_info[vertex_name].len \ + * final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) + this_cov = self.vertex_info[vertex_name].cov \ + / final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) + total_len += this_len + total_product += this_len * this_cov + final_results[go_res]["cov"] = total_product / total_len + return final_results + # else: + # """ produce the first-ranked copy combination """ + # free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) + # for i, this_copy in enumerate(copy_results)} + # + # """ simplify copy values """ # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] + # all_copies = [] + # for this_symbol in all_v_symbols: + # vertex_name = symbols_to_vertex[this_symbol] + # this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) + # if this_copy <= 0: + # raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") + # all_copies.append(this_copy) + # if len(all_copies) == 0: + # raise ProcessingGraphFailed( + # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # elif len(all_copies) == 1: + # all_copies = [1] + # elif min(all_copies) == 1: + # pass + # else: + # new_all_copies = reduce_list_with_gcd(all_copies) + # if verbose and new_all_copies != all_copies: + # if log_handler: + # log_handler.info("Estimated copies: " + str(all_copies)) + # log_handler.info("Reduced copies: " + str(new_all_copies)) + # else: + # sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") + # sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") + # all_copies = new_all_copies + # + # """ record new copy values """ + # for go_s, this_symbol in enumerate(all_v_symbols): + # vertex_name = symbols_to_vertex[this_symbol] + # if vertex_name in self.vertex_to_copy: + # old_copy = self.vertex_to_copy[vertex_name] + # self.copy_to_vertex[old_copy].remove(vertex_name) + # if not self.copy_to_vertex[old_copy]: + # del self.copy_to_vertex[old_copy] + # this_copy = all_copies[go_s] + # self.vertex_to_copy[vertex_name] = this_copy + # if this_copy not in self.copy_to_vertex: + # self.copy_to_vertex[this_copy] = set() + # self.copy_to_vertex[this_copy].add(vertex_name) + # + # if debug or verbose: + # """ re-estimate baseline depth """ + # total_product = 0. + # total_len = 0 + # for vertex_name in vertices_list: + # this_len = self.vertex_info[vertex_name].len \ + # * self.vertex_to_copy.get(vertex_name, 1) + # this_cov = self.vertex_info[vertex_name].cov / self.vertex_to_copy.get(vertex_name, 1) + # total_len += this_len + # total_product += this_len * this_cov + # new_val = total_product / total_len + # if log_handler: + # log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2))) + # else: + # sys.stdout.write( + # "Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2)) + "\n") + + # def estimate_copy_and_depth_precisely_sympy(self, expected_average_cov=None, # broken_graph_allowed=False, + # verbose=False, log_handler=None, debug=False, + # target_name_for_log="target", n_iterations=None): + # + # def get_formula(from_vertex, from_end, back_to_vertex, here_record_ends): + # result_form = vertex_to_symbols[from_vertex] + # here_record_ends.add((from_vertex, from_end)) + # # if back_to_vertex ~ from_vertex (from_vertex == back_to_vertex) form a loop, skipped + # if from_vertex != back_to_vertex: + # for next_v, next_e in self.vertex_info[from_vertex].connections[from_end]: + # # if next_v ~ from_vertex (next_v == from_vertex) form a loop, add a pseudo vertex + # if (next_v, next_e) == (from_vertex, not from_end): + # # skip every self-loop 2020-06-23 + # # pseudo_self_circle_str = "P" + from_vertex + # # if pseudo_self_circle_str not in extra_str_to_symbol_m2: + # # extra_str_to_symbol_m2[pseudo_self_circle_str] = Symbol(pseudo_self_circle_str, integer=True) + # # extra_symbol_to_str_m2[extra_str_to_symbol_m2[pseudo_self_circle_str]] = pseudo_self_circle_str + # # result_form -= (extra_str_to_symbol_m2[pseudo_self_circle_str] - 1) + # pass + # # elif (next_v, next_e) != (back_to_vertex, back_to_end): + # elif (next_v, next_e) not in here_record_ends: + # result_form -= get_formula(next_v, next_e, from_vertex, here_record_ends) + # return result_form + # + # # # for compatibility between scipy and sympy + # # def least_square_function_v(x): + # # return least_square_expr(*tuple(x)) + # # + # # """ create constraints by creating inequations: the copy of every contig has to be >= 1 """ + # # + # # def constraint_min_function(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # min_copy = np.array([1.001] * len(all_v_symbols) + + # # [1.001] * len(extra_symbol_to_str_m1) + + # # [2.001] * len(extra_symbol_to_str_m2)) + # # # effect: expression_array >= int(min_copy) + # # return expression_array - min_copy + # # + # # def constraint_min_function_for_customized_brute(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # min_copy = np.array([1.0] * len(all_v_symbols) + + # # [1.0] * len(extra_symbol_to_str_m1) + + # # [2.0] * len(extra_symbol_to_str_m2)) + # # # effect: expression_array >= min_copy + # # return expression_array - min_copy + # + # def constraint_min_function_for_gekko(g_vars): + # subs_tuples = [(symb_used_, Symbol("g_vars[" + str(go_sym) + "]")) + # for go_sym, symb_used_ in enumerate(free_copy_variables)] + # expression_array = [copy_solution[this_sym].subs(subs_tuples) for this_sym in all_symbols] + # min_copy = [1] * len(all_v_symbols) + \ + # [1] * len(extra_symbol_to_str_m1) + \ + # [2] * len(extra_symbol_to_str_m2) + # # effect: expression_array >= min_copy + # expression = [] + # if verbose or debug: + # for e, c in zip(expression_array, min_copy): + # expression.append(eval(str(e) + ">=" + str(c))) + # log_handler.info(" constraint: " + str(e) + ">=" + str(c)) + # else: + # for e, c in zip(expression_array, min_copy): + # expression.append(eval(str(e) + ">=" + str(c))) + # expression = [expr for expr in expression if not isinstance(expr, bool)] + # return expression + # + # # def constraint_max_function(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # max_copy = np.array([expected_average_cov] * len(all_v_symbols) + + # # [expected_average_cov] * len(extra_symbol_to_str_m1) + + # # [expected_average_cov * 2] * len(extra_symbol_to_str_m2)) + # # # effect: expression_array <= max_copy + # # return max_copy - expression_array + # # + # # def constraint_int_function(x): + # # replacements = [(symbol_used, x[go_sym]) for go_sym, symbol_used in enumerate(free_copy_variables)] + # # expression_array = np.array([copy_solution[this_sym].subs(replacements) for this_sym in all_symbols]) + # # # diff = np.array([0] * len(all_symbols)) + # # return sum([abs(every_copy - int(every_copy)) for every_copy in expression_array]) + # # + # # def minimize_brute_force(func, range_list, constraint_list, round_digit=4, display_p=True, + # # in_log_handler=log_handler): + # # # time0 = time.time() + # # best_fun_val = inf + # # best_para_val = [] + # # count_round = 0 + # # count_valid = 0 + # # for value_set in product(*[list(this_range) for this_range in range_list]): + # # count_round += 1 + # # is_valid_set = True + # # for cons in constraint_list: + # # if cons["type"] == "ineq": + # # try: + # # if (cons["fun"](value_set) < 0).any(): + # # is_valid_set = False + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) + # # break + # # except TypeError: + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal ineq constraints".format(value_set)) + # # is_valid_set = False + # # break + # # elif cons["type"] == "eq": + # # try: + # # if cons["fun"](value_set) != 0: + # # is_valid_set = False + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) + # # break + # # except TypeError: + # # # if in_log_handler and (debug or display_p): + # # # in_log_handler.info("value_set={} ; illegal eq constraints".format(value_set)) + # # is_valid_set = False + # # break + # # if not is_valid_set: + # # continue + # # count_valid += 1 + # # this_fun_val = func(value_set) + # # if in_log_handler: + # # if debug or display_p: + # # in_log_handler.info("value_set={} ; fun_val={}".format(value_set, this_fun_val)) + # # this_fun_val = round(this_fun_val, round_digit) + # # if this_fun_val < best_fun_val: + # # best_para_val = [value_set] + # # best_fun_val = this_fun_val + # # elif this_fun_val == best_fun_val: + # # best_para_val.append(value_set) + # # else: + # # pass + # # if in_log_handler: + # # if debug or display_p: + # # in_log_handler.info("Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round)) + # # in_log_handler.info("Brute best function value: " + str(best_fun_val)) + # # if debug: + # # in_log_handler.info("Best solution: " + str(best_para_val)) + # # else: + # # if debug or display_p: + # # sys.stdout.write( + # # "Brute valid/candidate rounds: " + str(count_valid) + "/" + str(count_round) + "\n") + # # sys.stdout.write("Brute best function value: " + str(best_fun_val) + "\n") + # # if debug: + # # sys.stdout.write("Best solution: " + str(best_para_val) + "\n") + # # return best_para_val + # if verbose: + # log_handler.info("Estimating copy and depth precisely ...") + # + # vertices_list = sorted(self.vertex_info) + # if len(vertices_list) == 1: + # cov_ = self.vertex_info[vertices_list[0]].cov + # # 2022-12-15, remove return_new_graph + # # if return_new_graphs: + # return [{"graph": deepcopy(self), "cov": cov_}] + # # else: + # # if log_handler: + # # log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2))) + # # else: + # # sys.stdout.write( + # # "Average " + target_name_for_log + " kmer-coverage = " + str(round(cov_, 2)) + "\n") + # # return + # + # # reduce expected_average_cov to reduce computational burden + # all_coverages = [self.vertex_info[v_name].cov for v_name in vertices_list] + # # max_contig_multiplicity = \ + # # min(max_contig_multiplicity, int(2 * math.ceil(max(all_coverages) / min(all_coverages)))) + # # if verbose: + # # if log_handler: + # # log_handler.info("Maximum multiplicity: " + str(max_contig_multiplicity)) + # # else: + # # sys.stdout.write("Maximum multiplicity: " + str(max_contig_multiplicity) + "\n") + # + # """ create constraints by creating multivariate equations """ + # vertex_to_symbols = {vertex_name: Symbol("V" + vertex_name, integer=True) # positive=True) + # for vertex_name in vertices_list} + # symbols_to_vertex = {vertex_to_symbols[vertex_name]: vertex_name for vertex_name in vertices_list} + # extra_str_to_symbol_m1 = {} + # extra_str_to_symbol_m2 = {} + # extra_symbol_to_str_m1 = {} + # extra_symbol_to_str_m2 = {} + # extra_symbol_initial_values = {} + # formulae = [] + # recorded_ends = set() + # for vertex_name in vertices_list: + # for this_end in (True, False): + # if (vertex_name, this_end) not in recorded_ends: + # recorded_ends.add((vertex_name, this_end)) + # if self.vertex_info[vertex_name].connections[this_end]: + # this_formula = vertex_to_symbols[vertex_name] + # formulized = False + # for n_v, n_e in self.vertex_info[vertex_name].connections[this_end]: + # if (n_v, n_e) not in recorded_ends: + # # if n_v in vertices_set: + # # recorded_ends.add((n_v, n_e)) + # try: + # this_formula -= get_formula(n_v, n_e, vertex_name, recorded_ends) + # formulized = True + # # if verbose: + # # if log_handler: + # # log_handler.info("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # # vertex_name + ECHO_DIRECTION[this_end] + ": " + + # # str(this_formula)) + # # else: + # # sys.stdout.write("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # # vertex_name + ECHO_DIRECTION[this_end] + ": " + + # # str(this_formula)+"\n") + # except RecursionError: + # if log_handler: + # log_handler.warning("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # vertex_name + ECHO_DIRECTION[this_end] + " failed!") + # else: + # sys.stdout.write("formulating for: " + n_v + ECHO_DIRECTION[n_e] + "->" + + # vertex_name + ECHO_DIRECTION[this_end] + " failed!\n") + # raise ProcessingGraphFailed("RecursionError!") + # if verbose: + # if log_handler: + # log_handler.info( + # "formulating for: " + vertex_name + ECHO_DIRECTION[this_end] + ": " + + # str(this_formula)) + # else: + # sys.stdout.write( + # "formulating for: " + vertex_name + ECHO_DIRECTION[this_end] + ": " + + # str(this_formula) + "\n") + # if formulized: + # formulae.append(this_formula) + # # 2022-12-13 remove this restriction + # # because we have a reduce_list_with_gcd for all graph component + # # elif broken_graph_allowed: + # # # Extra limitation to force terminal vertex to have only one copy, to avoid over-estimation + # # # Under-estimation would not be a problem here, + # # # because the True-multiple-copy vertex would simply have no other connections, + # # # or failed in the following estimation if it does + # # formulae.append(vertex_to_symbols[vertex_name] - 1) + # + # # add self-loop formulae + # self_loop_v = set() + # for vertex_name in vertices_list: + # if self.vertex_info[vertex_name].is_self_loop(): + # self_loop_v.add(vertex_name) + # if log_handler: + # log_handler.warning("Self-loop contig detected: Vertex_" + vertex_name) + # pseudo_self_loop_str = "P" + vertex_name + # if pseudo_self_loop_str not in extra_str_to_symbol_m1: + # extra_str_to_symbol_m1[pseudo_self_loop_str] = Symbol(pseudo_self_loop_str, integer=True) + # extra_symbol_to_str_m1[extra_str_to_symbol_m1[pseudo_self_loop_str]] = pseudo_self_loop_str + # this_formula = vertex_to_symbols[vertex_name] - extra_str_to_symbol_m1[pseudo_self_loop_str] + # extra_symbol_initial_values[extra_str_to_symbol_m1[pseudo_self_loop_str]] = \ + # self.vertex_to_copy[vertex_name] + # formulae.append(this_formula) + # if verbose: + # if log_handler: + # log_handler.info( + # "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula)) + # else: + # sys.stdout.write( + # "formulating for: " + vertex_name + ECHO_DIRECTION[True] + ": " + str(this_formula) + "\n") + # + # # add following extra limitation + # # set cov_sequential_repeat = x*near_by_cov, x is an integer + # for vertex_name in vertices_list: + # single_pair_in_the_trunk_path = self.is_sequential_repeat(vertex_name) + # if single_pair_in_the_trunk_path: + # (from_v, from_e), (to_v, to_e) = single_pair_in_the_trunk_path + # # from_v and to_v are already in the "trunk path", if they are the same, + # # the graph is like two circles sharing the same sequential repeat, no need to add this limitation + # if from_v != to_v: + # new_str = "E" + str(len(extra_str_to_symbol_m1) + len(extra_str_to_symbol_m2)) + # if vertex_name in self_loop_v: + # # self-loop vertex is allowed to have the multiplicity of 1 + # extra_str_to_symbol_m1[new_str] = Symbol(new_str, integer=True) + # extra_symbol_to_str_m1[extra_str_to_symbol_m1[new_str]] = new_str + # this_formula = vertex_to_symbols[vertex_name] - \ + # vertex_to_symbols[from_v] * extra_str_to_symbol_m1[new_str] + # extra_symbol_initial_values[extra_str_to_symbol_m1[new_str]] = \ + # round(self.vertex_to_float_copy[vertex_name] / self.vertex_to_float_copy[from_v]) + # else: + # extra_str_to_symbol_m2[new_str] = Symbol(new_str, integer=True) + # extra_symbol_to_str_m2[extra_str_to_symbol_m2[new_str]] = new_str + # this_formula = vertex_to_symbols[vertex_name] - \ + # vertex_to_symbols[from_v] * extra_str_to_symbol_m2[new_str] + # extra_symbol_initial_values[extra_str_to_symbol_m2[new_str]] = \ + # round(self.vertex_to_float_copy[vertex_name] / self.vertex_to_float_copy[from_v]) + # formulae.append(this_formula) + # if verbose: + # if log_handler: + # log_handler.info("formulating for: " + vertex_name + ": " + str(this_formula)) + # else: + # sys.stdout.write("formulating for: " + vertex_name + ": " + str(this_formula) + "\n") + # + # all_v_symbols = list(symbols_to_vertex) + # all_symbols = all_v_symbols + list(extra_symbol_to_str_m1) + list(extra_symbol_to_str_m2) + # if verbose or debug: + # if log_handler: + # log_handler.info("formulae: " + str(formulae)) + # else: + # sys.stdout.write("formulae: " + str(formulae) + "\n") + # # solve the equations + # copy_solution = solve(formulae, all_v_symbols) + # + # copy_solution = copy_solution if copy_solution else {} + # if type(copy_solution) == list: # delete 0 containing set, even for self-loop vertex + # go_solution = 0 + # while go_solution < len(copy_solution): + # if 0 in set(copy_solution[go_solution].values()): + # del copy_solution[go_solution] + # else: + # go_solution += 1 + # if not copy_solution: + # raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (1)!") + # elif type(copy_solution) == list: + # if len(copy_solution) > 2: + # raise ProcessingGraphFailed("Incomplete/Complicated " + target_name_for_log + " graph (2)!") + # else: + # copy_solution = copy_solution[0] + # + # free_copy_variables = list() + # for symbol_used in all_symbols: + # if symbol_used not in copy_solution: + # free_copy_variables.append(symbol_used) + # copy_solution[symbol_used] = symbol_used + # if verbose: + # if log_handler: + # log_handler.info("copy equations: " + str(copy_solution)) + # log_handler.info("free variables: " + str(free_copy_variables)) + # else: + # sys.stdout.write("copy equations: " + str(copy_solution) + "\n") + # sys.stdout.write("free variables: " + str(free_copy_variables) + "\n") + # + # # """ """ + # # least_square_expr = 0 + # # for symbol_used in all_v_symbols: + # # # least_square_expr += copy_solution[symbol_used] + # # this_vertex = symbols_to_vertex[symbol_used] + # # this_copy = self.vertex_to_float_copy[this_vertex] + # # least_square_expr += (copy_solution[symbol_used] - this_copy) ** 2 # * self.vertex_info[this_vertex]["len"] + # # least_square_expr = lambdify(args=free_copy_variables, expr=least_square_expr) + # + # if free_copy_variables: + # """ minimizing equation-based copy's deviations from coverage-based copy values """ + # # ignore overlap influence + # m = GEKKO(remote=False) + # g_vars = m.Array(m.Var, + # len(free_copy_variables), + # lb=1, + # ub=int(4 * math.ceil(max(all_coverages) / min(all_coverages))), + # integer=True) + # # initialize free variables + # for go_sym, symbol_used in enumerate(free_copy_variables): + # if symbol_used in symbols_to_vertex and symbols_to_vertex[symbol_used] in self.vertex_to_copy: + # g_vars[go_sym].value = self.vertex_to_copy[symbols_to_vertex[symbol_used]] + # elif symbol_used in extra_symbol_initial_values: + # g_vars[go_sym].value = extra_symbol_initial_values[symbol_used] + # # get the variation within the data + # single_variations = [] + # for vertex_name in self.vertex_info: + # f_copy = self.vertex_to_float_copy[vertex_name] + # single_variations.append(abs((math.ceil(f_copy) - f_copy) ** 2 - (math.floor(f_copy) - f_copy) ** 2)) + # single_variations.sort() + # # largest_var = single_variations[-1] * 4 + # + # replacements = [(symbol_used, Symbol("g_vars[" + str(go_sym) + "]")) + # for go_sym, symbol_used in enumerate(free_copy_variables)] + # least_square_list = [] + # for symbol_used in all_v_symbols: + # this_vertex = symbols_to_vertex[symbol_used] + # this_copy = self.vertex_to_float_copy[this_vertex] + # symbol_copy = eval(str(copy_solution[symbol_used].subs(replacements))) + # least_square_list.append((symbol_copy - this_copy) ** 2) + # # not working + # # # constraint the number to be integer + # # least_square_list.append(largest_var * (symbol_copy - int(symbol_copy)) ** 2) + # # least_square_expr = sum(least_square_list) will lead to no solution error for many variables + # least_square_expr = sum(least_square_list) + # if verbose or debug: + # log_handler.info("square function: " + str(repr(least_square_expr))) + # # reform least_square_expr if string length > 15000 + # exp_str_len = len(str(least_square_expr)) + # if exp_str_len > 15000: # not allowed by Gekko:APM + # num_blocks = math.ceil(exp_str_len / 10000.) + # block_size = math.ceil(len(least_square_list) / float(num_blocks)) + # block_list = [] + # for g_b in range(num_blocks): + # block_list.append(sum(least_square_list[g_b * block_size: (g_b + 1) * block_size])) + # least_square_expr = m.sum(block_list) + # + # # account for the influence of the overlap + # # total_len = 0 + # # multinomial_loglike_list = [] + # # v_to_len = {} + # # v_to_real_len = {} + # # all_obs = [] + # # if self.__uni_overlap: + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # v_to_real_len[this_vertex] = self.vertex_info[this_vertex].len - self.__uni_overlap + # # v_to_len[this_vertex] = eval(str(copy_solution[symbol_used].subs(replacements))) \ + # # * v_to_real_len[this_vertex] + # # total_len += v_to_len[this_vertex] + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # prob = v_to_len[this_vertex] / total_len + # # obs = self.vertex_info[this_vertex].cov * v_to_real_len[this_vertex] + # # multinomial_loglike_list.append(m.log(prob) * obs) + # # all_obs.append(obs) + # # if verbose: + # # if log_handler: + # # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # # else: + # # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # # else: + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # overlaps = [_ovl + # # for _strand in (True, False) + # # for _next, _ovl in self.vertex_info[this_vertex].connections[_strand].items()] + # # approximate_overlap = average_np_free(overlaps) + # # v_to_real_len[this_vertex] = self.vertex_info[this_vertex].len - approximate_overlap + # # v_to_len[this_vertex] = eval(str(copy_solution[symbol_used].subs(replacements)))\ + # # * v_to_real_len[this_vertex] + # # total_len += v_to_len[this_vertex] + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # prob = v_to_len[this_vertex] / total_len + # # obs = self.vertex_info[this_vertex].cov * v_to_real_len[this_vertex] + # # # multinomial_loglike_list.append(m.log(prob) * obs) + # # all_obs.append(obs) + # # if verbose: + # # if log_handler: + # # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # # else: + # # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # # """extra arbitrary restriction to avoid over inflation of copies""" + # # multinomial_loglike_list.append(-abs(sum(all_obs) / expected_average_cov - total_len)) + # # multinomial_loglike_expr = m.sum(multinomial_loglike_list) + # + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # total_len += eval(str(copy_solution[symbol_used].subs(replacements))) * self.vertex_info[this_vertex].len + # # multinomial_like_expr = 0 + # # for symbol_used in all_v_symbols: + # # this_vertex = symbols_to_vertex[symbol_used] + # # prob = eval(str(copy_solution[symbol_used].subs(replacements))) \ + # # * self.vertex_info[this_vertex].len / total_len + # # obs = self.vertex_info[this_vertex].cov * self.vertex_info[this_vertex].len + # # multinomial_like_expr += m.log(prob) * obs + # # if verbose: + # # if log_handler: + # # log_handler.info(" >" + this_vertex + "\t" + str(obs)) # + "\t" + str(prob) + # # else: + # # sys.stdout.write(" >" + this_vertex + "\t" + str(obs) + "\n") + # m.Equations(constraint_min_function_for_gekko(g_vars)) + # m.Minimize(least_square_expr) + # # 1 for APOPT, 2 for BPOPT, 3 for IPOPT, 0 for all available solvers + # # here only 1 and 3 are available + # m.options.SOLVER = 1 + # # setting empirical options + # # 5000 costs ~ 150 sec + # if n_iterations is None: + # n_high_copy = sum([math.log2(self.vertex_to_float_copy[_v]) + # for _v in self.vertex_info if self.vertex_to_float_copy[_v] > 2]) + # n_iterations = 500 + int(len(self.vertex_info) * n_high_copy) + # if verbose or debug: + # log_handler.info("setting n_iterations=" + str(n_iterations)) + # if verbose or debug: + # log_handler.info("setting minlp_gap_tol=%.0e" % single_variations[0]) + # m.solver_options = ['minlp_maximum_iterations ' + str(n_iterations), + # # minlp iterations with integer solution + # 'minlp_max_iter_with_int_sol ' + str(n_iterations), + # # treat minlp as nlp + # 'minlp_as_nlp 0', + # # nlp sub-problem max iterations + # 'nlp_maximum_iterations ' + str(n_iterations), + # # 1 = depth first, 2 = breadth first + # 'minlp_branch_method 2', + # # maximum deviation from whole number: + # # amount that a candidate solution variable can deviate from an integer solution + # # and still be considered an integer + # 'minlp_integer_tol 1.0e-2', + # # covergence tolerance + # 'minlp_gap_tol %.0e' % single_variations[0]] + # if debug or verbose: + # m.solve() + # else: + # m.solve(disp=False) + # # print([x.value[0] for x in g_vars]) + # copy_results = list([x.value[0] for x in g_vars]) + # + # # # for safe running + # # if len(free_copy_variables) > 10: + # # raise ProcessingGraphFailed("Free variable > 10 is not accepted yet!") + # # + # # if expected_average_cov ** len(free_copy_variables) < 5E6: + # # # sometimes, SLSQP ignores bounds and constraints + # # copy_results = minimize_brute_force( + # # func=least_square_function_v, range_list=[range(1, expected_average_cov + 1)] * len(free_copy_variables), + # # constraint_list=({'type': 'ineq', 'fun': constraint_min_function_for_customized_brute}, + # # {'type': 'eq', 'fun': constraint_int_function}, + # # {'type': 'ineq', 'fun': constraint_max_function}), + # # display_p=verbose) + # # else: + # # constraints = ({'type': 'ineq', 'fun': constraint_min_function}, + # # {'type': 'eq', 'fun': constraint_int_function}, + # # {'type': 'ineq', 'fun': constraint_max_function}) + # # copy_results = set() + # # best_fun = inf + # # opt = {'disp': verbose, "maxiter": 100} + # # for initial_copy in range(expected_average_cov * 2 + 1): + # # if initial_copy < expected_average_cov: + # # initials = np.array([initial_copy + 1] * len(free_copy_variables)) + # # elif initial_copy < expected_average_cov * 2: + # # initials = np.array([random.randint(1, expected_average_cov)] * len(free_copy_variables)) + # # else: + # # initials = np.array([self.vertex_to_copy.get(symbols_to_vertex.get(symb, False), 2) + # # for symb in free_copy_variables]) + # # bounds = [(1, expected_average_cov) for foo in range(len(free_copy_variables))] + # # try: + # # copy_result = optimize.minimize(fun=least_square_function_v, x0=initials, jac=False, + # # method='SLSQP', bounds=bounds, constraints=constraints, options=opt) + # # except Exception: + # # continue + # # if copy_result.fun < best_fun: + # # best_fun = round(copy_result.fun, 2) + # # copy_results = {tuple(copy_result.x)} + # # elif copy_result.fun == best_fun: + # # copy_results.add(tuple(copy_result.x)) + # # else: + # # pass + # # if debug or verbose: + # # if log_handler: + # # log_handler.info("Best function value: " + str(best_fun)) + # # else: + # # sys.stdout.write("Best function value: " + str(best_fun) + "\n") + # if verbose or debug: + # if log_handler: + # log_handler.info("Copy results: " + str(copy_results)) + # else: + # sys.stdout.write("Copy results: " + str(copy_results) + "\n") + # # if len(copy_results) == 1: + # # copy_results = list(copy_results) + # # elif len(copy_results) > 1: + # # # draftly sort results by freedom vertices_set + # # copy_results = sorted(copy_results, key=lambda + # # x: sum([(x[go_sym] - self.vertex_to_float_copy[symbols_to_vertex[symb_used]]) ** 2 + # # for go_sym, symb_used in enumerate(free_copy_variables) + # # if symb_used in symbols_to_vertex])) + # # else: + # # raise ProcessingGraphFailed("Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (3)!") + # else: + # copy_results = [] + # + # # if return_new_graphs: + # # """ produce all possible vertex copy combinations """ + # final_results = [] + # all_copy_sets = set() + # # maybe no more multiple results since 2022-12 gekko update + # for go_res, copy_result in enumerate([copy_results]): + # free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) + # for i, this_copy in enumerate(copy_result)} + # + # """ simplify copy values """ + # # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] + # # 2022-12-15 add cluster info to simplify by graph components when the graph is broken + # all_copies = [] + # v_to_cid = {} + # for go_id, this_symbol in enumerate(all_v_symbols): + # vertex_name = symbols_to_vertex[this_symbol] + # v_to_cid[vertex_name] = go_id + # this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) + # if this_copy <= 0: + # raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") + # all_copies.append(this_copy) + # if len(self.vertex_clusters) == 1: + # if len(all_copies) == 0: + # raise ProcessingGraphFailed( + # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # elif len(all_copies) == 1: + # all_copies = [1] + # elif min(all_copies) == 1: + # pass + # else: + # new_all_copies = reduce_list_with_gcd(all_copies) + # if verbose and new_all_copies != all_copies: + # if log_handler: + # log_handler.info("Estimated copies: " + str(all_copies)) + # log_handler.info("Reduced copies: " + str(new_all_copies)) + # else: + # sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") + # sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") + # all_copies = new_all_copies + # else: + # for v_cluster in self.vertex_clusters: + # ids = [v_to_cid[_v] for _v in v_cluster] + # component_copies = [all_copies[_id] for _id in ids] + # if len(component_copies) == 0: + # raise ProcessingGraphFailed( + # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # elif len(component_copies) == 1: + # component_copies = [1] + # elif min(component_copies) == 1: + # pass + # else: + # new_comp_copies = reduce_list_with_gcd(component_copies) + # if verbose and new_comp_copies != component_copies: + # if log_handler: + # log_handler.info("Estimated copies: " + str(component_copies)) + # log_handler.info("Reduced copies: " + str(new_comp_copies)) + # else: + # sys.stdout.write("Estimated copies: " + str(component_copies) + "\n") + # sys.stdout.write("Reduced copies: " + str(new_comp_copies) + "\n") + # component_copies = new_comp_copies + # for sequential_id, _id in enumerate(ids): + # all_copies[_id] = component_copies[sequential_id] + # + # all_copies = tuple(all_copies) + # if all_copies not in all_copy_sets: + # all_copy_sets.add(all_copies) + # else: + # continue + # + # """ record new copy values """ + # final_results.append({"graph": deepcopy(self)}) + # for go_s, this_symbol in enumerate(all_v_symbols): + # vertex_name = symbols_to_vertex[this_symbol] + # if vertex_name in final_results[go_res]["graph"].vertex_to_copy: + # old_copy = final_results[go_res]["graph"].vertex_to_copy[vertex_name] + # final_results[go_res]["graph"].copy_to_vertex[old_copy].remove(vertex_name) + # if not final_results[go_res]["graph"].copy_to_vertex[old_copy]: + # del final_results[go_res]["graph"].copy_to_vertex[old_copy] + # this_copy = all_copies[go_s] + # final_results[go_res]["graph"].vertex_to_copy[vertex_name] = this_copy + # if this_copy not in final_results[go_res]["graph"].copy_to_vertex: + # final_results[go_res]["graph"].copy_to_vertex[this_copy] = set() + # final_results[go_res]["graph"].copy_to_vertex[this_copy].add(vertex_name) + # + # """ re-estimate baseline depth """ + # total_product = 0. + # total_len = 0 + # for vertex_name in vertices_list: + # this_len = self.vertex_info[vertex_name].len \ + # * final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) + # this_cov = self.vertex_info[vertex_name].cov \ + # / final_results[go_res]["graph"].vertex_to_copy.get(vertex_name, 1) + # total_len += this_len + # total_product += this_len * this_cov + # final_results[go_res]["cov"] = total_product / total_len + # return final_results + # # else: + # # """ produce the first-ranked copy combination """ + # # free_copy_variables_dict = {free_copy_variables[i]: int(this_copy) + # # for i, this_copy in enumerate(copy_results)} + # # + # # """ simplify copy values """ # 2020-02-22 added to avoid multiplicities res such as: [4, 8, 4] + # # all_copies = [] + # # for this_symbol in all_v_symbols: + # # vertex_name = symbols_to_vertex[this_symbol] + # # this_copy = int(copy_solution[this_symbol].evalf(subs=free_copy_variables_dict, chop=True)) + # # if this_copy <= 0: + # # raise ProcessingGraphFailed("Cannot identify copy number of " + vertex_name + "!") + # # all_copies.append(this_copy) + # # if len(all_copies) == 0: + # # raise ProcessingGraphFailed( + # # "Incomplete/Complicated/Unsolvable " + target_name_for_log + " graph (4)!") + # # elif len(all_copies) == 1: + # # all_copies = [1] + # # elif min(all_copies) == 1: + # # pass + # # else: + # # new_all_copies = reduce_list_with_gcd(all_copies) + # # if verbose and new_all_copies != all_copies: + # # if log_handler: + # # log_handler.info("Estimated copies: " + str(all_copies)) + # # log_handler.info("Reduced copies: " + str(new_all_copies)) + # # else: + # # sys.stdout.write("Estimated copies: " + str(all_copies) + "\n") + # # sys.stdout.write("Reduced copies: " + str(new_all_copies) + "\n") + # # all_copies = new_all_copies + # # + # # """ record new copy values """ + # # for go_s, this_symbol in enumerate(all_v_symbols): + # # vertex_name = symbols_to_vertex[this_symbol] + # # if vertex_name in self.vertex_to_copy: + # # old_copy = self.vertex_to_copy[vertex_name] + # # self.copy_to_vertex[old_copy].remove(vertex_name) + # # if not self.copy_to_vertex[old_copy]: + # # del self.copy_to_vertex[old_copy] + # # this_copy = all_copies[go_s] + # # self.vertex_to_copy[vertex_name] = this_copy + # # if this_copy not in self.copy_to_vertex: + # # self.copy_to_vertex[this_copy] = set() + # # self.copy_to_vertex[this_copy].add(vertex_name) + # # + # # if debug or verbose: + # # """ re-estimate baseline depth """ + # # total_product = 0. + # # total_len = 0 + # # for vertex_name in vertices_list: + # # this_len = self.vertex_info[vertex_name].len \ + # # * self.vertex_to_copy.get(vertex_name, 1) + # # this_cov = self.vertex_info[vertex_name].cov / self.vertex_to_copy.get(vertex_name, 1) + # # total_len += this_len + # # total_product += this_len * this_cov + # # new_val = total_product / total_len + # # if log_handler: + # # log_handler.info("Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2))) + # # else: + # # sys.stdout.write( + # # "Average " + target_name_for_log + " kmer-coverage = " + str(round(new_val, 2)) + "\n") + + def tag_in_between(self, database_n=None): + """add those in between the tagged vertices_set to tagged_vertices, which offered the only connection""" + if database_n is None: + db_types = sorted(self.tagged_vertices) + else: + db_types = [database_n] + for db_n in db_types: + updated = True + candidate_vertices = list(self.vertex_info) + while updated: + updated = False + go_to_v = 0 + while go_to_v < len(candidate_vertices): + can_v = candidate_vertices[go_to_v] + if can_v in self.tagged_vertices[db_n]: del candidate_vertices[go_to_v] continue - count_nearby_tagged = [] - for can_end in (True, False): - for next_v, next_e in self.vertex_info[can_v].connections[can_end]: - # candidate_v is the only output vertex to next_v - if next_v in self.tagged_vertices[database_n] and \ - len(self.vertex_info[next_v].connections[next_e]) == 1: - count_nearby_tagged.append((next_v, next_e)) - break - if len(count_nearby_tagged) == 2: - del candidate_vertices[go_to_v] - # add in between - self.tagged_vertices[database_n].add(can_v) - if "weight" not in self.vertex_info[can_v].other_attr: - self.vertex_info[can_v].other_attr["weight"] = {} - if database_n not in self.vertex_info[can_v].other_attr["weight"]: - self.vertex_info[can_v].other_attr["weight"][database_n] = 0 - self.vertex_info[can_v].other_attr["weight"][database_n] += 1 * self.vertex_info[can_v].cov - if database_n != "embplant_mt": - # Adding extra circle - the contig in-between the sequential repeats - # To avoid risk of tagging mt as pt by mistake, - # the repeated contig must be at least 2 folds of the nearby tagged contigs - near_by_pairs = self.is_sequential_repeat(can_v, return_pair_in_the_trunk_path=False) - if near_by_pairs: - checking_new = [] - coverage_folds = [] - for near_by_p in near_by_pairs: - for (near_v, near_e) in near_by_p: - if (near_v, near_e) not in count_nearby_tagged: - checking_new.append(near_v) - # comment out for improper design: if the untagged is mt - # coverage_folds.append( - # round(self.vertex_info[can_v].cov / - # self.vertex_info[near_v].cov, 0)) - for near_v, near_e in count_nearby_tagged: - coverage_folds.append( - round(self.vertex_info[can_v].cov / - self.vertex_info[near_v].cov, 0)) - # if coverage folds is - if max(coverage_folds) >= 2: - for extra_v_to_add in set(checking_new): - self.tagged_vertices[database_n].add(extra_v_to_add) - try: - candidate_vertices.remove(extra_v_to_add) - except ValueError: - pass - # when a contig has no weights - if "weight" not in self.vertex_info[extra_v_to_add].other_attr: - self.vertex_info[extra_v_to_add].other_attr["weight"] = {database_n: 0} - # when a contig has weights of other database - if database_n not in self.vertex_info[extra_v_to_add].other_attr["weight"]: - self.vertex_info[extra_v_to_add].other_attr["weight"][database_n] = 0 - self.vertex_info[extra_v_to_add].other_attr["weight"][database_n] \ - += 1 * self.vertex_info[extra_v_to_add].cov - updated = True - break else: - go_to_v += 1 + if sum([bool(c_c) for c_c in self.vertex_info[can_v].connections.values()]) != 2: + del candidate_vertices[go_to_v] + continue + count_nearby_tagged = [] + for can_end in (True, False): + for next_v, next_e in self.vertex_info[can_v].connections[can_end]: + # candidate_v is the only output vertex to next_v + if next_v in self.tagged_vertices[db_n] and \ + len(self.vertex_info[next_v].connections[next_e]) == 1: + count_nearby_tagged.append((next_v, next_e)) + break + if len(count_nearby_tagged) == 2: + del candidate_vertices[go_to_v] + # add in between + self.tagged_vertices[db_n].add(can_v) + if "weight" not in self.vertex_info[can_v].other_attr: + self.vertex_info[can_v].other_attr["weight"] = {} + if db_n not in self.vertex_info[can_v].other_attr["weight"]: + self.vertex_info[can_v].other_attr["weight"][db_n] = 0 + self.vertex_info[can_v].other_attr["weight"][db_n] += 1 * self.vertex_info[can_v].cov + if db_n != "embplant_mt": + # Adding extra circle - the contig in-between the sequential repeats + # To avoid risk of tagging mt as pt by mistake, + # the repeated contig must be at least 2 folds of the nearby tagged contigs + near_by_pairs = self.is_sequential_repeat(can_v, return_pair_in_the_trunk_path=False) + if near_by_pairs: + checking_new = [] + coverage_folds = [] + for near_by_p in near_by_pairs: + for (near_v, near_e) in near_by_p: + if (near_v, near_e) not in count_nearby_tagged: + checking_new.append(near_v) + # comment out for improper design: if the untagged is mt + # coverage_folds.append( + # round(self.vertex_info[can_v].cov / + # self.vertex_info[near_v].cov, 0)) + for near_v, near_e in count_nearby_tagged: + coverage_folds.append( + round(self.vertex_info[can_v].cov / + self.vertex_info[near_v].cov, 0)) + # if coverage folds is + if max(coverage_folds) >= 2: + for extra_v_to_add in set(checking_new): + self.tagged_vertices[db_n].add(extra_v_to_add) + try: + candidate_vertices.remove(extra_v_to_add) + except ValueError: + pass + # when a contig has no weights + if "weight" not in self.vertex_info[extra_v_to_add].other_attr: + self.vertex_info[extra_v_to_add].other_attr["weight"] = {db_n: 0} + # when a contig has weights of other database + if db_n not in self.vertex_info[extra_v_to_add].other_attr["weight"]: + self.vertex_info[extra_v_to_add].other_attr["weight"][db_n] = 0 + self.vertex_info[extra_v_to_add].other_attr["weight"][db_n] \ + += 1 * self.vertex_info[extra_v_to_add].cov + updated = True + break + else: + go_to_v += 1 + + def parse_tab_file(self, + tab_file, + database_name, + type_factor, + max_gene_gap=300, + max_cov_diff=3., + log_handler=None, + append_info=False, + verbose=False, + random_obj=None, + ): + """ + :param tab_file: + :param database_name: + :param type_factor: + :param max_gene_gap: + :param max_cov_diff: + :param log_handler: + :param append_info: not recommended, keep the original information in the vertex info + :param verbose: + :param random_obj: + :return: + """ + if random_obj is None: + import random as random_obj - def parse_tab_file(self, tab_file, database_name, type_factor, log_handler=None): - # parse_csv, every locus only occur in one vertex (removing locations with smaller weight) + # # parse_csv, every locus only occur in one vertex (removing locations with smaller weight) + # 2022-12-22 modified for v2: locus can occur in multiple vertices that are linearly continuous + + # 1. parsing to tag_loci tag_loci = {} tab_matrix = [line.strip("\n").split("\t") for line in open(tab_file)][1:] for node_record in tab_matrix: @@ -1701,56 +3428,675 @@ def parse_tab_file(self, tab_file, database_name, type_factor, log_handler=None) if (locus_start == 1 or locus_end == self.vertex_info[vertex_name].len) \ and self.uni_overlap() and locus_len <= self.uni_overlap(): continue - if locus_name in tag_loci[locus_type]: - new_weight = locus_len * self.vertex_info[vertex_name].cov - if new_weight > tag_loci[locus_type][locus_name]["weight"]: - tag_loci[locus_type][locus_name] = {"vertex": vertex_name, "len": locus_len, - "weight": new_weight} - else: - tag_loci[locus_type][locus_name] = {"vertex": vertex_name, "len": locus_len, - "weight": locus_len * self.vertex_info[vertex_name].cov} - + # 2022-12-22 added + if locus_name not in tag_loci[locus_type]: + tag_loci[locus_type][locus_name] = [] + tag_loci[locus_type][locus_name].append( + {"vertex": vertex_name, + "len": locus_len, + "weight": locus_len * self.vertex_info[vertex_name].cov}) + # if locus_name in tag_loci[locus_type]: + # new_weight = locus_len * self.vertex_info[vertex_name].cov + # if new_weight > tag_loci[locus_type][locus_name]["weight"]: + # tag_loci[locus_type][locus_name] = {"vertex": vertex_name, "len": locus_len, + # "weight": new_weight} + # else: + # tag_loci[locus_type][locus_name] = \ + # {"vertex": vertex_name, + # "len": locus_len, + # "weight": locus_len * self.vertex_info[vertex_name].cov} + + # 2022-12-22~24 added + # 2. remove redundant tags can occur in multiple vertices that are not linearly continuous + # Under current version, there is no hit-start-end information from csv produced by slim_graph.py + # for 1) easy modification and 2) compatible with older versions + # So we have to guess the order of vertices in the linear gene + # TODO: in the future, the best solution is annotating the assembly graph accurately + + sum_tag_loci = {} + idx_v_cluster = False + v_to_cluster = {} + len_cluster = len(self.vertex_clusters) + # import time + # time0 = time.time() + # gmm_time = 0 for locus_type in tag_loci: - self.tagged_vertices[locus_type] = set() + sum_tag_loci[locus_type] = {} for locus_name in tag_loci[locus_type]: - vertex_name = tag_loci[locus_type][locus_name]["vertex"] - loci_weight = tag_loci[locus_type][locus_name]["weight"] - # tags - if "tags" not in self.vertex_info[vertex_name].other_attr: - self.vertex_info[vertex_name].other_attr["tags"] = {} - if locus_type in self.vertex_info[vertex_name].other_attr["tags"]: - self.vertex_info[vertex_name].other_attr["tags"][locus_type].add(locus_name) - else: - self.vertex_info[vertex_name].other_attr["tags"][locus_type] = {locus_name} - # weight - if "weight" not in self.vertex_info[vertex_name].other_attr: - self.vertex_info[vertex_name].other_attr["weight"] = {} - if locus_type in self.vertex_info[vertex_name].other_attr["weight"]: - self.vertex_info[vertex_name].other_attr["weight"][locus_type] += loci_weight + if len(tag_loci[locus_type][locus_name]) == 1: + sum_tag_loci[locus_type][locus_name] = {"vertex": [tag_loci[locus_type][locus_name][0]["vertex"]], + "weight": [tag_loci[locus_type][locus_name][0]["weight"]]} else: - self.vertex_info[vertex_name].other_attr["weight"][locus_type] = loci_weight - self.tagged_vertices[locus_type].add(vertex_name) + if not idx_v_cluster: + for go_c, v_clusters in enumerate(self.vertex_clusters): + for v_name in v_clusters: + v_to_cluster[v_name] = go_c + # 2023-01-07 added + single_locus_info = tag_loci[locus_type][locus_name] + # 2.1 to speed up, remove tags (de-weight) out of the main connected component + # if len(single_locus_info) > 10: + if verbose and log_handler: + log_handler.info(" de-weighting minor-component tags " + locus_type + ":" + locus_name) + g_weights = [0.] * len_cluster + cluster_to_info_id = {c_id: [] for c_id in range(len_cluster)} + for go_r, rec in enumerate(single_locus_info): + cluster_id = v_to_cluster[rec["vertex"]] + g_weights[cluster_id] += rec["weight"] + cluster_to_info_id[cluster_id].append(go_r) + max_g_w = max(g_weights) + rm_r_ids = [] + for go_c in range(len_cluster): + # arbitrary weight different between connected components + if g_weights[go_c] * 20 < max_g_w: + rm_r_ids.extend(cluster_to_info_id[go_c]) + rm_r_ids.sort(reverse=True) + if verbose and log_handler: + log_handler.info(" " + str(len(rm_r_ids)) + "/" + str(len(single_locus_info)) + + " de-weighted: " + str([single_locus_info[_r]["vertex"] for _r in rm_r_ids])) + for go_r in rm_r_ids: + del single_locus_info[go_r] + + # if len(self._get_tagged_merged_paths([_rec["vertex"] for _rec in single_locus_info])) > 1: + # 2.2 mark tags of minor coverage as negative + if verbose and log_handler: + log_handler.info(" negatizing tags based coverage " + locus_type + ":" + locus_name) + single_locus_info.sort(key=lambda x: -x["weight"]) + vertices = [x["vertex"] for x in single_locus_info] + # self.get_clusters(limited_vertices=vertices) + # maybe increase the vertex weight in the main component + coverages = [self.vertex_info[_v].cov for _v in vertices] + # v_weights = [x["weight"] for x in single_locus_info] + v_weights = [self.vertex_info[_v].len for _v in vertices] + if verbose and log_handler: + log_handler.info(" vertices: " + str(vertices) + "; depths: " + str(coverages) + + "; weights: " + str(v_weights)) + # timex = time.time() + # most time consuming step + gmm_scheme = weighted_clustering_with_em_aic( + data_array=coverages, + data_weights=v_weights, + maximum_cluster=5, + log_handler=log_handler, + verbose_log=verbose, + random_obj=random_obj) + # print(time.time() - timex) + # gmm_time += time.time() - timex + labels = gmm_scheme["labels"] + if log_handler and verbose: + log_handler.info(" labels: " + str(list(labels))) + l_weights = [0.] * gmm_scheme["cluster_num"] + for go_r, lb in enumerate(labels): + l_weights[lb] += single_locus_info[go_r]["weight"] + selected_lb = l_weights.index(max(l_weights)) + selected_param = gmm_scheme["parameters"][selected_lb] + selected_mu, selected_sigma = selected_param["mu"], selected_param["sigma"] + keep_lbs = {go_l + for go_l, params in enumerate(gmm_scheme["parameters"]) + if params["mu"] - selected_mu > -2 * max(selected_sigma, params["sigma"])} + # rm_idx = sorted([go_r for go_r, lb in enumerate(labels) if lb not in keep_lbs], reverse=True) + ne_idx = [go_r for go_r, lb in enumerate(labels) if lb not in keep_lbs] + if verbose and log_handler: + log_handler.info(" " + str(len(ne_idx)) + "/" + str(len(single_locus_info)) + + " negatized: " + str([single_locus_info[_r]["vertex"] for _r in ne_idx])) + # for go_r in rm_idx: + # del single_locus_info[go_r] + if locus_type == database_name: + # only negatizing the target label + for go_r in ne_idx: + single_locus_info[go_r]["weight"] = -1 * abs(single_locus_info[go_r]["weight"]) + else: + for rm_id in sorted(ne_idx, reverse=True): + del single_locus_info[rm_id] + + # 2.3. remove redundant tags that occur in parallel vertices + if verbose and log_handler: + log_handler.info(" negatizing parallel tags: " + locus_type + ":" + locus_name) + v_to_sl_id = {rec["vertex"]: go_r for go_r, rec in enumerate(single_locus_info)} + parallel_vertices_list = self.detect_parallel_vertices( + limited_vertices=list(v_to_sl_id), + detect_neighbors=False) + if parallel_vertices_list: + rm_r_ids = set() + ne_r_ids = set() + for prl_vertices_set in parallel_vertices_list: + # sort by weight, then coverage + prl_vertices = sorted( + prl_vertices_set, + key=lambda x: ( + -single_locus_info[v_to_sl_id[x[0]]]["weight"], + -self.vertex_info[x[0]].cov)) + up_v, up_e = prl_vertices[0] + up_id = v_to_sl_id[up_v] + up_lb = labels[up_id] + up_sigma = gmm_scheme["parameters"][up_lb]["sigma"] + up_cov = self.vertex_info[up_v].cov + for de_name, de_end in prl_vertices[1:]: + de_id = v_to_sl_id[de_name] + de_lb = labels[de_id] + de_sigma = gmm_scheme["parameters"][de_lb]["sigma"] + de_cov = self.vertex_info[de_name].cov + if abs(de_cov - up_cov) < 2 * max(up_sigma, de_sigma) or \ + single_locus_info[de_id]["weight"] / de_cov > \ + single_locus_info[up_id]["weight"] / up_cov: + # to be conserved + rm_r_ids.add(de_id) + else: + ne_r_ids.add(de_id) + if verbose and log_handler: + log_handler.info(" (" + str(len(ne_r_ids)) + "+" + str(len(rm_r_ids)) + + ")/" + str(len(single_locus_info)) + + " negatized: " + + str([single_locus_info[_r]["vertex"] for _r in ne_r_ids]) + + " de-weighted: " + + str([single_locus_info[_r]["vertex"] for _r in rm_r_ids])) + if locus_type == database_name: + # only negatizing the target label + for go_r in ne_r_ids: + single_locus_info[go_r]["weight"] = -1 * abs(single_locus_info[go_r]["weight"]) + for rm_id in sorted(rm_r_ids, reverse=True): + del single_locus_info[rm_id] + else: + for rm_id in sorted(ne_r_ids|rm_r_ids, reverse=True): + del single_locus_info[rm_id] + + # 2.4 search for the linear tags maximize the total gene weight + if verbose and log_handler: + log_handler.info(" linearize " + locus_type + ":" + locus_name + ":" + + str([x["vertex"] for x in single_locus_info])) + sum_tag_loci[locus_type][locus_name] = \ + self._find_linear_tags(single_locus_info, max_gene_gap, max_cov_diff, verbose, log_handler) + # print("gmm cost", gmm_time) + # print("tagging cost", time.time() - time0) + # 3. assign information in sum_tag_loci to contigs.other_attr + # 2022-12-22 modified + if not append_info: + # clean previous info + for vertex_name in self.vertex_info: + self.vertex_info[vertex_name].other_attr["tags"] = {} + for locus_type in sum_tag_loci: + self.tagged_vertices[locus_type] = set() + self.tagged_vertices[locus_type + "-"] = set() # negative type + # TODO: add locus_name weights according to taxa statistics across "all locus types" + # e.g. ycf15, rpl2* should have much lower weights, + # because they were more often seen to be HGTed from pt to mt + # set arbitrary values for temporary usage + if locus_type == "embplant_pt": + extra_w = {l_n: 1.5 + if l_n in {"rpoA", "rpoB", "rpoC1", "rpoC2", + "atpA", "atpB", "atpE", "atpF", "atpH", "atpI", + "rbcL", + "petB", "petG", + "rrn16", "rrn23", "rrn4.5", "rrn5"} + else 0.5 + for l_n in sum_tag_loci[locus_type]} + else: + extra_w = {} + for locus_name in sum_tag_loci[locus_type]: + # 2022-12-22 modified + + for vertex_name, loci_weight in zip(sum_tag_loci[locus_type][locus_name]["vertex"], + sum_tag_loci[locus_type][locus_name]["weight"]): + # vertex_name = tag_loci[locus_type][locus_name]["vertex"] + # loci_weight = tag_loci[locus_type][locus_name]["weight"] + # tags + adjusted_w = loci_weight * extra_w.get(locus_name, 1.) + if "tags" not in self.vertex_info[vertex_name].other_attr: + self.vertex_info[vertex_name].other_attr["tags"] = {} + if locus_type in self.vertex_info[vertex_name].other_attr["tags"]: + self.vertex_info[vertex_name].other_attr["tags"][locus_type][locus_name] = adjusted_w + else: + self.vertex_info[vertex_name].other_attr["tags"][locus_type] = {locus_name: adjusted_w} + # weight + if "weight" not in self.vertex_info[vertex_name].other_attr: + self.vertex_info[vertex_name].other_attr["weight"] = {} + if locus_type in self.vertex_info[vertex_name].other_attr["weight"]: + self.vertex_info[vertex_name].other_attr["weight"][locus_type] += adjusted_w + else: + self.vertex_info[vertex_name].other_attr["weight"][locus_type] = adjusted_w + # self.tagged_vertices[locus_type].add(vertex_name) + # 4. clarify locus_type for each contig by comparing weights, and add to self.tagged_vertices for vertex_name in self.vertex_info: if "weight" in self.vertex_info[vertex_name].other_attr: + all_weights = [(loc_type, self.vertex_info[vertex_name].other_attr["weight"][loc_type]) + for loc_type in self.vertex_info[vertex_name].other_attr["weight"]] if len(self.vertex_info[vertex_name].other_attr["weight"]) > 1: - all_weights = sorted([(loc_type, self.vertex_info[vertex_name].other_attr["weight"][loc_type]) - for loc_type in self.vertex_info[vertex_name].other_attr["weight"]], - key=lambda x: -x[1]) - best_t, best_w = all_weights[0] + all_weights.sort(key=lambda x: -x[1]) + best_t, best_w = all_weights[0] + if best_w > 0: + self.tagged_vertices[best_t].add(vertex_name) for next_t, next_w in all_weights[1:]: - if next_w * type_factor < best_w: - self.tagged_vertices[next_t].remove(vertex_name) + if next_w * type_factor >= best_w: + self.tagged_vertices[next_t].add(vertex_name) + elif best_w <= 0: + for next_t, next_w in all_weights: + if next_w < 0: + self.tagged_vertices[next_t + "-"].add(vertex_name) if database_name not in self.tagged_vertices or len(self.tagged_vertices[database_name]) == 0: raise ProcessingGraphFailed("No available " + database_name + " information found in " + tab_file) + # print("parsing cost", time.time() - time0) + + def _get_tagged_merged_paths(self, tagged_vs): + raw_tagged = set(tagged_vs) + vs_to_merge = OrderedDict([(_v, True) for _v in tagged_vs]) + vs_used = set() + # print("merging tagged_vs: " + str(raw_tagged)) + merged_paths = [] + while vs_to_merge: + check_v, foo = vs_to_merge.popitem() + vs_used.add(check_v) + # print(check_v, foo) + extend_e = True + this_path = [(check_v, extend_e)] + while True: + this_v, this_e = this_path[-1] + next_con = [(_v, _e) for _v, _e in self.vertex_info[this_v].connections[this_e]] + next_con_tagged = [(_v, _e) for _v, _e in next_con if _v in raw_tagged] + # print("this_path", this_path) + # print("next_con_tagged", next_con_tagged) + if len(next_con_tagged) == 1: + next_v, next_e = next_con_tagged[0] + back_con = [(_v, _e) for _v, _e in self.vertex_info[next_v].connections[next_e]] + back_con_tagged = [(_v, _e) for _v, _e in back_con if _v in raw_tagged] + # print("back_con_tagged", back_con_tagged) + # if there is only one possible merging way, and the next is not equal to the previous one + if back_con_tagged == [(this_v, this_e)] and next_v not in vs_used: + # extra criteria to avoid generating chimeric path when the main path has a small gap + # if there are non tagged connections, check if there is coverage similarity between candidate + # contigs before merging + extra_match = False + if len(next_con) == len(back_con) == 1: + extra_match = True + else: + this_cov = self.vertex_info[this_v].cov + next_cov = self.vertex_info[next_v].cov + cov_diff = abs(this_cov - next_cov) + if cov_diff/max(this_cov, next_cov) < 0.2: # arbitrary empirical value + # print("cov_diff", cov_diff) + # print("testing passed 1") + extra_match = True + if len(next_con) > 1: + # if the coverage of any other connection is closer to the tagged candidate one + if min([abs(this_cov - self.vertex_info[_n].cov) for _n, _e in next_con]) \ + != cov_diff: + extra_match = False + # print(this_cov, [self.vertex_info[_n].cov for _n, _e in next_con]) + # print([abs(this_cov - self.vertex_info[_n].cov) for _n, _e in next_con]) + # print("testing failed 2", extra_match) + # else: + # print(this_cov, [self.vertex_info[_n].cov for _n, _e in next_con]) + # print([abs(this_cov - self.vertex_info[_n].cov) for _n, _e in next_con]) + # print("testing passed 2", extra_match) + # else: + # print("testing passed 3", extra_match) + if extra_match and len(back_con) > 1: + # if the coverage of any other connection is closer to the tagged candidate one + if min([abs(next_cov - self.vertex_info[_n].cov) for _n, _e in back_con]) \ + != cov_diff: + extra_match = False + # print(next_cov, [self.vertex_info[_n].cov for _n, _e in back_con]) + # print([abs(next_cov - self.vertex_info[_n].cov) for _n, _e in back_con]) + # print("testing failed 4", extra_match) + # else: + # print(this_cov, [self.vertex_info[_n].cov for _n, _e in next_con]) + # print([abs(this_cov - self.vertex_info[_n].cov) for _n, _e in next_con]) + # print("testing passed 4", extra_match) + # else: + # print("testing passed 5", extra_match) + if extra_match: + this_path.append((next_v, not next_e)) + del vs_to_merge[next_v] + vs_used.add(next_v) + continue + if extend_e: + this_path = [(_v, not _e) for _v, _e in this_path[::-1]] + extend_e = False + else: + break + merged_paths.append(this_path) + # print("merged paths: " + str(merged_paths)) + return merged_paths + + def _find_linear_tags(self, tag_locus_info, max_gene_gap, max_cov_diff, verbose=False, log_handler=None): + + def _try_merge(_current_p_id, _rev_p, _next_p_id, potential_start, middle_gap_p=None): + # if not the first one to extend/merge + if count_keep > 0: + _c_opt = deepcopy(raw_opt) + else: + _c_opt = c_opt + if _rev_p: + _c_opt["paths"][_current_p_id] = [(_v, not _e) + for _v, _e in _c_opt["paths"][_current_p_id][::-1]] + # if check_gene: + # print("_c_opt['paths'][_current_p_id]", _c_opt["paths"][_current_p_id]) + # print('_c_opt["paths"][_next_p_id]', _c_opt["paths"][_next_p_id]) + _nv, _ne = potential_start + + middle_gap_p = [] if not middle_gap_p else middle_gap_p + if _c_opt["paths"][_next_p_id][0] == potential_start: + _c_opt["paths"][_current_p_id].extend(middle_gap_p) + # merge go_p and next_p + _c_opt["paths"][_current_p_id].extend(_c_opt["paths"][_next_p_id]) + # if check_gene: + # print(" merge forward, new paths", _c_opt["paths"][_current_p_id]) + elif _c_opt["paths"][_next_p_id][-1] == (_nv, not _ne): + _c_opt["paths"][_current_p_id].extend(middle_gap_p) + # merge go_p and next_p in the reverse + rev_next_p = [(_v, not _e) for _v, _e in _c_opt["paths"][_next_p_id][::-1]] + _c_opt["paths"][_current_p_id].extend(rev_next_p) + # if check_gene: + # print(" merge reverse, new paths", _c_opt["paths"][_current_p_id]) + else: + return False + del _c_opt["paths"][_next_p_id] + # update path id + _c_opt["path_id"] = {_v: _p_id + for _p_id, _p in enumerate(_c_opt["paths"]) + for _v, _e in _p} + # update tuple + _c_opt["tuple"] = self.standardize_paths(_c_opt["paths"]) + # if not the first one to extend/merge + if count_keep > 0: + candidate_options.append(_c_opt) + return True + + # 2023-01-16 added + def _heuristic_generator(path_list): + # arbitrary set empirical threshold for speeding up + if len(path_list) < 4: + for _go_p, _this_path in enumerate(path_list): + yield _go_p, _this_path + else: + # starting from the largest weight, + # pick top-4 paths or top-2 successes, whichever comes later + _new_list = [(_go_p_, _this_p_, sum([tagged_v_w.get(_v, 0.) for _v, _e, in _this_p_])) + for _go_p_, _this_p_ in enumerate(path_list)] + _new_list.sort(key=lambda x: -x[2]) + count_gen = 0 + for _go_p, _this_path, _weight in _new_list: + count_gen += 1 + if count_gen < 4 or extended.count(True) < 2: + yield _go_p, _this_path + else: + break - def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_threshold=10., + tagged_v_w = {_rec["vertex"]: _rec["weight"] for _rec in tag_locus_info} + tagged_vs = sorted([_rec["vertex"] for _rec in tag_locus_info if _rec["weight"] > 0]) + + # check_gene = False + # if "323673" in tagged_vs: + # check_gene = True + + # merge tagged vertices into paths before linear searching + merged_paths = self._get_tagged_merged_paths(tagged_vs) + candidate_options = [{}] + candidate_options[0]["tuple"] = self.standardize_paths(merged_paths) + candidate_options[0]["paths"] = [list(_p) for _p in candidate_options[0]["tuple"]] + candidate_options[0]["path_id"] = {_v: _p_id + for _p_id, _p in enumerate(candidate_options[0]["paths"]) + for _v, _e in _p} + if len(candidate_options[0]["tuple"]) > 1: + # tagged_set = set(tagged_vs) + intermediate_combinations = set([]) # to avoid repeated calculation + + # # start_v = sorted(tag_loci[locus_type][locus_name], + # # key=lambda x: (-x["weight"], x["vertex"]))[0]["vertex"] + # # start_v = tagged_vs[0] + # # sv_id = tagged_vs.index(start_v) + # candidate_options = [{"paths": [[(_v, True)] for _v in tagged_vs], + # "path_id": {_v: p_id for p_id, _v in enumerate(tagged_vs)} + # }] + # # palindromic repeats does not matter, just cause duplicates + # candidate_options[0]["tuple"] = self.standardize_paths(candidate_options[0]["paths"]) + + go_candidate = 0 + while go_candidate < len(candidate_options): + # if check_gene: + # print("go_candidate", go_candidate) + # input("") + if candidate_options[go_candidate]["tuple"] in intermediate_combinations: + del candidate_options[go_candidate] # searched + else: + c_opt = candidate_options[go_candidate] + # if check_gene: + # print("c_opt (" + str(len(c_opt["paths"])) + "):", c_opt["paths"]) + # input("") + intermediate_combinations.add(c_opt["tuple"]) + extended = [] + count_keep = 0 + raw_opt = deepcopy(c_opt) + for go_p, this_path in _heuristic_generator(list(raw_opt["paths"])): + extended.append(False) + for rev_p in (False, True): + # if check_gene: + # print(" go_p", go_p, rev_p, this_path) + # # print(" next_con_ls_tagged_pair:", next_con_ls_tagged_pair) + # input("") + if rev_p: + # palindromic repeats does not matter, just cause duplicates + this_path = [(this_v, not this_e) for this_v, this_e in this_path[::-1]] + + # Problematic + # next_connections = next_con_pair[int(rev_p)] + # next_connect_ls = next_con_ls_pair[int(rev_p)] + + extend_v, extend_e = this_path[-1] + next_connections = self.vertex_info[extend_v].connections[extend_e] + # constraint the coverage change + next_connect_ls = [(_n, _e) + for _n, _e in next_connections + if self.vertex_info[_n].cov / max_cov_diff + < self.vertex_info[extend_v].cov + < max_cov_diff * self.vertex_info[_n].cov] + # if check_gene: + # print(" next_connections", next_connections) + # input("") + if len(next_connect_ls) == 0: + continue + else: + for next_v, next_e in next_connect_ls: + if this_path.count((next_v, not next_e)) \ + >= len(self.vertex_info[next_v].connections[next_e]): + # real multiplicity does not matter, just search for the simplest path + # that represent the gene + continue + elif next_v in raw_opt["path_id"]: + next_p_id = raw_opt["path_id"][next_v] + if next_p_id != go_p: # not self-loop + if _try_merge(go_p, rev_p, next_p_id, (next_v, not next_e)): + # if check_gene: + # print(" merged with ", next_v, next_p_id, self.vertex_info[next_v].cov) + # input("") + count_keep += 1 + extended[-1] = True + else: + # allow gaps + # if check_gene: + # print(" check gaps") + accumulated_gap = self.vertex_info[next_v].len - \ + next_connections[(next_v, next_e)] + gap_paths = [{"p": [(next_v, not next_e)], "l": accumulated_gap}] + go_g = 0 + while go_g < len(gap_paths): + if gap_paths[go_g]["l"] > max_gene_gap: + del gap_paths[go_g] + else: + next_ext_v, next_ext_e = gap_paths[go_g]["p"][-1] + nn_cons = self.vertex_info[next_ext_v].connections[next_ext_e] + # constraint the coverage change + nn_con_ls = [(_n, _e) + for _n, _e in nn_cons + if self.vertex_info[_n].cov / max_cov_diff + < self.vertex_info[next_ext_v].cov + < max_cov_diff * self.vertex_info[_n].cov] + if len(nn_con_ls) == 0: + del gap_paths[go_g] + # elif len(nn_con_ls) == 1: + # nn_v, nn_e = nn_con_ls[0] + # gap_paths[go_g]["p"].append((nn_v, not nn_e)) + # # either jump to the next gap path option + # # or add the accumulated gap_length, + # # which both lead to search termination + # if nn_v in raw_opt["path_id"]: + # # nn_p_id = c_opt["path_id"][nn_v] + # # if nn_p_id == go_p: + # # del gap_paths[go_g] + # go_g += 1 + # else: + # gap_paths[go_g]["l"] += \ + # self.vertex_info[nn_v].len - nn_cons[(nn_v, nn_e)] + else: + dup_p = deepcopy(gap_paths[go_g]) + for go_c, (nn_v, nn_e) in enumerate(nn_con_ls): + if go_c == 0: + gap_paths[go_g]["p"].append((nn_v, not nn_e)) + if nn_v in raw_opt["path_id"]: + go_g += 1 + else: + gap_paths[go_g]["l"] += \ + self.vertex_info[nn_v].len - nn_cons[(nn_v, nn_e)] + else: + if go_c < len(nn_con_ls) - 1: + this_p = deepcopy(dup_p) + else: + this_p = dup_p + this_p["p"].append((nn_v, not nn_e)) + if nn_v in raw_opt["path_id"]: + gap_paths.insert(go_g, this_p) + go_g += 1 + else: + this_p["l"] += \ + self.vertex_info[nn_v].len - nn_cons[(nn_v, nn_e)] + gap_paths.append(this_p) + # if check_gene: + # print(" gap_paths", gap_paths) + if gap_paths: + if len(gap_paths) == 1: + p_start = nn_v, nn_e = gap_paths[0]["p"][-1] + nn_p_id = raw_opt["path_id"][nn_v] + # if check_gene: + # print("go_p, rev_p, nn_p_id, p_start, gap_paths[0]['p'][:-1]") + # print(go_p, rev_p, nn_p_id, p_start, gap_paths[0]["p"][:-1]) + if nn_p_id != go_p: + if _try_merge(go_p, rev_p, nn_p_id, p_start, gap_paths[0]["p"][:-1]): + # if check_gene: + # print(" merged (gap) with ", nn_v, nn_p_id, + # self.vertex_info[nn_v].cov) + # input("") + count_keep += 1 + extended[-1] = True + # else: + # continue + else: + for go_g, gap_path in enumerate(gap_paths): + p_start = nn_v, nn_e = gap_path["p"][-1] + nn_p_id = raw_opt["path_id"][nn_v] + if nn_p_id == go_p: + continue + if _try_merge(go_p, rev_p, nn_p_id, p_start, gap_path["p"][:-1]): + # if check_gene: + # print(" merged (gap) with ", nn_v, nn_p_id, + # self.vertex_info[nn_v].cov) + # input("") + count_keep += 1 + extended[-1] = True + # if extended: + # break + # if extended: + # break + if True not in extended: + go_candidate += 1 + # pick the paths with the largest weight + # sort candidate_options by its maximum path weight (sum of v weights in a path), decreasingly + candidate_options.sort(key=lambda opt: (max([sum([tagged_v_w.get(_v, 0.) + for _v, _e, in _p]) + for _p in opt["paths"]]), + opt["tuple"]), + reverse=True) + if verbose: + for candidate_opt in candidate_options: + log_handler.info(" paths: " + str(candidate_opt["paths"])) + log_handler.info(" weights: " + str([sum([tagged_v_w.get(_v, 0.) + for _v, _e, in _p]) + for _p in candidate_opt["paths"]])) + best_paths = candidate_options[0]["paths"] + # pick the path with the largest weight + best_paths.sort(key=lambda _path: sum([tagged_v_w.get(_v, 0.) for _v, _e, in _path]), reverse=True) + best_path = best_paths[0] + if verbose and log_handler: + log_handler.info(" best_path: " + str(best_path)) + # generate info table + # labeled_vs = sorted(set([_v for _v, _e in best_path])) + labeled_vs = set([_v for _v, _e in best_path]) + res_dict = {"vertex": [], "weight": []} + for this_v in sorted(labeled_vs): + res_dict["vertex"].append(this_v) + res_dict["weight"].append(tagged_v_w.get(this_v, 0.)) + for record in tag_locus_info: + this_v = record["vertex"] + this_w = record["weight"] + if this_v not in labeled_vs: + if this_w < 0: + res_dict["vertex"].append(this_v) + res_dict["weight"].append(this_w) + elif self.check_connected(labeled_vs | {this_v}): + # connected but not real one should have negative weight + res_dict["vertex"].append(this_v) + res_dict["weight"].append(-abs(this_w)) + return res_dict + + def _revise_single_copy_coverages_based_on_graph(self, vertex_list): + # rough processing + revised_coverages = [] + for v_name in vertex_list: + forward_len = len(self.vertex_info[v_name].connections[True]) + reverse_len = len(self.vertex_info[v_name].connections[False]) + assumed_copy_num = self.vertex_to_copy.get(v_name, 1) + if forward_len <= 1 and reverse_len <= 1: + revised_coverages.append(self.vertex_info[v_name].cov / assumed_copy_num) + else: + # set arbitrary threshold as 4 + forward_main = [] + forward_minor = [] + if forward_len: + forward_coverages = [(self.vertex_info[_v].cov, _v) + for _v, _e in self.vertex_info[v_name].connections[True]] + max_f_cov = max([_x[0] for _x in forward_coverages]) + for this_cov, this_v in forward_coverages: + if this_cov * 4 > max_f_cov: + forward_main.append((this_cov, this_v)) + else: + forward_minor.append((this_cov, this_v)) + reverse_main = [] + reverse_minor = [] + if reverse_len: + reverse_coverages = [(self.vertex_info[_v].cov, _v) + for _v, _e in self.vertex_info[v_name].connections[False]] + max_r_cov = max([_x[0] for _x in reverse_coverages]) + for this_cov, this_v in reverse_coverages: + if this_cov * 4 > max_r_cov: + reverse_main.append((this_cov, this_v)) + else: + reverse_minor.append((this_cov, this_v)) + # skip weighting + sum_minor_cov = (sum([_x[0] for _x in forward_minor]) + sum([_x[0] for _x in forward_minor]))/2. + assumed_copy_num = max(assumed_copy_num, len(forward_main), len(reverse_main)) + revised_coverages.append((self.vertex_info[v_name].cov - sum_minor_cov) / assumed_copy_num) + return revised_coverages + + def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", min_cov_folds=5., weight_factor=100., min_sigma_factor=0.1, min_cluster=1, terminal_extra_weight=5., - verbose=False, log_handler=None, debug=False): + verbose=False, log_handler=None, debug=False, random_obj=None): + if random_obj is None: + import random as random_obj changed = False + if len(self.vertex_info) == 1 and list(self.vertex_info)[0] in self.tagged_vertices[database_n]: + only_cov = self.vertex_info[list(self.vertex_info)[0]].cov + return changed, [(only_cov, only_cov * min_sigma_factor)] # overlap = self.__overlap if self.__overlap else 0 - log_hard_cov_threshold = abs(log(hard_cov_threshold)) + # log_min_cov_folds = abs(log(min_cov_folds)) vertices = sorted( self.vertex_info, key=lambda x: (-self.vertex_info[x].other_attr.get("weight", {}).get(database_n, 0), x)) # 2022-05-06: use the coverage of the contig with the max weight instead of the max coverage @@ -1759,21 +4105,13 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre log_handler.info("coverage threshold set: " + str(standard_coverage)) elif verbose or debug: sys.stdout.write("coverage threshold set: " + str(standard_coverage) + "\n") - # 2022-05-06: use the coverage of the contig with the max weight instead of the max coverage - # v_coverages = {this_v: self.vertex_info[this_v].cov / self.vertex_to_copy.get(this_v, 1) for this_v in vertices} - # try: - # max_tagged_cov = max([v_coverages[tagged_v] for tagged_v in self.tagged_vertices[database_n]]) - # except ValueError as e: - # if log_handler: - # log_handler.info("tagged vertices: " + str(self.tagged_vertices)) - # else: - # sys.stdout.write("tagged vertices: " + str(self.tagged_vertices) + "\n") - # raise e - # removing coverage with 10 times lower/greater than tagged_cov + # removing coverage with min_cov_folds times lower/greater than tagged_cov + # if abs(log(self.vertex_info[candidate_v].cov / standard_coverage)) > log_min_cov_folds] + # 2022-12-05 only remove lower contigs removing_low_cov = [candidate_v for candidate_v in vertices - if abs(log(self.vertex_info[candidate_v].cov / standard_coverage)) > log_hard_cov_threshold] + if self.vertex_info[candidate_v].cov * min_cov_folds < standard_coverage] if removing_low_cov: if log_handler and (debug or verbose): log_handler.info("removing extremely outlying coverage contigs: " + str(removing_low_cov)) @@ -1784,11 +4122,17 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre merged = self.merge_all_possible_vertices() if merged: changed = True + tagged_vs = [_v for _v in self.tagged_vertices[database_n] + if self.vertex_info[_v].other_attr.get("weight", {}).get(database_n, -1) > 0] + self.estimate_copy_and_depth_by_cov(tagged_vs, + min_sigma=min_sigma_factor, + debug=debug, log_handler=log_handler, + verbose=verbose, mode=database_n) vertices = sorted(self.vertex_info) - v_coverages = {this_v: self.vertex_info[this_v].cov / self.vertex_to_copy.get(this_v, 1) - for this_v in vertices} - - coverages = np.array([v_coverages[this_v] for this_v in vertices]) + # v_coverages = {this_v: self.vertex_info[this_v].cov / self.vertex_to_copy.get(this_v, 1) + # for this_v in vertices} + # coverages = np.array([v_coverages[this_v] for this_v in vertices]) + coverages = np.array(self._revise_single_copy_coverages_based_on_graph(vertices)) cover_weights = np.array([self.vertex_info[this_v].len # multiply by copy number * self.vertex_to_copy.get(this_v, 1) @@ -1796,20 +4140,34 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre * (terminal_extra_weight if self.vertex_info[this_v].is_terminal() else 1) for this_v in vertices]) tag_kinds = [tag_kind for tag_kind in self.tagged_vertices if self.tagged_vertices[tag_kind]] - tag_kinds.sort(key=lambda x: x != database_n) + set_kinds = [tag_kind for tag_kind in tag_kinds if not tag_kind.endswith("-")] + # introduced 2023-01-11 + ban_kind_set = set([ban_kind for ban_kind in tag_kinds if ban_kind.endswith("-") and ban_kind in set_kinds]) + set_kinds.sort(key=lambda x: x != database_n) + # force labeled vertex to be in specific cluster, which provide the supervision information for the clustering set_cluster = {} - for v_id, vertex_name in enumerate(vertices): - for go_tag, this_tag in enumerate(tag_kinds): + for go_tag, this_tag in enumerate(set_kinds): + for v_id, vertex_name in enumerate(vertices): if vertex_name in self.tagged_vertices[this_tag]: if v_id not in set_cluster: set_cluster[v_id] = set() set_cluster[v_id].add(go_tag) - min_tag_kind = {0} - for v_id in set_cluster: - if 0 not in set_cluster[v_id]: - min_tag_kind |= set_cluster[v_id] - min_cluster = max(min_cluster, len(min_tag_kind)) - + # introduced 2023-01-11 + ban_cluster = {} + for go_tag, this_tag in enumerate(set_kinds): + ban_tag = this_tag + "-" + if ban_tag in ban_kind_set: + for v_id, vertex_name in enumerate(vertices): + if vertex_name in self.tagged_vertices[ban_tag]: + if v_id not in ban_cluster: + ban_cluster[v_id] = set() + ban_cluster[v_id].add(go_tag) + # # min number of clusters + # min_tag_kind = {0} + # for v_id in set_cluster: + # if 0 not in set_cluster[v_id]: + # min_tag_kind |= set_cluster[v_id] + # min_cluster = max(min_cluster, len(min_tag_kind)) # old way: # set_cluster = {v_coverages[tagged_v]: 0 for tagged_v in self.tagged_vertices[mode]} @@ -1818,13 +4176,34 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre if log_handler and (debug or verbose): log_handler.info("Vertices: " + str(vertices)) log_handler.info("Coverages: " + str([float("%.1f" % cov_x) for cov_x in coverages])) - elif verbose or debug: - sys.stdout.write("Vertices: " + str(vertices) + "\n") - sys.stdout.write("Coverages: " + str([float("%.1f" % cov_x) for cov_x in coverages]) + "\n") - gmm_scheme = weighted_gmm_with_em_aic(coverages, data_weights=cover_weights, - minimum_cluster=min_cluster, maximum_cluster=6, - cluster_limited=set_cluster, min_sigma_factor=min_sigma_factor, - log_handler=log_handler, verbose_log=verbose) + log_handler.info("Ban cluster: " + str(ban_cluster)) + log_handler.info("Set cluster: " + str(set_cluster)) + # elif verbose or debug: + # sys.stdout.write("Vertices: " + str(vertices) + "\n") + # sys.stdout.write("Coverages: " + str([float("%.1f" % cov_x) for cov_x in coverages]) + "\n") + singleton_constraint = [tuple(_cl) for _cl in set_cluster.values() if len(_cl) == 1] + if min_cluster > 1 and len(singleton_constraint) == len(coverages) and len(set(singleton_constraint)) == 1: + set_cluster = {} + if log_handler and (debug or verbose): + log_handler.info("Set cluster reset to: " + str(set_cluster)) + # print(min_cluster) + try: + gmm_scheme = weighted_clustering_with_em_aic(coverages, data_weights=cover_weights, + minimum_cluster=min_cluster, maximum_cluster=6, + cluster_limited=set_cluster, cluster_bans=ban_cluster, + min_sigma_factor=min_sigma_factor, + log_handler=log_handler, verbose_log=verbose, + random_obj=random_obj) + except ValueError as e: + if "Solution Not Found" in str(e): + # just calculate the mu and sigma using the single component model + gmm_scheme = weighted_clustering_with_em_aic(coverages, data_weights=cover_weights, + minimum_cluster=1, maximum_cluster=1, + min_sigma_factor=min_sigma_factor, + log_handler=log_handler, verbose_log=verbose, + random_obj=random_obj) + else: + raise e cluster_num = gmm_scheme["cluster_num"] parameters = gmm_scheme["parameters"] # for debug @@ -1845,13 +4224,14 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre # for lb in selected_label_type: # this_add_up = 0 # for go in np.where(labels == lb)[0]: - # this_add_up += self.vertex_info[vertices[go]].get("weight", {}).get(mode, 0) + # this_add_up += self.vertex_info[vertices_set[go]].get("weight", {}).get(mode, 0) # label_weights[lb] = this_add_up label_weights = {lb: sum([self.vertex_info[vertices[go]].other_attr.get("weight", {}).get(database_n, 0) for go in np.where(labels == lb)[0]]) for lb in selected_label_type} selected_label_type.sort(key=lambda x: -label_weights[x]) remained_label_type = {selected_label_type[0]} + # add minor weights if qualified for candidate_lb_type in selected_label_type[1:]: if label_weights[candidate_lb_type] * weight_factor >= selected_label_type[0]: remained_label_type.add(candidate_lb_type) @@ -1865,6 +4245,16 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre if abs(can_mu - parameters[remained_l]["mu"]) < 2 * parameters[remained_l]["sigma"]: extra_kept.add(candidate_lb_type) break + # does not help - 2023-01-17 + # if database_n == "embplant_mt": + # if abs(can_mu - parameters[remained_l]["mu"]) < 2 * parameters[remained_l]["sigma"]: + # extra_kept.add(candidate_lb_type) + # break + # else: + # # all contigs with larger coverages will be kept + # if parameters[remained_l]["mu"] - can_mu < 2 * parameters[remained_l]["sigma"]: + # extra_kept.add(candidate_lb_type) + # break remained_label_type |= extra_kept else: remained_label_type = {selected_label_type[0]} @@ -1880,10 +4270,10 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre # 2 # exclude_label_type = set() - # if len(tag_kinds) > 1: + # if len(set_kinds) > 1: # for go_l, this_label in enumerate(labels): - # for this_tag in tag_kinds[1:]: - # if vertices[go_l] in self.tagged_vertices[this_tag]: + # for this_tag in set_kinds[1:]: + # if vertices_set[go_l] in self.tagged_vertices[this_tag]: # exclude_label_type.add(this_label) # break # exclude_label_type = sorted(exclude_label_type) @@ -1901,28 +4291,37 @@ def filter_by_coverage(self, drop_num=1, database_n="embplant_pt", hard_cov_thre # check_ex += 1 candidate_dropping_label_type = {l_t: inf for l_t in set(range(cluster_num)) - remained_label_type} - for lab_tp in candidate_dropping_label_type: + for lab_tp in list(candidate_dropping_label_type): check_mu = parameters[lab_tp]["mu"] check_sigma = parameters[lab_tp]["sigma"] for remained_l in remained_label_type: rem_mu = parameters[remained_l]["mu"] rem_sigma = parameters[remained_l]["sigma"] this_dist = abs(rem_mu - check_mu) - 2 * (check_sigma + rem_sigma) + # does not help - 2023-01-17 + # if database_n == "embplant_mt": + # this_dist = abs(rem_mu - check_mu) - 2 * (check_sigma + rem_sigma) + # else: + # if rem_mu - check_mu < rem_sigma: + # del candidate_dropping_label_type[lab_tp] + # break + # else: + # this_dist = rem_mu - check_mu - 2 * (check_sigma + rem_sigma) candidate_dropping_label_type[lab_tp] = min(candidate_dropping_label_type[lab_tp], this_dist) dropping_type = sorted(candidate_dropping_label_type, key=lambda x: -candidate_dropping_label_type[x]) - drop_num = max(len(tag_kinds) - 1, drop_num) + drop_num = max(len(set_kinds) - 1, drop_num) dropping_type = dropping_type[:drop_num] if debug or verbose: if log_handler: for lab_tp in dropping_type: if candidate_dropping_label_type[lab_tp] < 0: - log_handler.warning("Indistinguishable vertices " + log_handler.warning("Indistinguishable vertices_set " + str([vertices[go] for go in np.where(labels == lab_tp)[0]]) + " removed!") else: for lab_tp in dropping_type: if candidate_dropping_label_type[lab_tp] < 0: - sys.stdout.write("Warning: indistinguishable vertices " + sys.stdout.write("Warning: indistinguishable vertices_set " + str([vertices[go] for go in np.where(labels == lab_tp)[0]]) + " removed!\n") vertices_to_del = {vertices[go] for go, lb in enumerate(labels) if lb in set(dropping_type)} @@ -1952,13 +4351,16 @@ def exclude_other_hits(self, database_n): def reduce_to_subgraph(self, bait_vertices, bait_offsets=None, limit_extending_len=None, - extending_len_weighted_by_depth=False): + extending_len_weighted_by_depth=False, + verbose=False, + log_handler=None): """ :param bait_vertices: :param bait_offsets: :param limit_extending_len: - :param limit_offset_current_vertex: :param extending_len_weighted_by_depth: + :param verbose: + :param log_handler: :return: """ if bait_offsets is None: @@ -1973,7 +4375,9 @@ def reduce_to_subgraph(self, bait_vertices, bait_offsets=None, else: rm_sub_ids.append(go_sub) rm_contigs.update(vertices) - # rm vertices + # rm vertices_set + if rm_contigs and verbose and log_handler: + log_handler.info("removing clusters without baits(" + str(len(rm_contigs)) + "):" + str(rm_contigs)) self.remove_vertex(rm_contigs, update_cluster=False) # rm clusters for sub_id in rm_sub_ids[::-1]: @@ -1996,7 +4400,7 @@ def reduce_to_subgraph(self, bait_vertices, bait_offsets=None, changed = True best_explored_record[(this_v, this_e)] = (quota_len, base_cov) for (next_v, next_e), this_overlap in self.vertex_info[this_v].connections[this_e].items(): - # not the starting vertices + # not the starting vertices_set if next_v not in bait_vertices: new_quota_len = quota_len - (self.vertex_info[next_v].len - this_overlap) * \ max(1, self.vertex_info[next_v].cov / base_cov) @@ -2020,12 +4424,14 @@ def reduce_to_subgraph(self, bait_vertices, bait_offsets=None, changed = False for (this_v, this_e), quota_len in sorted(explorers.items()): # if there's any this_v active: quota_len>0 AND (not_recorded OR recorded_changed)) - if quota_len > 0 and quota_len != best_explored_record.get((this_v, this_e), None): + # TODO: test new code + # if quota_len > 0 and quota_len != best_explored_record.get((this_v, this_e), None): + if quota_len > 0 and quota_len > best_explored_record.get((this_v, this_e), 0): changed = True best_explored_record[(this_v, this_e)] = quota_len # for this_direction in (True, False): for (next_v, next_e), this_overlap in self.vertex_info[this_v].connections[this_e].items(): - # not the starting vertices + # not the starting vertices_set if next_v not in bait_vertices: new_quota_len = quota_len - (self.vertex_info[next_v].len - this_overlap) # if next_v is active: quota_len>0 AND (not_explored OR larger_len)) @@ -2067,7 +4473,7 @@ def generate_consensus_vertex(self, vertices, directions, copy_tags=True, check_ MergingHistory( [(ConsensusHistory([(v_n, v_e) for v_n, v_e in zip(vertices, directions)]), directions[0])]) new_vertex = str(self.vertex_info[vertices[0]].merging_history) - # new_vertex = "(" + "|".join(vertices) + ")" + # new_vertex = "(" + "|".join(vertices_set) + ")" self.vertex_info[new_vertex] = deepcopy(self.vertex_info[vertices[0]]) self.vertex_info[new_vertex].name = new_vertex self.vertex_info[new_vertex].cov = sum([self.vertex_info[v].cov for v in vertices]) @@ -2076,13 +4482,13 @@ def generate_consensus_vertex(self, vertices, directions, copy_tags=True, check_ # del self.vertex_info[new_vertex]["long"] # self.merging_history[new_vertex] = set() - # for candidate_v in vertices: + # for candidate_v in vertices_set: # if candidate_v in self.merging_history: # for sub_v_n in self.merging_history[candidate_v]: # self.merging_history[new_vertex].add(sub_v_n) # else: # self.merging_history[new_vertex].add(candidate_v) - # for candidate_v in vertices: + # for candidate_v in vertices_set: # if candidate_v in self.merging_history: # del self.merging_history[candidate_v] @@ -2114,8 +4520,14 @@ def generate_consensus_vertex(self, vertices, directions, copy_tags=True, check_ self.vertex_info[new_vertex].other_attr["tags"][db_n] \ = deepcopy(self.vertex_info[other_vertex].other_attr["tags"][db_n]) else: - self.vertex_info[new_vertex].other_attr["tags"][db_n] \ - |= self.vertex_info[other_vertex].other_attr["tags"][db_n] + # adjust for update in 2023-01-13 + for ln, lw in self.vertex_info[other_vertex].other_attr["tags"][db_n].items(): + if ln not in self.vertex_info[new_vertex].other_attr["tags"][db_n]: + self.vertex_info[new_vertex].other_attr["tags"][db_n][ln] = lw + else: + self.vertex_info[new_vertex].other_attr["tags"][db_n][ln] += lw + # self.vertex_info[new_vertex].other_attr["tags"][db_n] \ + # |= self.vertex_info[other_vertex].other_attr["tags"][db_n] if "weight" in self.vertex_info[other_vertex].other_attr: if "weight" not in self.vertex_info[new_vertex].other_attr: self.vertex_info[new_vertex].other_attr["weight"] \ @@ -2138,17 +4550,26 @@ def generate_consensus_vertex(self, vertices, directions, copy_tags=True, check_ else: log_handler.info("Consensus made: " + new_vertex + "\n") - def processing_polymorphism(self, database_name, limited_vertices=None, + def processing_polymorphism(self, database_name, average_depth=None, limited_vertices=None, contamination_depth=3., contamination_similarity=0.95, degenerate=False, degenerate_depth=1.5, degenerate_similarity=0.98, warning_count=4, only_keep_max_cov=False, verbose=False, debug=False, log_handler=None): + if average_depth is None: + tagged_vs = [_v for _v in self.tagged_vertices[database_name] + if self.vertex_info[_v].other_attr.get("weight", {}).get(database_name, -1) > 0] + average_depth, ave_std = self.estimate_copy_and_depth_by_cov( + tagged_vs, debug=debug, log_handler=log_handler, + verbose=verbose, mode=database_name) + else: + average_depth = float(average_depth) + parallel_vertices_list = self.detect_parallel_vertices(limited_vertices=limited_vertices) # overlap = self.__overlap if self.__overlap else 0 if verbose or debug: if log_handler: - log_handler.info("detected parallel vertices " + str(parallel_vertices_list)) + log_handler.info("detected parallel vertices_set " + str(parallel_vertices_list)) else: - sys.stdout.write("detected parallel vertices " + str(parallel_vertices_list) + "\n") + sys.stdout.write("detected parallel vertices_set " + str(parallel_vertices_list) + "\n") degenerate_depth = abs(log(degenerate_depth)) contamination_depth = abs(log(contamination_depth)) @@ -2271,7 +4692,7 @@ def processing_polymorphism(self, database_name, limited_vertices=None, contaminating_cov = np.array([self.vertex_info[con_v].cov for con_v in removing_contaminating_v]) contaminating_weight = np.array([len(self.vertex_info[con_v].seq[True]) for con_v in removing_contaminating_v]) - for candidate_rm_v in removing_contaminating_v: + for candidate_rm_v in list(removing_contaminating_v): # fixed in 2022-12-30 if candidate_rm_v in self.tagged_vertices[database_name]: removing_contaminating_v.remove(candidate_rm_v) self.remove_vertex(removing_contaminating_v) @@ -2287,112 +4708,126 @@ def processing_polymorphism(self, database_name, limited_vertices=None, self.remove_vertex(removing_below_cut_off) if verbose or debug: if log_handler: - log_handler.info("removing contaminating vertices: " + " ".join(list(removing_contaminating_v))) - log_handler.info("removing contaminating-like vertices: " + " ".join(list(removing_below_cut_off))) + log_handler.info("removing contaminating vertices_set: " + " ".join(list(removing_contaminating_v))) + log_handler.info("removing contaminating-like vertices_set: " + " ".join(list(removing_below_cut_off))) else: sys.stdout.write( - "removing contaminating vertices: " + " ".join(list(removing_contaminating_v)) + "\n") + "removing contaminating vertices_set: " + " ".join(list(removing_contaminating_v)) + "\n") sys.stdout.write( - "removing contaminating-like vertices: " + " ".join(list(removing_below_cut_off)) + "\n") + "removing contaminating-like vertices_set: " + " ".join(list(removing_below_cut_off)) + "\n") if removing_irrelevant_v: for candidate_rm_v in list(removing_irrelevant_v): if candidate_rm_v in self.tagged_vertices[database_name]: removing_irrelevant_v.remove(candidate_rm_v) - self.remove_vertex(removing_irrelevant_v) - if verbose or debug: - if log_handler: - log_handler.info("removing parallel vertices: " + " ".join(list(removing_irrelevant_v))) - else: - sys.stdout.write("removing parallel vertices: " + " ".join(list(removing_irrelevant_v)) + "\n") + if removing_irrelevant_v: + self.remove_vertex(removing_irrelevant_v) + if verbose or debug: + # if log_handler: + log_handler.info("removing parallel vertices_set: " + " ".join(list(removing_irrelevant_v))) + # else: + # sys.stdout.write("removing parallel vertices_set: " + " ".join(list(removing_irrelevant_v)) + "\n") if count_contamination_or_degenerate >= warning_count: - if log_handler: - log_handler.warning("The graph might suffer from contamination or polymorphism!") - if count_using_only_max: - log_handler.warning("Only the contig with the max cov was kept for each of those " + - str(count_using_only_max) + " polymorphic loci.") - else: - sys.stdout.write("Warning: The graph might suffer from contamination or polymorphism!") - if count_using_only_max: - sys.stdout.write("Warning: Only the contig with the max cov was kept for each of those " + - str(count_using_only_max) + " polymorphic loci.\n") - - def find_target_graph(self, tab_file, database_name, mode="embplant_pt", type_factor=3, weight_factor=100.0, - max_contig_multiplicity=8, min_sigma_factor=0.1, expected_max_size=inf, expected_min_size=0, - hard_cov_threshold=10., contamination_depth=3., contamination_similarity=0.95, + # if log_handler: + log_handler.warning("The graph might suffer from contamination or polymorphism!") + if count_using_only_max and removing_irrelevant_v: + log_handler.warning("Only the contig with the max cov was kept for each of those " + + str(count_using_only_max) + " polymorphic loci.") + # else: + # sys.stdout.write("Warning: The graph might suffer from contamination or polymorphism!") + # if count_using_only_max: + # sys.stdout.write("Warning: Only the contig with the max cov was kept for each of those " + + # str(count_using_only_max) + " polymorphic loci.\n") + + def find_target_graph(self, + # tab_file, + db_name, + mode="embplant_pt", + # type_factor=3, + weight_factor=100.0, + min_sigma_factor=0.1, + expected_max_size=inf, + expected_min_size=0, + # max_contig_multiplicity=8, + hard_cov_threshold=5., contamination_depth=3., contamination_similarity=0.95, degenerate=True, degenerate_depth=1.5, degenerate_similarity=0.98, only_keep_max_cov=True, min_single_copy_percent=50, meta=False, - broken_graph_allowed=False, temp_graph=None, verbose=True, + broken_graph_allowed=False, + selected_graph=None, + temp_graph=None, verbose=True, read_len_for_log=None, kmer_for_log=None, - log_handler=None, debug=False): + log_handler=None, debug=False, + random_obj=None, + ): """ - :param tab_file: - :param database_name: + :param db_name: :param mode: - :param type_factor: :param weight_factor: - :param max_contig_multiplicity: :param min_sigma_factor: :param expected_max_size: :param expected_min_size: :param hard_cov_threshold: - :param contamination_depth: - :param contamination_similarity: + :param contamination_depth: for processing polymorphism + :param contamination_similarity: for processing polymorphism :param degenerate: :param degenerate_depth: :param degenerate_similarity: :param only_keep_max_cov: :param min_single_copy_percent: [0-100] :param broken_graph_allowed: + :param selected_graph: :param temp_graph: :param verbose: :param read_len_for_log: :param kmer_for_log: :param log_handler: :param debug: + :param random_obj: :return: """ # overlap = self.__overlap if self.__overlap else 0 - def log_target_res(final_res_combinations_inside): - echo_graph_id = int(bool(len(final_res_combinations_inside) - 1)) - for go_res, final_res_one in enumerate(final_res_combinations_inside): - this_graph = final_res_combinations_inside[go_res]["graph"] - this_k_cov = round(final_res_combinations_inside[go_res]["cov"], 3) - if read_len_for_log and kmer_for_log: - this_b_cov = round(this_k_cov * read_len_for_log / (read_len_for_log - kmer_for_log + 1), 3) - else: - this_b_cov = None - if log_handler: - if echo_graph_id: - log_handler.info("Graph " + str(go_res + 1)) - for vertex_set in sorted(this_graph.vertex_clusters): - copies_in_a_set = {this_graph.vertex_to_copy[v_name] for v_name in vertex_set} - if copies_in_a_set != {1}: - for in_vertex_name in sorted(vertex_set): - log_handler.info("Vertex_" + in_vertex_name + " #copy = " + - str(this_graph.vertex_to_copy.get(in_vertex_name, 1))) - cov_str = " kmer-coverage" if bool(self.uni_overlap()) else " coverage" - log_handler.info("Average " + mode + cov_str + - ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_k_cov) - if this_b_cov: - log_handler.info("Average " + mode + " base-coverage" + - ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_b_cov) - else: - if echo_graph_id: - sys.stdout.write("Graph " + str(go_res + 1) + "\n") - for vertex_set in sorted(this_graph.vertex_clusters): - copies_in_a_set = {this_graph.vertex_to_copy[v_name] for v_name in vertex_set} - if copies_in_a_set != {1}: - for in_vertex_name in sorted(vertex_set): - sys.stdout.write("Vertex_" + in_vertex_name + " #copy = " + - str(this_graph.vertex_to_copy.get(in_vertex_name, 1)) + "\n") - cov_str = " kmer-coverage" if bool(self.uni_overlap()) else " coverage" - sys.stdout.write("Average " + mode + cov_str + - ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_k_cov + "\n") - if this_b_cov: - sys.stdout.write("Average " + mode + " base-coverage" + ("(" + str(go_res + 1) + ")") * - echo_graph_id + " = " + "%.1f" % this_b_cov + "\n") + # def log_target_res(final_res_combinations_inside): + # echo_graph_id = int(bool(len(final_res_combinations_inside) - 1)) + # for go_res, final_res_one in enumerate(final_res_combinations_inside): + # this_graph = final_res_one["graph"] + # this_k_cov = round(final_res_one["cov"], 3) + # if read_len_for_log and kmer_for_log: + # this_b_cov = round(this_k_cov * read_len_for_log / (read_len_for_log - kmer_for_log + 1), 3) + # else: + # this_b_cov = None + # if log_handler: + # if echo_graph_id: + # log_handler.info("Graph " + str(go_res + 1)) + # for vertex_set in sorted(this_graph.vertex_clusters): + # copies_in_a_set = {this_graph.vertex_to_copy[v_name] for v_name in vertex_set} + # if copies_in_a_set != {1}: + # for in_vertex_name in sorted(vertex_set): + # log_handler.info("Vertex_" + in_vertex_name + " #copy = " + + # str(this_graph.vertex_to_copy.get(in_vertex_name, 1))) + # cov_str = " kmer-coverage" if bool(self.uni_overlap()) else " coverage" + # log_handler.info("Average " + mode + cov_str + + # ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_k_cov) + # if this_b_cov: + # log_handler.info("Average " + mode + " base-coverage" + + # ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_b_cov) + # else: + # if echo_graph_id: + # sys.stdout.write("Graph " + str(go_res + 1) + "\n") + # for vertex_set in sorted(this_graph.vertex_clusters): + # copies_in_a_set = {this_graph.vertex_to_copy[v_name] for v_name in vertex_set} + # if copies_in_a_set != {1}: + # for in_vertex_name in sorted(vertex_set): + # sys.stdout.write("Vertex_" + in_vertex_name + " #copy = " + + # str(this_graph.vertex_to_copy.get(in_vertex_name, 1)) + "\n") + # cov_str = " kmer-coverage" if bool(self.uni_overlap()) else " coverage" + # sys.stdout.write("Average " + mode + cov_str + + # ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_k_cov + "\n") + # if this_b_cov: + # sys.stdout.write("Average " + mode + " base-coverage" + ("(" + str(go_res + 1) + ")") * + # echo_graph_id + " = " + "%.1f" % this_b_cov + "\n") + if random_obj is None: + import random as random_obj if temp_graph: if temp_graph.endswith(".gfa"): @@ -2415,51 +4850,113 @@ def add_temp_id(old_temp_file, extra_str=""): else: return old_temp_file + extra_str - def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): + def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, step_tag): if _temp_graph: - tmp_graph_1 = add_temp_id(_temp_graph, ".%02d.%02d" % (count_all_temp[0], go_id)) - tmp_csv_1 = add_temp_id(_temp_csv, ".%02d.%02d" % (count_all_temp[0], go_id)) + tmp_graph_1 = add_temp_id(_temp_graph, ".%02d.%s" % (count_all_temp[0], step_tag)) + tmp_csv_1 = add_temp_id(_temp_csv, ".%02d.%s" % (count_all_temp[0], step_tag)) if verbose: if log_handler: - log_handler.info("Writing out temp graph (%d): %s" % (go_id, tmp_graph_1)) + log_handler.info("Writing out temp graph (%s): %s" % (step_tag, tmp_graph_1)) else: - sys.stdout.write("Writing out temp graph (%d): %s" % (go_id, tmp_graph_1) + "\n") + sys.stdout.write("Writing out temp graph (%s): %s" % (step_tag, tmp_graph_1) + "\n") _assembly.write_to_gfa(tmp_graph_1) - _assembly.write_out_tags([_database_name], tmp_csv_1) + if _database_name in ("embplant_pt", "embplant_mt"): + _database_name = ["embplant_pt", "embplant_mt"] + else: + _database_name = [_database_name] + _assembly.write_out_tags(_database_name, tmp_csv_1) count_all_temp[0] += 1 + def write_selected(_assembly, _selected_graph): + if _selected_graph is None: + pass + else: + # write out selected graph + log_handler.info("Output selected graph file " + str(_selected_graph)) + _assembly.write_to_gfa(_selected_graph) + # db_name can be different from the "mode" when the mode is anonym + if db_name in ("embplant_pt", "embplant_mt"): + _this_dbs = ["embplant_pt", "embplant_mt"] + else: + _this_dbs = [db_name] + _assembly.write_out_tags(_this_dbs, _selected_graph[:-3] + "csv") + + def check_remaining_singleton(): + if len(new_assembly.vertex_info) == 0: + raise ProcessingGraphFailed("Too strict criteria removing all contigs in an insufficient graph") + elif len(new_assembly.vertex_info) == 1: + the_only_v = list(new_assembly.vertex_info)[0] + if the_only_v in new_assembly.tagged_vertices[db_name]: + if new_assembly.vertex_info[the_only_v].is_self_loop() or broken_graph_allowed: + return True + else: + if verbose: + raise ProcessingGraphFailed("Linear graph: " + the_only_v + "! # tags: " + + str(new_assembly.vertex_info[the_only_v].other_attr. + get("tags", {db_name: ""})[db_name])) + else: + raise ProcessingGraphFailed("Linear graph") + + def gen_contigs_with_no_connections(): + if verbose and log_handler: + log_handler.info("Removing all connections and generate the output ..") + for _del_v_con in new_assembly.vertex_info: + new_assembly.vertex_info[_del_v_con].connections = {True: OrderedDict(), + False: OrderedDict()} + # new_assembly.merge_all_possible_vertices() + new_assembly.update_vertex_clusters() + new_assembly.copy_to_vertex = {1: set(new_assembly.vertex_info)} + new_assembly.vertex_to_copy = {v_n: 1 for v_n in new_assembly.vertex_info} + return [{"graph": new_assembly, + "cov": np.average([v_info.cov for foo, v_info in new_assembly.vertex_info.items()])}] + if broken_graph_allowed and not meta: weight_factor = 10000. - if meta: - try: - self.parse_tab_file( - tab_file, database_name=database_name, type_factor=type_factor, log_handler=log_handler) - except ProcessingGraphFailed: - return [] - else: - self.parse_tab_file(tab_file, database_name=database_name, type_factor=type_factor, log_handler=log_handler) + # if meta: + # try: + # self.parse_tab_file( + # tab_file, + # database_name=db_name, + # type_factor=type_factor, + # max_gene_gap=250, + # max_cov_diff=hard_cov_threshold, + # verbose=verbose, + # log_handler=log_handler) + # except ProcessingGraphFailed: + # return [] + # else: + # self.parse_tab_file( + # tab_file, + # database_name=db_name, + # type_factor=type_factor, + # max_gene_gap=250, + # max_cov_diff=hard_cov_threshold, # contamination_depth? + # verbose=verbose, + # log_handler=log_handler) + new_assembly = deepcopy(self) is_reasonable_res = False data_contains_outlier = False try: + # if True: while not is_reasonable_res: is_reasonable_res = True if verbose or debug: if log_handler: - log_handler.info("tagged vertices: " + str(sorted(new_assembly.tagged_vertices[database_name]))) + log_handler.info("tagged vertices_set: " + str(sorted(new_assembly.tagged_vertices[db_name]))) log_handler.info("tagged coverage: " + str(["%.1f" % new_assembly.vertex_info[log_v].cov - for log_v in sorted(new_assembly.tagged_vertices[database_name])])) + for log_v in sorted(new_assembly.tagged_vertices[db_name])])) else: - sys.stdout.write("tagged vertices: " + str(sorted(new_assembly.tagged_vertices[database_name])) + sys.stdout.write("tagged vertices_set: " + str(sorted(new_assembly.tagged_vertices[db_name])) + "\n") sys.stdout.write("tagged coverage: " + str(["%.1f" % new_assembly.vertex_info[log_v].cov - for log_v in sorted(new_assembly.tagged_vertices[database_name])]) + "\n") + for log_v in sorted(new_assembly.tagged_vertices[db_name])]) + "\n") new_assembly.merge_all_possible_vertices() - new_assembly.tag_in_between(database_n=database_name) - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 1) + new_assembly.tag_in_between() + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "a") changed = True count_large_round = 0 while changed: @@ -2477,55 +4974,70 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): # remove low coverages first_round = True delete_those_vertices = set() - parameters = [] + # parameters = [] this_del = False - new_assembly.estimate_copy_and_depth_by_cov( - new_assembly.tagged_vertices[database_name], debug=debug, log_handler=log_handler, + tagged_vs = [_v for _v in new_assembly.tagged_vertices[db_name] + if new_assembly.vertex_info[_v].other_attr.get("weight", {}).get(db_name, -1) > 0] + new_ave_cov, ave_std = new_assembly.estimate_copy_and_depth_by_cov( + tagged_vs, + min_sigma=min_sigma_factor, + debug=debug, log_handler=log_handler, verbose=verbose, mode=mode) while first_round or delete_those_vertices or this_del: if data_contains_outlier: - this_del, parameters = \ - new_assembly.filter_by_coverage(database_n=database_name, + this_del, foo_parameters = \ + new_assembly.filter_by_coverage(database_n=db_name, weight_factor=weight_factor, - hard_cov_threshold=hard_cov_threshold, + min_cov_folds=hard_cov_threshold, min_sigma_factor=min_sigma_factor, min_cluster=2, log_handler=log_handler, - verbose=verbose, debug=debug) + verbose=verbose, debug=debug, + random_obj=random_obj + ) data_contains_outlier = False if not this_del: raise ProcessingGraphFailed( "Unable to generate result with single copy vertex percentage < {}%" .format(min_single_copy_percent)) else: - this_del, parameters = \ - new_assembly.filter_by_coverage(database_n=database_name, + this_del, foo_parameters = \ + new_assembly.filter_by_coverage(database_n=db_name, weight_factor=weight_factor, - hard_cov_threshold=hard_cov_threshold, + min_cov_folds=hard_cov_threshold, min_sigma_factor=min_sigma_factor, log_handler=log_handler, verbose=verbose, - debug=debug) + debug=debug, + random_obj=random_obj + ) if verbose or debug: if log_handler: - log_handler.info("tagged vertices: " + - str(sorted(new_assembly.tagged_vertices[database_name]))) + log_handler.info("tagged vertices_set: " + + str(sorted(new_assembly.tagged_vertices[db_name]))) log_handler.info("tagged coverage: " + str(["%.1f" % new_assembly.vertex_info[log_v].cov for log_v - in sorted(new_assembly.tagged_vertices[database_name])])) + in sorted(new_assembly.tagged_vertices[db_name])])) else: - sys.stdout.write("tagged vertices: " + - str(sorted(new_assembly.tagged_vertices[database_name])) + "\n") + sys.stdout.write("tagged vertices_set: " + + str(sorted(new_assembly.tagged_vertices[db_name])) + "\n") log_handler.info("tagged coverage: " + str(["%.1f" % new_assembly.vertex_info[log_v].cov for log_v in - sorted(new_assembly.tagged_vertices[database_name])]) + "\n") - new_assembly.estimate_copy_and_depth_by_cov( - new_assembly.tagged_vertices[database_name], debug=debug, log_handler=log_handler, + sorted(new_assembly.tagged_vertices[db_name])]) + "\n") + if this_del and temp_graph: + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "b") + tagged_vs = \ + [_v for _v in new_assembly.tagged_vertices[db_name] + if new_assembly.vertex_info[_v].other_attr.get("weight", {}).get(db_name, -1) > 0] + new_ave_cov, ave_std = new_assembly.estimate_copy_and_depth_by_cov( + tagged_vs, + min_sigma=min_sigma_factor, + debug=debug, log_handler=log_handler, verbose=verbose, mode=mode) first_round = False - if new_assembly.exclude_other_hits(database_n=database_name): + if new_assembly.exclude_other_hits(database_n=db_name): changed = True cluster_trimmed = False @@ -2536,13 +5048,15 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): pass else: cluster_weights = \ - [sum([new_assembly.vertex_info[x_v].other_attr["weight"][database_name] + [sum([new_assembly.vertex_info[x_v].other_attr["weight"][db_name] for x_v in x if "weight" in new_assembly.vertex_info[x_v].other_attr and - database_name in new_assembly.vertex_info[x_v].other_attr["weight"]]) + db_name in new_assembly.vertex_info[x_v].other_attr["weight"]]) for x in new_assembly.vertex_clusters] + if verbose and log_handler: + log_handler.info("cluster_weights: " + str(cluster_weights)) best = max(cluster_weights) best_id = cluster_weights.index(best) if broken_graph_allowed: @@ -2552,12 +5066,15 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): id_remained.add(j) else: for del_v in new_assembly.vertex_clusters[j]: - if del_v in new_assembly.tagged_vertices[database_name]: + if del_v in new_assembly.tagged_vertices[db_name]: new_cov = new_assembly.vertex_info[del_v].cov - for mu, sigma in parameters: - if abs(new_cov - mu) < sigma: - id_remained.add(j) - break + # 2023-01-04 modified + if abs(new_cov - new_ave_cov) < 3 * ave_std: + id_remained.add(j) + # for mu, sigma in parameters: + # if abs(new_cov - mu) < sigma: + # id_remained.add(j) + # break if j in id_remained: break else: @@ -2567,26 +5084,39 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): del temp_cluster_weights[best_id] second = max(temp_cluster_weights) if best < second * weight_factor: - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 2) - raise ProcessingGraphFailed("Multiple isolated " + mode + " components detected! " - "Broken or contamination?") + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "c") + raise ProcessingGraphFailed( + "Multiple isolated " + mode + " components detected!") for j, w in enumerate(cluster_weights): if w == second: for del_v in new_assembly.vertex_clusters[j]: - if del_v in new_assembly.tagged_vertices[database_name]: + if del_v in new_assembly.tagged_vertices[db_name]: new_cov = new_assembly.vertex_info[del_v].cov # for debug # print(new_cov) # print(parameters) - for mu, sigma in parameters: - if abs(new_cov - mu) < sigma: - write_temp_out(new_assembly, database_name, - temp_graph, temp_csv, 3) - raise ProcessingGraphFailed( - "Complicated graph: please check around EDGE_" + del_v + "!" - "# tags: " + - str(new_assembly.vertex_info[del_v].other_attr. - get("tags", {database_name: ""})[database_name])) + # 2023-01-04 modified + if abs(new_cov - new_ave_cov) < 3 * ave_std: + raise ProcessingGraphFailed( + "Complicated graph: please check around EDGE_" + del_v + "!" + "# tags: " + + str(new_assembly.vertex_info[del_v].other_attr. + get("tags", {db_name: ""})[db_name])) + else: + if (verbose or debug) and log_handler: + log_handler.warning( + "removing tagged but low-coverage isolated contig: " + + del_v + ":" + + str(new_assembly.vertex_info[del_v].other_attr["tags"])) + # for mu, sigma in parameters: + # if abs(new_cov - mu) < sigma: + # write_temp_out(new_assembly, db_name, + # temp_graph, temp_csv, "d") + # raise ProcessingGraphFailed( + # "Complicated graph: please check around EDGE_" + del_v + "!" + # "# tags: " + + # str(new_assembly.vertex_info[del_v].other_attr. + # get("tags", {db_name: ""})[db_name])) # remove other clusters vertices_to_del = set() @@ -2603,29 +5133,68 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): cluster_trimmed = True changed = True - # merge vertices + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "e") + if check_remaining_singleton(): + break + + # merge vertices_set new_assembly.merge_all_possible_vertices() - new_assembly.tag_in_between(database_n=database_name) + new_assembly.tag_in_between() + if check_remaining_singleton(): + break # no tip contigs allowed if broken_graph_allowed: pass else: + if verbose and log_handler: + log_handler.info("Start removing terminal contigs ..") + total_weight = sum([new_assembly.vertex_info[x_v].other_attr["weight"][db_name] + for x_v in new_assembly.vertex_info + if + "weight" in new_assembly.vertex_info[x_v].other_attr + and + db_name in new_assembly.vertex_info[x_v].other_attr["weight"]]) first_round = True delete_those_vertices = set() while first_round or delete_those_vertices: first_round = False delete_those_vertices = set() - for vertex_name in new_assembly.vertex_info: + for _v_n in new_assembly.vertex_info: # both ends must have edge(s) if sum([bool(len(cn)) - for cn in new_assembly.vertex_info[vertex_name].connections.values()]) != 2: - if vertex_name in new_assembly.tagged_vertices[database_name]: - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 4) - raise ProcessingGraphFailed( - "Incomplete/Complicated graph: please check around EDGE_" + vertex_name + "!") + for cn in new_assembly.vertex_info[_v_n].connections.values()]) != 2: + # To keep a terminal vertex + # 1. tagged + # 2. normal depth (3 sigma) + # 3. enough weight + this_cov = new_assembly.vertex_info[_v_n].cov / self.vertex_to_copy.get(_v_n, 1) + if verbose and log_handler: + log_handler.info(" checking " + _v_n) + log_handler.info(" Average[std]~v_cov: " + "%.4f" % new_ave_cov + "[" + "%.4f" % ave_std + "]~" + + "%.4f" % this_cov) + if "weight" in new_assembly.vertex_info[_v_n].other_attr: + log_handler.info(" v_weight/total_weight: " + + str(new_assembly.vertex_info[_v_n].other_attr["weight"]. + get(db_name, 0.)) + "/" + str(total_weight)) + + if _v_n in new_assembly.tagged_vertices[db_name]: + if abs(new_ave_cov - this_cov) <= 3 * ave_std and \ + "weight" in new_assembly.vertex_info[_v_n].other_attr and \ + new_assembly.vertex_info[_v_n].other_attr["weight"].get(db_name, 0.) * \ + weight_factor > total_weight: + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "f") + raise ProcessingGraphFailed( + "Incomplete/Complicated graph: please check around EDGE_" + _v_n + "!") + else: + if (verbose or debug) and log_handler: + log_handler.warning( + "removing tagged but low-coverage terminal contig: " + _v_n + ":" + + str(new_assembly.vertex_info[_v_n].other_attr["tags"])) + delete_those_vertices.add(_v_n) else: - delete_those_vertices.add(vertex_name) + delete_those_vertices.add(_v_n) if delete_those_vertices: if verbose or debug: if log_handler: @@ -2635,48 +5204,66 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): "removing terminal contigs: " + str(delete_those_vertices) + "\n") new_assembly.remove_vertex(delete_those_vertices) changed = True + if check_remaining_singleton(): + break - # merge vertices + # merge vertices_set new_assembly.merge_all_possible_vertices() - new_assembly.processing_polymorphism(database_name=database_name, + if check_remaining_singleton(): + break + tagged_vs = [_v for _v in new_assembly.tagged_vertices[db_name] + if new_assembly.vertex_info[_v].other_attr.get("weight", {}).get(db_name, -1) > 0] + new_ave_cov, ave_std = new_assembly.estimate_copy_and_depth_by_cov( + tagged_vs, + min_sigma=min_sigma_factor, + debug=debug, log_handler=log_handler, + verbose=verbose, mode=mode) + new_assembly.processing_polymorphism(database_name=db_name, + average_depth=new_ave_cov, contamination_depth=contamination_depth, contamination_similarity=contamination_similarity, degenerate=False, degenerate_depth=degenerate_depth, degenerate_similarity=degenerate_similarity, verbose=verbose, debug=debug, log_handler=log_handler) - new_assembly.tag_in_between(database_n=database_name) - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 5) - - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 6) - new_assembly.processing_polymorphism(database_name=database_name, - contamination_depth=contamination_depth, - contamination_similarity=contamination_similarity, - degenerate=degenerate, degenerate_depth=degenerate_depth, - degenerate_similarity=degenerate_similarity, - warning_count=1, only_keep_max_cov=only_keep_max_cov, - verbose=verbose, debug=debug, log_handler=log_handler) - new_assembly.merge_all_possible_vertices() - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 7) + new_assembly.tag_in_between() + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "g") + + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "h") + if check_remaining_singleton(): + pass + else: + new_assembly.processing_polymorphism(database_name=db_name, + contamination_depth=contamination_depth, + contamination_similarity=contamination_similarity, + degenerate=degenerate, degenerate_depth=degenerate_depth, + degenerate_similarity=degenerate_similarity, + warning_count=1, only_keep_max_cov=only_keep_max_cov, + verbose=verbose, debug=debug, log_handler=log_handler) + new_assembly.merge_all_possible_vertices() + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "i") - # create idealized vertices and edges + # create idealized vertices_set and edges try: - new_average_cov = new_assembly.estimate_copy_and_depth_by_cov(log_handler=log_handler, - verbose=verbose, - mode="all", debug=debug) - if verbose: - if log_handler: - log_handler.info("Estimating copy and depth precisely ...") - else: - sys.stdout.write("Estimating copy and depth precisely ...\n") - final_res_combinations = new_assembly.estimate_copy_and_depth_precisely( - maximum_copy_num=max_contig_multiplicity, broken_graph_allowed=broken_graph_allowed, - return_new_graphs=True, log_handler=log_handler, - verbose=verbose, debug=debug) - if verbose: - if log_handler: - log_handler.info(str(len(final_res_combinations)) + " candidate graph(s) generated.") - else: - sys.stdout.write(str(len(final_res_combinations)) + " candidate graph(s) generated.\n") + new_average_cov, ave_std = new_assembly.estimate_copy_and_depth_by_cov( + min_sigma=min_sigma_factor, + log_handler=log_handler, + verbose=verbose, + mode="all", + debug=debug) + if check_remaining_singleton(): + final_res_combinations = [{"graph": new_assembly, "cov": new_average_cov}] + else: + final_res_combinations = new_assembly.estimate_copy_and_depth_precisely( + expected_average_cov=new_average_cov, + # broken_graph_allowed=broken_graph_allowed, + log_handler=log_handler, + verbose=verbose, debug=debug) + # maybe no more multiple results since 2022-12 gekko update + # if verbose: + # if log_handler: + # log_handler.info(str(len(final_res_combinations)) + " candidate graph(s) generated.") + # else: + # sys.stdout.write(str(len(final_res_combinations)) + " candidate graph(s) generated.\n") absurd_copy_nums = True no_single_copy = True while absurd_copy_nums: @@ -2739,19 +5326,21 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): if go_ve != keep_this: dropping_names.append(this_name) # if log_handler: - # log_handler.info("Dropping vertices " + " ".join(dropping_names)) + # log_handler.info("Dropping vertices_set " + " ".join(dropping_names)) # else: - # log_handler.info("Dropping vertices " + "".join(dropping_names) + "\n") + # log_handler.info("Dropping vertices_set " + "".join(dropping_names) + "\n") new_possible_graph.remove_vertex(dropping_names) new_possible_graph.merge_all_possible_vertices() - new_possible_graph.estimate_copy_and_depth_by_cov( + new_ave_cov, ave_std = new_possible_graph.estimate_copy_and_depth_by_cov( + min_sigma=min_sigma_factor, log_handler=log_handler, verbose=verbose, mode="all", debug=debug) + final_res_combinations.extend( new_possible_graph.estimate_copy_and_depth_precisely( - maximum_copy_num=max_contig_multiplicity, - broken_graph_allowed=broken_graph_allowed, return_new_graphs=True, + expected_average_cov=new_ave_cov, + # broken_graph_allowed=broken_graph_allowed, log_handler=log_handler, verbose=verbose, debug=debug)) - + write_temp_out(new_possible_graph, db_name, temp_graph, temp_csv, "j") del final_res_combinations[go_graph] if not final_res_combinations and absurd_copy_nums: # if absurd_copy_nums: @@ -2761,130 +5350,165 @@ def write_temp_out(_assembly, _database_name, _temp_graph, _temp_csv, go_id): if no_single_copy: raise ProcessingGraphFailed("No single copy region?! Detecting path(s) failed!") except ImportError as e: + write_selected(_assembly=new_assembly, _selected_graph=selected_graph) raise e - except (RecursionError, Exception) as e: - if broken_graph_allowed: - unlabelled_contigs = [check_v for check_v in list(new_assembly.vertex_info) - if check_v not in new_assembly.tagged_vertices[database_name]] - if unlabelled_contigs: - if verbose or debug: - if log_handler: - log_handler.info("removing unlabelled contigs: " + str(unlabelled_contigs)) - else: - sys.stdout.write("removing unlabelled contigs: " + str(unlabelled_contigs) + "\n") - new_assembly.remove_vertex(unlabelled_contigs) - new_assembly.merge_all_possible_vertices() - else: + # except (RecursionError, Exception) as e: + # 2022-12-21 remove base class Exception + except RecursionError as e: # RecursionError is created by complex graph + unlabelled_contigs = [check_v for check_v in list(new_assembly.vertex_info) + if check_v not in new_assembly.tagged_vertices[db_name]] + connections_removed = False + if unlabelled_contigs: + if verbose or debug: + if log_handler: + log_handler.info("removing unlabelled contigs: " + str(unlabelled_contigs)) + else: + sys.stdout.write("removing unlabelled contigs: " + str(unlabelled_contigs) + "\n") + new_assembly.remove_vertex(unlabelled_contigs) + new_assembly.merge_all_possible_vertices() + write_selected(_assembly=new_assembly, _selected_graph=selected_graph) + else: + write_selected(_assembly=new_assembly, _selected_graph=selected_graph) + if broken_graph_allowed: # delete all previous connections if all present contigs are labelled for del_v_connection in new_assembly.vertex_info: new_assembly.vertex_info[del_v_connection].connections = {True: OrderedDict(), False: OrderedDict()} - new_assembly.update_vertex_clusters() - new_average_cov = new_assembly.estimate_copy_and_depth_by_cov( - re_initialize=True, log_handler=log_handler, verbose=verbose, mode="all", debug=debug) - outer_continue = False - for remove_all_connections in (False, True): - if remove_all_connections: # delete all previous connections - for del_v_connection in new_assembly.vertex_info: - new_assembly.vertex_info[del_v_connection].connections = {True: OrderedDict(), - False: OrderedDict()} + # new_assembly.update_vertex_clusters() + connections_removed = True + else: + if verbose and log_handler: + log_handler.exception("") + raise e + + new_average_cov, ave_std = new_assembly.estimate_copy_and_depth_by_cov( + re_initialize=True, min_sigma=min_sigma_factor, + log_handler=log_handler, verbose=verbose, mode="all", debug=debug) + outer_continue = False + for remove_all_connections in (False, True): + # if connections_removed and remove_all_connections: + # is_reasonable_res = False + # outer_continue = True + # break + # if remove_all_connections and not connections_removed: # delete all previous connections + # for del_v_connection in new_assembly.vertex_info: + # new_assembly.vertex_info[del_v_connection].connections = {True: OrderedDict(), + # False: OrderedDict()} + # new_assembly.merge_all_possible_vertices() + if remove_all_connections: + final_res_combinations = gen_contigs_with_no_connections() + # new_assembly.copy_to_vertex = {1: set(new_assembly.vertex_info)} + # new_assembly.vertex_to_copy = {v_n: 1 for v_n in new_assembly.vertex_info} + # final_res_combinations = [ + # {"graph": new_assembly, + # "cov": np.average([v_info.cov for foo, v_info in new_assembly.vertex_info.items()])}] + # print("new_assembly.copy_to_vertex", new_assembly.copy_to_vertex) + else: new_assembly.update_vertex_clusters() try: - here_max_copy = 1 if remove_all_connections else max_contig_multiplicity final_res_combinations = new_assembly.estimate_copy_and_depth_precisely( - maximum_copy_num=here_max_copy, broken_graph_allowed=True, return_new_graphs=True, + expected_average_cov=new_average_cov, log_handler=log_handler, verbose=verbose, debug=debug) except ImportError as e: raise e except Exception as e: if verbose or debug: - if log_handler: - log_handler.info(str(e)) + if remove_all_connections or connections_removed: + log_handler.info("Unlikely error: " + str(e)) else: - sys.stdout.write(str(e) + "\n") + log_handler.info(str(e)) continue - test_first_g = final_res_combinations[0]["graph"] - if 1 in test_first_g.copy_to_vertex: - single_copy_percent = sum([test_first_g.vertex_info[s_v].len - for s_v in test_first_g.copy_to_vertex[1]]) \ - / float(sum([test_first_g.vertex_info[a_v].len - for a_v in test_first_g.vertex_info])) - if single_copy_percent < 0.5: - if verbose: - if log_handler: - log_handler.warning( - "Result with single copy vertex percentage < 50% is " - "unacceptable, continue dropping suspicious vertices ...") - else: - sys.stdout.write( - "Warning: Result with single copy vertex percentage < 50% is " - "unacceptable, continue dropping suspicious vertices ...") - data_contains_outlier = True - is_reasonable_res = False - outer_continue = True - break - else: - log_target_res(final_res_combinations) - return final_res_combinations - else: + test_first_g = final_res_combinations[0]["graph"] + if 1 in test_first_g.copy_to_vertex: + single_copy_percent = sum([test_first_g.vertex_info[s_v].len + for s_v in test_first_g.copy_to_vertex[1]]) \ + / float(sum([test_first_g.vertex_info[a_v].len + for a_v in test_first_g.vertex_info])) + if single_copy_percent < 0.5: if verbose: if log_handler: - log_handler.warning("Result with single copy vertex percentage < 50% is " - "unacceptable, continue dropping suspicious vertices ...") + log_handler.warning( + "Result with single copy vertex percentage < 50% is " + "unacceptable, continue dropping suspicious vertices_set ...") else: - sys.stdout.write("Warning: Result with single copy vertex percentage < 50% is " - "unacceptable, continue dropping suspicious vertices ...") + sys.stdout.write( + "Warning: Result with single copy vertex percentage < 50% is " + "unacceptable, continue dropping suspicious vertices_set ...") data_contains_outlier = True is_reasonable_res = False outer_continue = True break - if outer_continue: - continue - elif temp_graph: - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 8) - raise ProcessingGraphFailed("Complicated " + mode + " graph! Detecting path(s) failed!") - else: - if verbose and log_handler: - log_handler.exception("") - raise e - else: + else: + log_target_res(final_res_combinations) + return final_res_combinations + else: + if verbose: + if log_handler: + log_handler.warning("Result with single copy vertex percentage < 50% is " + "unacceptable, continue dropping suspicious vertices_set ...") + else: + sys.stdout.write("Warning: Result with single copy vertex percentage < 50% is " + "unacceptable, continue dropping suspicious vertices_set ...") + data_contains_outlier = True + is_reasonable_res = False + outer_continue = True + break + if outer_continue: + continue + else: # no error raised during graph cleaning + write_selected(_assembly=new_assembly, _selected_graph=selected_graph) test_first_g = final_res_combinations[0]["graph"] if 1 in test_first_g.copy_to_vertex or min_single_copy_percent == 0: single_copy_percent = sum([test_first_g.vertex_info[s_v].len for s_v in test_first_g.copy_to_vertex[1]]) \ / float(sum([test_first_g.vertex_info[a_v].len for a_v in test_first_g.vertex_info])) - if single_copy_percent < min_single_copy_percent / 100.: - if verbose: - if log_handler: - log_handler.warning("Result with single copy vertex percentage < {}% is " - "unacceptable, continue dropping suspicious vertices ..." - .format(min_single_copy_percent)) - else: - sys.stdout.write("Warning: Result with single copy vertex percentage < {}% is " - "unacceptable, continue dropping suspicious vertices ..." - .format(min_single_copy_percent)) - data_contains_outlier = True - is_reasonable_res = False - continue - else: - log_target_res(final_res_combinations) - return final_res_combinations else: + single_copy_percent = 0. + if single_copy_percent < min_single_copy_percent / 100.: if verbose: if log_handler: log_handler.warning("Result with single copy vertex percentage < {}% is " - "unacceptable, continue dropping suspicious vertices ..." + "unacceptable, continue dropping suspicious vertices_set ..." .format(min_single_copy_percent)) else: sys.stdout.write("Warning: Result with single copy vertex percentage < {}% is " - "unacceptable, continue dropping suspicious vertices ..." + "unacceptable, continue dropping suspicious vertices_set ..." .format(min_single_copy_percent)) - data_contains_outlier = True - is_reasonable_res = False - continue + if broken_graph_allowed: + tagged_vs = \ + [_v for _v in new_assembly.tagged_vertices[db_name] + if new_assembly.vertex_info[_v].other_attr.get("weight", {}).get(db_name, -1) > 0] + new_ave_cov, ave_std = new_assembly.estimate_copy_and_depth_by_cov( + tagged_vs, + min_sigma=min_sigma_factor, + debug=debug, log_handler=log_handler, + verbose=verbose, mode=mode) + del_vs = [] + for vertex_name in new_assembly.vertex_info: + if new_ave_cov - new_assembly.vertex_info[vertex_name].cov > 2 * ave_std or \ + new_ave_cov > 3 * new_assembly.vertex_info[vertex_name].cov: + del_vs.append(vertex_name) + if del_vs: + new_assembly.remove_vertex(del_vs) + is_reasonable_res = False + continue + else: + return gen_contigs_with_no_connections() + else: + data_contains_outlier = True + is_reasonable_res = False + continue + else: + log_target_res(final_res_combinations, + log_handler=log_handler, + read_len_for_log=read_len_for_log, + kmer_for_log=kmer_for_log, + universal_overlap=bool(self.uni_overlap()), + mode=mode) + return final_res_combinations except KeyboardInterrupt as e: - write_temp_out(new_assembly, database_name, temp_graph, temp_csv, 9) + write_temp_out(new_assembly, db_name, temp_graph, temp_csv, "k") if log_handler: log_handler.exception("") raise e @@ -2899,17 +5523,17 @@ def peel_subgraph(self, subgraph, mode="", log_handler=None, verbose=False): limited_vertices = set(self.vertex_info) & set(subgraph_vertices) if not limited_vertices: if log_handler: - log_handler.warning("No overlapped vertices found for peeling!") + log_handler.warning("No overlapped vertices_set found for peeling!") else: - sys.stdout.write("No overlapped vertices found for peeling!\n") + sys.stdout.write("No overlapped vertices_set found for peeling!\n") if verbose: if log_handler: - log_handler.warning("graph vertices: " + str(sorted(self.vertex_info))) - log_handler.warning("subgraph vertices: " + str(sorted(subgraph.vertex_info))) + log_handler.warning("graph vertices_set: " + str(sorted(self.vertex_info))) + log_handler.warning("subgraph vertices_set: " + str(sorted(subgraph.vertex_info))) else: - sys.stdout.write("graph vertices: " + str(sorted(self.vertex_info))) - sys.stdout.write("subgraph vertices: " + str(sorted(subgraph.vertex_info))) - average_cov = self.estimate_copy_and_depth_by_cov( + sys.stdout.write("graph vertices_set: " + str(sorted(self.vertex_info))) + sys.stdout.write("subgraph vertices_set: " + str(sorted(subgraph.vertex_info))) + average_cov, ave_std = self.estimate_copy_and_depth_by_cov( limited_vertices, mode=mode, re_initialize=True, verbose=verbose) vertices_peeling_ratios = {} checked = set() @@ -3020,14 +5644,19 @@ def add_gap_nodes_with_spades_res(self, scaffold_fasta, scaffold_paths, min_cov= return gap_added def get_all_circular_paths(self, mode="embplant_pt", - library_info=None, log_handler=None, reverse_start_direction_for_pt=False): - + library_info=None, + log_handler=None, + reverse_start_direction_for_pt=False, + max_paths_num=inf): + # import time + # count_time = [0.] + # count_search = [0] def circular_directed_graph_solver(ongoing_path, next_connections, vertices_left, check_all_kinds, palindromic_repeat_vertices): - # print("-----------------------------") - # print("ongoing_path", ongoing_path) - # print("next_connect", next_connections) - # print("vertices_lef", vertices_left) + # flush_str = "valid/searching: " + str(len(paths)) + "/" + str(count_search[0]) + # sys.stdout.write(flush_str + "\b" * len(flush_str)) + # sys.stdout.flush() + # count_search[0] += 1 if not vertices_left: new_path = deepcopy(ongoing_path) if palindromic_repeat_vertices: @@ -3055,8 +5684,22 @@ def circular_directed_graph_solver(ongoing_path, next_connections, vertices_left return for next_vertex, next_end in next_connections: - # print("next_vertex", next_vertex) + if len(paths) >= max_paths_num: + return + # print("ongoing_path", ongoing_path) + # print("next_vertex", next_vertex, next_connections) + # print("vertices_left", vertices_left) + # input() if next_vertex in vertices_left: + # to speed up + if vertices_left[next_vertex] == 1 and len(next_connections) >= 2: + # len(next_connections) >= 2 actually makes no big difference here according to a single test + # maybe add some threshold to do following calculation, e.g. left copies numbers ... + # time0 = time.time() + # costs very limited time + if not self.check_connected(set(vertices_left) - {next_vertex}): + # count_time[0] += time.time() - time0 + continue new_path = deepcopy(ongoing_path) new_left = deepcopy(vertices_left) new_path.append((next_vertex, not next_end)) @@ -3096,12 +5739,13 @@ def circular_directed_graph_solver(ongoing_path, next_connections, vertices_left else: new_connect_list = sorted(new_connections) # if next_connections is SSC, reorder - if mode == "embplant_pt" and len(new_connect_list) == 2 and new_connect_list[0][0] == \ - new_connect_list[1][0]: - new_connect_list.sort( - key=lambda x: -self.vertex_info[x[0]].other_attr["orf"][x[1]]["sum_len"]) + # if mode == "embplant_pt" and len(new_connect_list) == 2 and new_connect_list[0][0] == \ + # new_connect_list[1][0]: + # new_connect_list.sort( + # key=lambda x: -self.vertex_info[x[0]].other_attr["orf"][x[1]]["sum_len"]) circular_directed_graph_solver(new_path, new_connect_list, new_left, check_all_kinds, palindromic_repeat_vertices) + return # for palindromic repeats palindromic_repeats = set() @@ -3157,6 +5801,7 @@ def circular_directed_graph_solver(ongoing_path, next_connections, vertices_left del vertex_to_copy[start_vertex] circular_directed_graph_solver(first_path, first_connections, vertex_to_copy, do_check_all_start_kinds, palindromic_repeats) + # log_handler.info("check_connected costs: " + str(count_time[0])) if not paths: raise ProcessingGraphFailed("Detecting path(s) from remaining graph failed!") @@ -3333,40 +5978,75 @@ def reseed_a_path(input_path, input_unique_vertex): "simply different in SSC direction (two flip-flop configurations)!\n") return sorted_paths - def get_all_paths(self, mode="embplant_pt", log_handler=None): - - def standardize_paths(raw_paths, undirected_vertices): + def standardize_paths(self, raw_paths, undirected_vertices={}, only_res=True): + if undirected_vertices: + corrected_paths = [[(this_v, True) if this_v in undirected_vertices else (this_v, this_e) + for this_v, this_e in path_part] + for path_part in raw_paths] + else: + corrected_paths = deepcopy(raw_paths) + here_standardized_path = [] + for part_path in corrected_paths: if undirected_vertices: - corrected_paths = [[(this_v, True) if this_v in undirected_vertices else (this_v, this_e) - for this_v, this_e in path_part] - for path_part in raw_paths] + rev_part = [(this_v, True) if this_v in undirected_vertices else (this_v, not this_e) + for this_v, this_e in part_path[::-1]] else: - corrected_paths = deepcopy(raw_paths) - here_standardized_path = [] - for part_path in corrected_paths: - if undirected_vertices: - rev_part = [(this_v, True) if this_v in undirected_vertices else (this_v, not this_e) - for this_v, this_e in part_path[::-1]] - else: - rev_part = [(this_v, not this_e) for this_v, this_e in part_path[::-1]] - if (part_path[0][0], not part_path[0][1]) \ - in self.vertex_info[part_path[-1][0]].connections[part_path[-1][1]]: - # circular - this_part_derived = [part_path, rev_part] - for change_start in range(1, len(part_path)): - this_part_derived.append(part_path[change_start:] + part_path[:change_start]) - this_part_derived.append(rev_part[change_start:] + rev_part[:change_start]) - try: - standard_part = tuple(sorted(this_part_derived, key=lambda x: smart_trans_for_sort(x))[0]) - except TypeError: - for j in this_part_derived: - print(j) - exit() - else: - standard_part = tuple(sorted([part_path, rev_part], key=lambda x: smart_trans_for_sort(x))[0]) - here_standardized_path.append(standard_part) + rev_part = [(this_v, not this_e) for this_v, this_e in part_path[::-1]] + if (part_path[0][0], not part_path[0][1]) \ + in self.vertex_info[part_path[-1][0]].connections[part_path[-1][1]]: + # circular + this_part_derived = [part_path, rev_part] + for change_start in range(1, len(part_path)): + this_part_derived.append(part_path[change_start:] + part_path[:change_start]) + this_part_derived.append(rev_part[change_start:] + rev_part[:change_start]) + # try: + standard_part = tuple(sorted(this_part_derived, key=lambda x: smart_trans_for_sort(x))[0]) + # except TypeError: + # for j in this_part_derived: + # print(j) + # exit() + else: + standard_part = tuple(sorted([part_path, rev_part], key=lambda x: smart_trans_for_sort(x))[0]) + here_standardized_path.append(standard_part) + if only_res: + return tuple(sorted(here_standardized_path, key=lambda x: smart_trans_for_sort(x))) + else: return corrected_paths, tuple(sorted(here_standardized_path, key=lambda x: smart_trans_for_sort(x))) + def get_all_paths(self, mode="embplant_pt", max_paths_num=inf, log_handler=None): + + # def standardize_paths(raw_paths, undirected_vertices): + # if undirected_vertices: + # corrected_paths = [[(this_v, True) if this_v in undirected_vertices else (this_v, this_e) + # for this_v, this_e in path_part] + # for path_part in raw_paths] + # else: + # corrected_paths = deepcopy(raw_paths) + # here_standardized_path = [] + # for part_path in corrected_paths: + # if undirected_vertices: + # rev_part = [(this_v, True) if this_v in undirected_vertices else (this_v, not this_e) + # for this_v, this_e in part_path[::-1]] + # else: + # rev_part = [(this_v, not this_e) for this_v, this_e in part_path[::-1]] + # if (part_path[0][0], not part_path[0][1]) \ + # in self.vertex_info[part_path[-1][0]].connections[part_path[-1][1]]: + # # circular + # this_part_derived = [part_path, rev_part] + # for change_start in range(1, len(part_path)): + # this_part_derived.append(part_path[change_start:] + part_path[:change_start]) + # this_part_derived.append(rev_part[change_start:] + rev_part[:change_start]) + # # try: + # standard_part = tuple(sorted(this_part_derived, key=lambda x: smart_trans_for_sort(x))[0]) + # # except TypeError: + # # for j in this_part_derived: + # # print(j) + # # exit() + # else: + # standard_part = tuple(sorted([part_path, rev_part], key=lambda x: smart_trans_for_sort(x))[0]) + # here_standardized_path.append(standard_part) + # return corrected_paths, tuple(sorted(here_standardized_path, key=lambda x: smart_trans_for_sort(x))) + def directed_graph_solver(ongoing_paths, next_connections, vertices_left, in_all_start_ve, undirected_vertices): # print("-----------------------------") # print("ongoing_path", ongoing_path) @@ -3374,14 +6054,36 @@ def directed_graph_solver(ongoing_paths, next_connections, vertices_left, in_all # print("vertices_lef", vertices_left) # print("vertices_lef", len(vertices_left)) if not vertices_left: - new_paths, new_standardized = standardize_paths(ongoing_paths, undirected_vertices) + new_paths, new_standardized = self.standardize_paths(ongoing_paths, undirected_vertices, False) if new_standardized not in paths_set: paths.append(new_paths) paths_set.add(new_standardized) return + # to speed up under max_paths_num, rank choices by complete, then incomplete + if len(next_connections) >= 2: + go_n = 0 + incomplete_choices = [] + while go_n < len(next_connections): + next_vertex, next_end = next_connections[go_n] + if next_vertex in vertices_left: + # maybe add some threshold to do following calculation to speed up, e.g. left copies numbers ... + if vertices_left[next_vertex] > 1 or \ + not self.check_connected(set(vertices_left) - {next_vertex}): + incomplete_choices.append(next_connections.pop(go_n)) + else: + go_n += 1 + else: + del next_connections[go_n] + next_connections.extend(incomplete_choices) + find_next = False for next_vertex, next_end in next_connections: + # print("next_vertex, next_end: {}, {}".format(next_vertex, next_end)) + # print("vertices_left: {}".format(vertices_left)) + # input() + if len(paths) >= max_paths_num: + return # print("next_vertex", next_vertex, next_end) if next_vertex in vertices_left: find_next = True @@ -3393,21 +6095,26 @@ def directed_graph_solver(ongoing_paths, next_connections, vertices_left, in_all del new_left[next_vertex] new_connect_list = sorted(self.vertex_info[next_vertex].connections[not next_end]) if not new_left: - new_paths, new_standardized = standardize_paths(new_paths, undirected_vertices) + new_paths, new_standardized = self.standardize_paths(new_paths, undirected_vertices, False) if new_standardized not in paths_set: paths.append(new_paths) paths_set.add(new_standardized) return else: - if mode == "embplant_pt" and len(new_connect_list) == 2 and new_connect_list[0][0] == \ - new_connect_list[1][0]: - new_connect_list.sort( - key=lambda x: self.vertex_info[x[0]].other_attr["orf"][x[1]]["sum_len"]) + # if mode == "embplant_pt" and len(new_connect_list) == 2 and new_connect_list[0][0] == \ + # new_connect_list[1][0]: + # new_connect_list.sort( + # key=lambda x: self.vertex_info[x[0]].other_attr["orf"][x[1]]["sum_len"]) + # to_print = "len(paths)={}".format(len(paths)) + # sys.stdout.write(to_print + "\b" * len(to_print)) + # sys.stdout.flush() directed_graph_solver(new_paths, new_connect_list, new_left, in_all_start_ve, undirected_vertices) if not find_next: new_all_start_ve = deepcopy(in_all_start_ve) while new_all_start_ve: + if len(paths) >= max_paths_num: + return new_start_vertex, new_start_end = new_all_start_ve.pop(0) if new_start_vertex in vertices_left: new_paths = deepcopy(ongoing_paths) @@ -3418,15 +6125,20 @@ def directed_graph_solver(ongoing_paths, next_connections, vertices_left, in_all del new_left[new_start_vertex] new_connect_list = sorted(self.vertex_info[new_start_vertex].connections[new_start_end]) if not new_left: - new_paths, new_standardized = standardize_paths(new_paths, undirected_vertices) + new_paths, new_standardized = self.standardize_paths(new_paths, undirected_vertices, False) if new_standardized not in paths_set: paths.append(new_paths) paths_set.add(new_standardized) else: - if mode == "embplant_pt" and len(new_connect_list) == 2 and new_connect_list[0][0] == \ - new_connect_list[1][0]: - new_connect_list.sort( - key=lambda x: self.vertex_info[x[0]].other_attr["orf"][x[1]]["sum_len"]) + # if mode == "embplant_pt" and len(new_connect_list) == 2 and new_connect_list[0][0] == \ + # new_connect_list[1][0]: + # new_connect_list.sort( + # key=lambda x: self.vertex_info[x[0]].other_attr["orf"][x[1]]["sum_len"]) + # to_print = "len(paths)={}".format(len(paths)) + # sys.stdout.write(to_print + "\b" * len(to_print)) + # sys.stdout.flush() + # TODO:prioritize those makes the component more connected + directed_graph_solver(new_paths, new_connect_list, new_left, new_all_start_ve, undirected_vertices) break @@ -3972,7 +6684,7 @@ def __init__(self, scaffold_fasta, scaffold_paths, assembly_obj, min_cov=0., max # trim_last == trim_this == 0 # seemingly_gap_len == real_gap_len if -real_gap_len > graph_overlap: - # if there's >kmer overlap but both of those vertices should fix + # if there's >kmer overlap but both of those vertices_set should fix # leading gap sequence to be illegal if log_handler: log_handler.warning( @@ -4046,6 +6758,10 @@ def smart_trans_for_sort(candidate_item): return all_e +def average_np_free(vals): + return sum(vals) / float(len(vals)) + + def average_weighted_np_free(vals, weights): return sum([val * weights[go_v] for go_v, val in enumerate(vals)]) / float(sum(weights)) @@ -4101,3 +6817,18 @@ def get_graph_coverages_range_simple(fasta_matrix, drop_low_percent=0.10, drop_h cov_mean, cov_std = 0., 0. coverages = [0.] return max(cov_mean - cov_std, min(coverages)), cov_mean, min(cov_mean + cov_std, max(coverages)) + + +def check_positive_value(value, flag, log_handler): + if value < 0: + if log_handler is None: + sys.stdout.write( + "Warning: illegitimate " + flag + " value " + str(value) + " adjusted to " + str(-value) + "!\n") + else: + log_handler.warning("illegitimate " + flag + " value " + str(value) + " adjusted to " + str(-value) + "!") + return -value + elif value == 0: + raise ValueError("illegitimate " + flag + " value " + str(value) + "!") + else: + return value + diff --git a/GetOrganelleLib/pipe_control_func.py b/GetOrganelleLib/pipe_control_func.py index 519edd9..79fbfb5 100755 --- a/GetOrganelleLib/pipe_control_func.py +++ b/GetOrganelleLib/pipe_control_func.py @@ -712,7 +712,7 @@ def __init__(self, sample_out_dir, prefix=""): if "circular genome" in detail_record: this_circular = "yes" elif " - WARNING: " in line and line[:4].isdigit(): - if "Degenerate base(s) used!" in line: + if "Ambiguous base(s) used!" in line: this_degenerate = "yes" if "circular" in this_record: this_record["circular"] += " & " + this_circular @@ -984,6 +984,52 @@ def get_static_html_context(remote_url, try_times=5, timeout=10, verbose=False, return {"status": False, "info": "unknown", "content": ""} +def log_target_res(final_res_combinations_inside, + log_handler=None, + read_len_for_log=None, + kmer_for_log=None, + universal_overlap=False, + mode="",): + echo_graph_id = int(bool(len(final_res_combinations_inside) - 1)) + for go_res, final_res_one in enumerate(final_res_combinations_inside): + this_graph = final_res_one["graph"] + this_k_cov = round(final_res_one["cov"], 3) + if read_len_for_log and kmer_for_log: + this_b_cov = round(this_k_cov * read_len_for_log / (read_len_for_log - kmer_for_log + 1), 3) + else: + this_b_cov = None + if log_handler: + if echo_graph_id: + log_handler.info("Graph " + str(go_res + 1)) + for vertex_set in sorted(this_graph.vertex_clusters): + copies_in_a_set = {this_graph.vertex_to_copy[v_name] for v_name in vertex_set} + if copies_in_a_set != {1}: + for in_vertex_name in sorted(vertex_set): + log_handler.info("Vertex_" + in_vertex_name + " #copy = " + + str(this_graph.vertex_to_copy.get(in_vertex_name, 1))) + cov_str = " kmer-coverage" if universal_overlap else " coverage" + log_handler.info("Average " + mode + cov_str + + ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_k_cov) + if this_b_cov: + log_handler.info("Average " + mode + " base-coverage" + + ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_b_cov) + else: + if echo_graph_id: + sys.stdout.write("Graph " + str(go_res + 1) + "\n") + for vertex_set in sorted(this_graph.vertex_clusters): + copies_in_a_set = {this_graph.vertex_to_copy[v_name] for v_name in vertex_set} + if copies_in_a_set != {1}: + for in_vertex_name in sorted(vertex_set): + sys.stdout.write("Vertex_" + in_vertex_name + " #copy = " + + str(this_graph.vertex_to_copy.get(in_vertex_name, 1)) + "\n") + cov_str = " kmer-coverage" if universal_overlap else " coverage" + sys.stdout.write("Average " + mode + cov_str + + ("(" + str(go_res + 1) + ")") * echo_graph_id + " = " + "%.1f" % this_k_cov + "\n") + if this_b_cov: + sys.stdout.write("Average " + mode + " base-coverage" + ("(" + str(go_res + 1) + ")") * + echo_graph_id + " = " + "%.1f" % this_b_cov + "\n") + + def download_file_with_progress(remote_url, output_file, log_handler=None, allow_empty=False, sha256_v=None, try_times=5, timeout=100000, alternative_url_list=None, verbose=False): time_0 = time.time() diff --git a/GetOrganelleLib/seq_parser.py b/GetOrganelleLib/seq_parser.py index d08ef9e..96d6d48 100755 --- a/GetOrganelleLib/seq_parser.py +++ b/GetOrganelleLib/seq_parser.py @@ -2,7 +2,7 @@ import sys import math import re -import random +from multiprocessing import Pool, Manager major_version, minor_version = sys.version_info[:2] if major_version == 2 and minor_version >= 7: @@ -74,6 +74,7 @@ def __init__(self, input_fasta_file=None, indexed=False): self.read_fasta(input_fasta_file) if indexed: for go_s, seq in enumerate(self.sequences): + assert seq.label not in self.__dict self.__dict[seq.label] = go_s def __len__(self): @@ -176,6 +177,43 @@ def __enumerate__(self): yield i, self.__getitem__(i) +def get_fasta_lengths( + file_path: str, + blast_form_seq_name: bool = False) -> dict: + # initialize an empty dictionary to store the sequence lengths + seq_lengths = {} + + # open the fasta file for reading + with open(file_path, "r") as input_handler: + # initialize variables to keep track of the current sequence name and length + seq_name = None + seq_len = 0 + # iterate over each line in the file + for line in input_handler: + # if the line starts with '>', it indicates a new sequence + if line.startswith(">"): + # if we have a previous sequence, store its length in the dictionary + if seq_name: + seq_lengths[seq_name] = seq_len + if blast_form_seq_name: + # get the new sequence name by splitting the line and taking the first element + seq_name = line.split()[0][1:] + else: + seq_name = line.strip()[1:] + # reset the sequence length to 0 + seq_len = 0 + else: + # if the line does not start with '>', it is part of the current sequence + # add the length of the line to the current sequence length + seq_len += len(line.strip()) + # If we have a final sequence, store its length in the dictionary + if seq_name: + seq_lengths[seq_name] = seq_len + + # Return the dictionary containing the sequence lengths + return seq_lengths + + def read_fasta(fasta_dir): names = [] seqs = [] @@ -1190,35 +1228,120 @@ def split_seq_by_quality_pattern(sequence, quality_str, low_quality_pattern, min return tuple(seq_list) -def fq_simple_generator(fq_dir_list, go_to_line=1, split_pattern=None, min_sub_seq=0, max_n_reads=float("inf")): +def fq_handler_re_seek(fq_handler, + fq_file, + seek_soft_start=None, + seek_soft_end=None): + if seek_soft_start is not None: + seek_start = fq_handler.seek(seek_soft_start) + if seek_soft_end is None: + seek_soft_end = os.path.getsize(fq_file) + line_str = fq_handler.readline() + while line_str: + if line_str.startswith("@"): + break + else: + seek_start = fq_handler.tell() + if seek_start >= seek_soft_end: + break + line_str = fq_handler.readline() + fq_handler.seek(seek_start) + return fq_handler + + +def fq_simple_generator(fq_dir_list, + go_to_line=1, + split_pattern=None, + min_sub_seq=0, + max_n_reads=float("inf"), + seek_soft_start=None, + seek_soft_end=None): + """ + :param fq_dir_list: + :param go_to_line: + :param split_pattern: + :param min_sub_seq: + :param max_n_reads: + :param seek_soft_start: starts by the start of the head line after seek_soft_start + :param seek_soft_end: ends by the end of the quality line after seek_soft_end + :return: + """ if not ((type(fq_dir_list) is list) or (type(fq_dir_list) is tuple)): fq_dir_list = [fq_dir_list] + max_n_lines = 4 * max_n_reads - if split_pattern and len(split_pattern) > 2: - for fq_dir in fq_dir_list: - count = 0 - with open(fq_dir, 'r') as fq_handler: - seq_line = fq_handler.readline() - while seq_line: - if count % 4 == go_to_line: - fq_handler.readline() - quality_str = fq_handler.readline()[:-1] - count += 2 - yield split_seq_by_quality_pattern(seq_line[:-1], quality_str, split_pattern, min_sub_seq) - count += 1 - if count >= max_n_lines: - break + if seek_soft_end is None: + if split_pattern and len(split_pattern) > 2: + for fq_dir in fq_dir_list: + count = 0 + with open(fq_dir, 'r') as fq_handler: + fq_handler = fq_handler_re_seek(fq_handler=fq_handler, fq_file=fq_dir, + seek_soft_start=seek_soft_start) seq_line = fq_handler.readline() + while seq_line: + if count % 4 == go_to_line: + fq_handler.readline() + quality_str = fq_handler.readline()[:-1] + count += 2 + yield split_seq_by_quality_pattern(seq_line[:-1], quality_str, split_pattern, min_sub_seq) + count += 1 + if count >= max_n_lines: + break + seq_line = fq_handler.readline() + else: + for fq_dir in fq_dir_list: + count = 0 + with open(fq_dir, 'r') as fq_handler: + fq_handler = fq_handler_re_seek(fq_handler=fq_handler, fq_file=fq_dir, + seek_soft_start=seek_soft_start) + for fq_line in fq_handler: + if count % 4 == go_to_line: + yield fq_line[:-1] + if count >= max_n_lines: + break + count += 1 else: - for fq_dir in fq_dir_list: - count = 0 - with open(fq_dir, 'r') as fq_handler: - for fq_line in fq_handler: - if count % 4 == go_to_line: - yield fq_line[:-1] - if count >= max_n_lines: - break - count += 1 + if split_pattern and len(split_pattern) > 2: + for fq_dir in fq_dir_list: + count = 0 + with open(fq_dir, 'r') as fq_handler: + fq_handler = fq_handler_re_seek(fq_handler=fq_handler, fq_file=fq_dir, + seek_soft_start=seek_soft_start, seek_soft_end=seek_soft_end) + if fq_handler.tell() >= seek_soft_end: + continue + seq_line = fq_handler.readline() + while seq_line: + if count % 4 == go_to_line: + fq_handler.readline() + quality_str = fq_handler.readline()[:-1] + count += 2 + yield split_seq_by_quality_pattern(seq_line[:-1], quality_str, split_pattern, min_sub_seq) + if fq_handler.tell() >= seek_soft_end: + break + count += 1 + if count >= max_n_lines: + break + seq_line = fq_handler.readline() + else: + for fq_dir in fq_dir_list: + count = 0 + with open(fq_dir, 'r') as fq_handler: + fq_handler = fq_handler_re_seek(fq_handler=fq_handler, fq_file=fq_dir, + seek_soft_start=seek_soft_start, seek_soft_end=seek_soft_end) + if fq_handler.tell() >= seek_soft_end: + continue + fq_line = fq_handler.readline() + # for fq_line in fq_handler: + # for loop will cause OSError: telling position disabled by next() call + while fq_line: + if count % 4 == go_to_line: + yield fq_line[:-1] + if fq_handler.tell() >= seek_soft_end: + break + if count >= max_n_lines: + break + count += 1 + fq_line = fq_handler.readline() def chop_seqs(seq_iter, word_size, mesh_size=1, previous_words=None): @@ -1404,7 +1527,7 @@ def get_orf_lengths(sequence_string, threshold=200, which_frame=None, def simulate_fq_simple( from_fasta_file, out_dir, out_name=None, is_circular=False, sim_read_len=100, sim_read_jump_size=None, generate_paired=False, paired_insert_size=300, generate_spot_num=None, generate_depth=None, - resume=True): + resume=True, random_obj=None): """ :param from_fasta_file: :param out_dir: @@ -1414,12 +1537,14 @@ def simulate_fq_simple( :param sim_read_jump_size: int; mutually exclusive with generate_spot_num, generate_depth; randomly off :param generate_paired: :param paired_insert_size: - :param randomly: :param generate_spot_num: int; mutually exclusive with sim_read_jump_size, generate_depth; randomly on :param generate_depth: int; mutually exclusive with sim_read_jump_size, generate_spot_num; randomly on :param resume: continue + :param random_obj :return: """ + if random_obj is None: + import random as random_obj if bool(sim_read_jump_size) + bool(generate_spot_num) + bool(generate_depth) == 0: raise Exception("One of sim_read_jump_size, generate_spot_num, generate_depth must be given!") elif bool(sim_read_jump_size) + bool(generate_spot_num) + bool(generate_depth) > 1: @@ -1476,7 +1601,7 @@ def simulate_fq_simple( cat_all_seqs.append(from_sequence) cat_all_seqs = "".join(cat_all_seqs) cat_all_seqs_rev = complementary_seq(cat_all_seqs) - chosen_start_ids = [random.choice(start_ids) for foo in range(generate_spot_num)] + chosen_start_ids = [random_obj.choice(start_ids) for foo in range(generate_spot_num)] with open(to_fq_files[0] + ".Temp", "w") as output_handler_1: with open(to_fq_files[1] + ".Temp", "w") as output_handler_2: for go_base in chosen_start_ids: @@ -1521,7 +1646,7 @@ def simulate_fq_simple( accumulated_len += len(from_sequence) cat_all_seqs.append(from_sequence) cat_all_seqs = "".join(cat_all_seqs) - chosen_start_ids = [random.choice(start_ids) for foo in range(generate_spot_num)] + chosen_start_ids = [random_obj.choice(start_ids) for foo in range(generate_spot_num)] with open(to_fq_files[0] + ".Temp", "w") as output_handler: for go_base in chosen_start_ids: output_handler.write("".join(["@", str(count_read), "\n", @@ -1533,13 +1658,203 @@ def simulate_fq_simple( os.rename(to_fq_f + ".Temp", to_fq_f) -def get_read_len_mean_max_count(fq_or_fq_files, maximum_n_reads, sampling_percent=1.): +# def get_read_len_mean_max_count(fq_or_fq_files, maximum_n_reads, sampling_percent=1., n_process=1): +# if type(fq_or_fq_files) is str: +# fq_files = [fq_or_fq_files] +# else: +# fq_files = fq_or_fq_files +# if n_process > 1: +# return __get_read_len_mean_max_count_mp(fq_files=fq_files, +# maximum_n_reads=maximum_n_reads, +# sampling_percent=sampling_percent, +# n_process=n_process) +# else: +# return __get_read_len_mean_max_count_single(fq_files=fq_files, +# maximum_n_reads=maximum_n_reads, +# sampling_percent=sampling_percent) +# +# +# def __get_read_len_mean_max_count_mp( +# fq_files, +# maximum_n_reads, +# sampling_percent=1., +# n_process=2): +# +# def open_maximum_reads(_fq_f): +# if maximum_n_reads == float("inf"): +# for _go_l, _str_line in enumerate(open(_fq_f, "r")): +# if _go_l % 4 == 1: +# if _str_line: +# yield _str_line +# else: +# break +# else: +# for _go_l, _str_line in zip(range(int(maximum_n_reads * 4)), open(_fq_f, "r")): +# if _go_l % 4 == 1: +# if _str_line: +# yield _str_line +# else: +# break +# """ set up iter arguments """ +# # arbitrary num of seqs +# block_size = int(5E6) +# """ set up variables shared by processes """ +# manager = Manager() +# read_lengths = manager.list() +# all_counts = manager.dict() +# for fq_f in fq_files: +# all_counts[fq_f] = 0 +# # lock = manager.Lock() +# """ set up the process pool """ +# pool_obj = Pool(processes=n_process) +# jobs = [] +# try: +# for fq_f in fq_files: +# seq_lines = [_str_line for _go_l, _str_line in zip(range(block_size), open_maximum_reads(_fq_f=fq_f))] +# while seq_lines: +# arg_tuple = (seq_lines, fq_f, sampling_percent, read_lengths, all_counts,) +# jobs.append(pool_obj.apply_async(__get_read_len_mean_max_count_worker, arg_tuple)) +# seq_lines = [_str_line for _go_l, _str_line in zip(range(block_size), open_maximum_reads(_fq_f=fq_f))] +# print(arg_tuple[1:]) +# pool_obj.close() +# print("pool closed") +# for go_j, job in enumerate(jobs): +# print("go_j get", go_j) +# job.get() +# pool_obj.join() +# # print("pool joined") +# except KeyboardInterrupt: +# pool_obj.terminate() +# raise KeyboardInterrupt +# return sum(read_lengths) / len(read_lengths), max(read_lengths), [all_counts[_fq_f] for _fq_f in fq_files] +# +# +# def __get_read_len_mean_max_count_worker(seq_lines, fq_file, sampling_percent, read_lengths, all_counts): +# here_r_l = [] +# if sampling_percent == 1: +# count_r = 0 +# for seq in seq_lines: +# count_r += 1 +# here_r_l.append(len(seq.strip("N"))) +# all_counts[fq_file] += count_r +# else: +# sampling_percent = int(1 / sampling_percent) +# count_r = 0 +# for seq in seq_lines: +# count_r += 1 +# if count_r % sampling_percent == 0: +# here_r_l.append(len(seq.strip("N"))) +# all_counts[fq_file] += count_r +# read_lengths.extend(here_r_l) + +def get_read_len_mean_max_count(fq_or_fq_files, maximum_n_reads, sampling_percent=1., n_process=1): if type(fq_or_fq_files) is str: - fq_or_fq_files = [fq_or_fq_files] + fq_files = [fq_or_fq_files] + else: + fq_files = fq_or_fq_files + if n_process > 1: + """ giving up: processes always jump into D (uninterruptible sleep) """ + + # """ using seek: not working probably because of NFS locking """ + # """ set up iter arguments """ + # iter_args = [[[], [], []]] + # file_sizes = [os.path.getsize(_fq_f) for _fq_f in fq_files] + # # block size for each process + # block_size = max(int(1E9), math.ceil(sum(file_sizes) / n_process)) + # accumulated_size = 0 + # for go_f, fq_f in enumerate(fq_files): + # start_seek = 0 + # while start_seek < file_sizes[go_f]: + # end_seek = start_seek + block_size - accumulated_size + # if end_seek > file_sizes[go_f]: + # iter_args[-1][0].append(fq_f) + # iter_args[-1][1].append(start_seek) + # iter_args[-1][2].append(file_sizes[go_f]) + # accumulated_size += file_sizes[go_f] - start_seek + # break + # else: + # iter_args[-1][0].append(fq_f) + # iter_args[-1][1].append(start_seek) + # iter_args[-1][2].append(end_seek) + # accumulated_size = 0 + # iter_args.append([[], [], []]) + # start_seek = end_seek + # if not iter_args[-1][0]: + # del iter_args[-1] + """ switch to one file one process """ + iter_args = [[[fq_f_], [None], [None]] for fq_f_ in fq_files] + """ set up variables shared by processes """ + manager = Manager() + read_lengths = manager.list() + all_counts = manager.dict() + for fq_f in fq_files: + all_counts[fq_f] = 0 + # lock = manager.Lock() + """ set up the process pool """ + pool_obj = Pool(processes=n_process) + jobs = [] + try: + for file_args in iter_args: + arg_tuple = tuple(file_args + [maximum_n_reads, read_lengths, all_counts, sampling_percent]) + print(arg_tuple) + jobs.append(pool_obj.apply_async(__get_read_len_mean_max_count_worker, arg_tuple)) + # print(len(iter_args)) + pool_obj.close() + # print("pool closed") + for go_j, job in enumerate(jobs): + # print("go_j get", go_j) + job.get() + pool_obj.join() + # print("pool joined") + except KeyboardInterrupt: + pool_obj.terminate() + raise KeyboardInterrupt + return sum(read_lengths) / len(read_lengths), max(read_lengths), [all_counts[_fq_f] for _fq_f in fq_files] + else: + return __get_read_len_mean_max_count_single(fq_files=fq_files, + maximum_n_reads=maximum_n_reads, + sampling_percent=sampling_percent) + +def __get_read_len_mean_max_count_worker( + fq_files, + seek_soft_starts, + seek_soft_ends, + maximum_n_reads, + read_lengths, + all_counts, + sampling_percent=1.): + """ using seek: not working probably because of NFS locking """ + here_r_l = [] + if sampling_percent == 1: + for go_f, fq_file in enumerate(fq_files): + count_r = 0 + for seq in fq_simple_generator( + fq_file, seek_soft_start=seek_soft_starts[go_f], seek_soft_end=seek_soft_ends[go_f]): + count_r += 1 + here_r_l.append(len(seq.strip("N"))) + if count_r >= maximum_n_reads: + break + all_counts[fq_file] += count_r + else: + sampling_percent = int(1 / sampling_percent) + for go_f, fq_file in enumerate(fq_files): + count_r = 0 + for seq in fq_simple_generator( + fq_file, seek_soft_start=seek_soft_starts[go_f], seek_soft_end=seek_soft_ends[go_f]): + count_r += 1 + if count_r % sampling_percent == 0: + here_r_l.append(len(seq.strip("N"))) + if count_r >= maximum_n_reads: + break + all_counts[fq_file] += count_r + read_lengths.extend(here_r_l) + + +def __get_read_len_mean_max_count_single(fq_files, maximum_n_reads, sampling_percent=1.): read_lengths = [] all_counts = [] if sampling_percent == 1: - for fq_f in fq_or_fq_files: + for fq_f in fq_files: count_r = 0 for seq in fq_simple_generator(fq_f): count_r += 1 @@ -1549,7 +1864,7 @@ def get_read_len_mean_max_count(fq_or_fq_files, maximum_n_reads, sampling_percen all_counts.append(count_r) else: sampling_percent = int(1 / sampling_percent) - for fq_f in fq_or_fq_files: + for fq_f in fq_files: count_r = 0 for seq in fq_simple_generator(fq_f): count_r += 1 diff --git a/GetOrganelleLib/statistical_func.py b/GetOrganelleLib/statistical_func.py index ebbaf87..ba841e9 100755 --- a/GetOrganelleLib/statistical_func.py +++ b/GetOrganelleLib/statistical_func.py @@ -1,22 +1,16 @@ -try: - from scipy import stats -except ImportError: - class stats: - class norm: - def logpdf(foo1, foo2, foo3): - raise ImportError("Failed in 'from scipy import stats, inf, log'!") - inf = float("inf") - from math import log -try: - from scipy import inf, log -except ImportError: - try: - from numpy import inf, log - except ImportError: - inf = float("inf") - from math import log +# try: +# from scipy import stats, inf, log +# except ImportError: +# class stats: +# class norm: +# def logpdf(foo1, foo2, foo3): +# raise ImportError("Failed in 'from scipy import stats, inf, log'!") +# inf = float("inf") +from math import log, inf, sqrt, pi +from itertools import permutations from copy import deepcopy +# add try except so that when statistical_func.py is called by other scripts, it will not prompt error immediately try: import numpy as np except ImportError: @@ -29,8 +23,12 @@ def vstack(foo1): raise ImportError("No module named numpy") def where(foo1): raise ImportError("No module named numpy") -import random import sys +try: + from gekko import GEKKO +except ImportError: + def GEKKO(remote): + raise ImportError("Failed in 'from gekko import GEKKO'!") def bic(loglike, len_param, len_data): @@ -47,19 +45,44 @@ def weighted_mean_and_std(values, weights): return mean, std -def weighted_gmm_with_em_aic(data_array, data_weights=None, minimum_cluster=1, maximum_cluster=5, min_sigma_factor=1E-5, - cluster_limited=None, log_handler=None, verbose_log=False): +def norm_logpdf(numpy_array, mu, sigma): + u = (numpy_array-mu)/abs(sigma) + return log(1/(sqrt(2*pi)*abs(sigma)))-u*u/2 + + +def weighted_clustering_with_em_aic(data_array, + data_weights=None, + minimum_cluster=1, + maximum_cluster=5, + min_sigma_factor=1E-5, + cluster_limited=None, + cluster_bans=None, + log_handler=None, + verbose_log=False, + random_obj=None): """ + The current implementation is using a categorical distribution, + with assignment of data exclusively to specific components. + :param data_array: :param data_weights: :param minimum_cluster: :param maximum_cluster: :param min_sigma_factor: :param cluster_limited: {dat_id1: {0, 1}, dat_id2: {0}, dat_id3: {0} ...} + cluster_limited has priority over cluster_bans if conflicted + :param cluster_bans: {dat_a: {0, 1}, dat_b: {0}, dat_c: {0} ...} :param log_handler: :param verbose_log: + :param random_obj: to control the random process using a universal seed :return: """ + # import time + # time0 = time.time() + # pdf_time = [0.] + if random_obj is None: + import random as random_obj + min_sigma = min_sigma_factor * np.average(data_array, weights=data_weights) def model_loglike(dat_arr, dat_w, lbs, parameters): @@ -68,73 +91,192 @@ def model_loglike(dat_arr, dat_w, lbs, parameters): points = dat_arr[lbs == go_to_cl] weights = dat_w[lbs == go_to_cl] if len(points): - total_loglike += sum(stats.norm.logpdf(points, pr["mu"], pr["sigma"]) * weights + log(pr["percent"])) + # total_loglike += sum(norm_logpdf(points, pr["mu"], pr["sigma"]) * weights) + # total_loglike += sum(norm_logpdf(points, pr["mu"], pr["sigma"]) * weights + log(pr["percent"])) + # total_loglike += sum(norm_logpdf(points, pr["mu"], pr["sigma"]) * weights + log(pr["percent"])) + total_loglike += sum((norm_logpdf(points, pr["mu"], pr["sigma"]) + log(pr["percent"])) * weights) return total_loglike + def revise_labels_according_to_constraints(_raw_lbs, _fixed, loglike_table=None): + _new_labels = list(_raw_lbs) + for _dat_id in range(int(data_len)): + if _dat_id in _fixed: + if _raw_lbs[_dat_id] in _fixed[_dat_id]: + # _new_labels[_dat_id] = _raw_lbs[_dat_id] + pass + else: + if loglike_table is None: + _new_labels[_dat_id] = sorted(_fixed[_dat_id])[0] + else: + val_increasing_order = list(loglike_table[:, _dat_id].argsort()) + for order_id in range(len(val_increasing_order) - 2, -1, -1): # search from the sub-optimum + this_label = val_increasing_order.index(order_id) + if this_label in _fixed[_dat_id]: + _new_labels[_dat_id] = this_label + # v1.8.0-pre6 fix a bug + break + else: + # _new_labels[_dat_id] = _raw_lbs[_dat_id] + pass + return np.array(_new_labels) + def assign_cluster_labels(dat_arr, dat_w, parameters, lb_fixed): # assign every data point to its most likely cluster if len(parameters) == 1: return np.array([0] * int(data_len)) + # elif len(in_params) == len(dat_arr): + # return np.array(range(len(in_params))) else: # the parameter set of the first cluster - loglike_res = stats.norm.logpdf(dat_arr, parameters[0]["mu"], parameters[0]["sigma"]) * dat_w + \ - log(parameters[0]["percent"]) + # timex = time.time() + # loglike_res = norm_logpdf(dat_arr, parameters[0]["mu"], parameters[0]["sigma"]) * dat_w + # loglike_res = norm_logpdf(dat_arr, parameters[0]["mu"], parameters[0]["sigma"]) * dat_w + \ + # log(parameters[0]["percent"]) + loglike_res = (norm_logpdf(dat_arr, parameters[0]["mu"], parameters[0]["sigma"]) + + log(parameters[0]["percent"])) * dat_w # the parameter set of the rest cluster for pr in parameters[1:]: loglike_res = np.vstack( - (loglike_res, stats.norm.logpdf(dat_arr, pr["mu"], pr["sigma"]) * dat_w + log(pr["percent"]))) + # (loglike_res, norm_logpdf(dat_arr, pr["mu"], pr["sigma"]) * dat_w)) + # (loglike_res, norm_logpdf(dat_arr, pr["mu"], pr["sigma"]) * dat_w + log(pr["percent"]))) + (loglike_res, (norm_logpdf(dat_arr, pr["mu"], pr["sigma"]) + log(pr["percent"])) * dat_w)) + # print(loglike_res) + # pdf_time[0] += time.time() - timex # assign labels new_labels = loglike_res.argmax(axis=0) if lb_fixed: - intermediate_labels = [] - for here_dat_id in range(int(data_len)): - if here_dat_id in lb_fixed: - if new_labels[here_dat_id] in lb_fixed[here_dat_id]: - intermediate_labels.append(new_labels[here_dat_id]) - else: - intermediate_labels.append(sorted(lb_fixed[here_dat_id])[0]) - else: - intermediate_labels.append(new_labels[here_dat_id]) - new_labels = np.array(intermediate_labels) + # intermediate_labels = [] + # for here_dat_id in range(int(data_len)): + # if here_dat_id in lb_fixed: + # if new_labels[here_dat_id] in lb_fixed[here_dat_id]: + # intermediate_labels.append(new_labels[here_dat_id]) + # else: + # intermediate_labels.append(sorted(lb_fixed[here_dat_id])[0]) + # else: + # intermediate_labels.append(new_labels[here_dat_id]) + # np.array(intermediate_labels) + new_labels = revise_labels_according_to_constraints(new_labels, lb_fixed, loglike_res) # new_labels = np.array([ # sorted(cluster_limited[dat_item])[0] # if new_labels[here_dat_id] not in cluster_limited[dat_item] else new_labels[here_dat_id] # if dat_item in cluster_limited else # new_labels[here_dat_id] # for here_dat_id, dat_item in enumerate(data_array)]) - limited_values = set(dat_arr[list(lb_fixed)]) - else: - limited_values = set() - # if there is an empty cluster, - # and if there is another non-empty cluster with two ends not in the fixed (lb_fixed), - # then move one of the end (min or max) from that non-empty cluster to the empty cluster + # limited_values = set(dat_arr[list(lb_fixed)]) + # else: + # limited_values = set() + if len(set(new_labels)) > len(parameters): + raise ValueError("Assigning failed!") label_counts = {lb: 0 for lb in range(len(parameters))} for ct_lb in new_labels: label_counts[ct_lb] += 1 - for empty_lb in label_counts: - if label_counts[empty_lb] == 0: - non_empty_lbs = {ne_lb: [min, max] for ne_lb in label_counts if label_counts[ne_lb] > 1} - for af_lb in sorted(non_empty_lbs): - these_points = dat_arr[new_labels == af_lb] - if max(these_points) in limited_values: - non_empty_lbs[af_lb].remove(max) - if min(these_points) in limited_values: - non_empty_lbs[af_lb].remove(min) - if not non_empty_lbs[af_lb]: - del non_empty_lbs[af_lb] - if non_empty_lbs: - chose_lb = random.choice(list(non_empty_lbs)) - chose_points = dat_arr[new_labels == chose_lb] - # random.choice([min, max]), then use the resulting function to pick the point - data_point = random.choice(non_empty_lbs[chose_lb])(chose_points) - transfer_index = random.choice(np.where(dat_arr == data_point)[0]) - new_labels[transfer_index] = empty_lb - label_counts[chose_lb] -= 1 + # 2023-01-15 added + empty_lbs = [] + # filled_lbs = [] + for label_id in range(len(parameters)): + if label_counts[label_id] == 0: + empty_lbs.append(label_id) + # else: + # filled_lbs.append(label_id) + if empty_lbs: + # find the new option that reduce the likelihood least + # if len(empty_lbs) == 1: + # new_lb = empty_lbs[0] + # loglike_reduced = np.max(loglike_res, axis=0) - loglike_res[new_lb] + # orders_to_d = {order_id: _d_id for _d_id, order_id in enumerate(loglike_reduced.argsort())} + # for try_order_id in range(len(dat_arr)): + # try_data_id = orders_to_d[try_order_id] + # # if old label is not fixed, or the new label is within the constraint + # if try_data_id not in lb_fixed or new_lb in lb_fixed[try_data_id]: + # new_labels[try_data_id] = new_lb + # break + # else: + # raise ValueError("Assigning failed!") + # else: + tmp_info = {} + to_change = {} + lb_counts = deepcopy(label_counts) + for new_lb in empty_lbs: + tmp_info[new_lb] = {} + loglike_reduced = np.max(loglike_res, axis=0) - loglike_res[new_lb] + tmp_info[new_lb]["reduced"] = loglike_reduced + orders_to_d = {order_id: _d_id for _d_id, order_id in enumerate(loglike_reduced.argsort())} + tmp_info[new_lb]["orders"] = orders_to_d + for try_order_id in range(len(dat_arr)): + try_data_id = orders_to_d[try_order_id] + # if old label is not fixed, or the new label is within the constraint + if (try_data_id not in lb_fixed or new_lb in lb_fixed[try_data_id]) and \ + lb_counts[new_labels[try_data_id]] > 1: + to_change[try_data_id] = new_lb + lb_counts[new_lb] += 1 + lb_counts[new_labels[try_data_id]] -= 1 + break + else: + raise ValueError("Assigning failed!") + if len(to_change) == len(empty_lbs): # each empty lbs has a unique data to fill + for data_id, new_lb in to_change.items(): + new_labels[data_id] = new_lb + else: + # slow but easy to coding way + replace_res = {} + for new_lb_order in permutations(empty_lbs): + replace_res[new_lb_order] = {"loglike": 0., + "change": {}, + "failed": False, + "counts": deepcopy(label_counts)} + for new_lb in new_lb_order: + for try_order_id in range(len(dat_arr)): + try_data_id = tmp_info[new_lb]["orders"][try_order_id] + # if the data id was not used "by other new lbs", AND + # if the donor has more than two occurrences + # if data id is not fixed, or the new label is within the constraint + if try_data_id not in replace_res[new_lb_order]["change"] and \ + replace_res[new_lb_order]["counts"][new_labels[try_data_id]] > 1 and \ + (try_data_id not in lb_fixed or new_lb in lb_fixed[try_data_id]): + replace_res[new_lb_order]["change"][try_data_id] = new_lb + replace_res[new_lb_order]["loglike"] -= tmp_info[new_lb]["reduced"][try_data_id] + replace_res[new_lb_order]["counts"][new_lb] += 1 + replace_res[new_lb_order]["counts"][new_labels[try_data_id]] -= 1 + break + else: + replace_res[new_lb_order]["failed"] = True + replace_res[new_lb_order]["loglike"] = -inf + best_order, best_info = sorted(replace_res.items(), key=lambda x: -x[1]["loglike"])[0] + if best_info["failed"]: + raise ValueError("Assigning failed!") + else: + for data_id, new_lb in best_info["change"].items(): + new_labels[data_id] = new_lb + # if there is an empty cluster, + # and if there is another non-empty cluster with two ends not in the fixed (lb_fixed), + # then move one of the end (min or max) from that non-empty cluster to the empty cluster + # for empty_lb in label_counts: + # if label_counts[empty_lb] == 0: + # non_empty_lbs = {ne_lb: [min, max] for ne_lb in label_counts if label_counts[ne_lb] > 1} + # for af_lb in sorted(non_empty_lbs): + # these_points = dat_arr[new_labels == af_lb] + # if max(these_points) in limited_values: + # non_empty_lbs[af_lb].remove(max) + # if min(these_points) in limited_values: + # non_empty_lbs[af_lb].remove(min) + # if not non_empty_lbs[af_lb]: + # del non_empty_lbs[af_lb] + # if non_empty_lbs: + # chose_lb = random_obj.choice(list(non_empty_lbs)) + # chose_points = dat_arr[new_labels == chose_lb] + # # random.choice([min, max]), then use the resulting function to pick the point + # data_point = random_obj.choice(non_empty_lbs[chose_lb])(chose_points) + # # random.choice(np.array([0])) triggers: IndexError: Cannot choose from an empty sequence + # transfer_index = random_obj.choice(list(np.where(dat_arr == data_point)[0])) + # new_labels[transfer_index] = empty_lb + # label_counts[chose_lb] -= 1 + # # 2022-12-18 fix a long-lasting issue + # label_counts[empty_lb] += 1 return new_labels - def updating_parameter(dat_arr, dat_w, lbs, parameters): - - for go_to_cl, pr in enumerate(parameters): + def updating_parameter(dat_arr, dat_w, lbs, in_params): + new_params = deepcopy(in_params) + for go_to_cl, pr in enumerate(new_params): these_points = dat_arr[lbs == go_to_cl] these_weights = dat_w[lbs == go_to_cl] if len(these_points) > 1: @@ -144,18 +286,20 @@ def updating_parameter(dat_arr, dat_w, lbs, parameters): pr["percent"] = sum(these_weights) # / data_len elif len(these_points) == 1: pr["sigma"] = max(dat_arr.std() / data_len, min_sigma) - pr["mu"] = np.average(these_points, weights=these_weights) + pr["sigma"] * (2 * random.random() - 1) + # 2023-01-15 + # pr["mu"] = np.average(these_points, weights=these_weights) + pr["sigma"] * (2 * random_obj.random() - 1) + pr["mu"] = these_points[0] pr["percent"] = sum(these_weights) # / data_len else: # exclude pr["mu"] = max(dat_arr) * 1E4 pr["sigma"] = min_sigma pr["percent"] = 1E-10 - return parameters + return new_params data_array = np.array(data_array) data_len = float(len(data_array)) - if not len(data_weights): + if data_weights is None or not len(data_weights): data_weights = np.array([1. for foo in range(int(data_len))]) else: assert len(data_weights) == data_len @@ -164,72 +308,136 @@ def updating_parameter(dat_arr, dat_w, lbs, parameters): data_weights = np.array([raw_w / average_weights for raw_w in data_weights]) results = [] - if cluster_limited: - cls = set() - for sub_cls in cluster_limited.values(): - cls |= sub_cls - freedom_dat_item = int(data_len) - len(cluster_limited) + len(cls) - else: - freedom_dat_item = int(data_len) - minimum_cluster = min(freedom_dat_item, minimum_cluster) - maximum_cluster = min(freedom_dat_item, maximum_cluster) + # adjust the min and max number of clusters according to constraints + if cluster_limited is None: + # freedom_dat_item = int(data_len) + cluster_limited = {} + # else: + # cls = set() + # for sub_cls in cluster_limited.values(): + # cls |= sub_cls + # freedom_dat_item = max(0, int(data_len) - max(0, len(cluster_limited) - len(cls))) + # min_choices = 0 + # if cluster_bans is None: + # cluster_bans = {} + # else: + # for ban_dat_id, sub_cls in cluster_bans.items(): + # if ban_dat_id not in cluster_limited: # cluster_limited has priority over cluster_bans + # min_choices = max(len(sub_cls) + 1, min_choices) + # # assert min_choices < freedom_dat_item, "unrealistic constraints: \ncluster_limited: " + \ + # # str(cluster_limited) + "\ncluster_bans: " + \ + # # str(cluster_bans) + # minimum_cluster = min(freedom_dat_item, max(minimum_cluster, min_choices)) + # maximum_cluster = min(freedom_dat_item, max(maximum_cluster, min_choices)) + # 2023-01-15 + # print(minimum_cluster, maximum_cluster) + minimum_cluster = max(minimum_cluster, len(set([tuple(_cl) for _cl in cluster_limited.values() if len(_cl) == 1]))) + maximum_cluster = min(maximum_cluster, len(data_array)) + # print(minimum_cluster, maximum_cluster) + + # timey = time.time() + # round_times = [] # iteratively try the num of clusters for total_cluster_num in range(minimum_cluster, maximum_cluster + 1): - # initialization - labels = np.random.choice(total_cluster_num, int(data_len)) + cluster_num_failure = False + if log_handler and verbose_log: + log_handler.info("assessing %i clusters" % total_cluster_num) if cluster_limited: - temp_labels = [] - for dat_id in range(int(data_len)): - if dat_id in cluster_limited: - if labels[dat_id] in cluster_limited[dat_id]: - temp_labels.append(labels[dat_id]) - else: - temp_labels.append(sorted(cluster_limited[dat_id])[0]) - else: - temp_labels.append(labels[dat_id]) - labels = np.array(temp_labels) + this_limit = deepcopy(cluster_limited) + for dat_id in cluster_bans: + if dat_id not in this_limit: + this_limit[dat_id] = set() + for potential_lb_id in range(total_cluster_num): + if potential_lb_id not in cluster_bans[dat_id]: + this_limit[dat_id].add(potential_lb_id) + else: + this_limit = {} + # initialization + # # labels = np_rd_obj.choice(total_cluster_num, int(data_len)) + # # TODO: each cluster has to have at least one occurrence, which will be complicated combined with this_limit + # min_occurrences = list(range(total_cluster_num)) + # random_obj.shuffle(min_occurrences) + # labels = np.array(min_occurrences + + # random_obj.choices(range(total_cluster_num), k=int(data_len) - total_cluster_num)) + # labels = revise_labels_according_to_constraints(labels, this_limit) + # using decreasing order rather than random @2023-01-18, to create more reasonable initial parameter set + extra_for_first_lb = int(data_len) % total_cluster_num + each_len = (int(data_len) - extra_for_first_lb) // total_cluster_num + cov_decreasing_order = np.flip(data_array.argsort()) + labels = np.zeros(int(data_len), dtype=np.int8) + for go_l, go_d in enumerate(range(extra_for_first_lb, int(data_len), each_len)): + labels[cov_decreasing_order[go_d: go_d + each_len]] = go_l + # initialize the parameters norm_parameters = updating_parameter(data_array, data_weights, labels, + # [{"mu": 0, "sigma": 1} [{"mu": 0, "sigma": 1, "percent": total_cluster_num/data_len} for foo in range(total_cluster_num)]) + if log_handler and verbose_log: + log_handler.info(" initial labels: " + str(list(labels))) + log_handler.info(" initial params: " + str(norm_parameters)) loglike_shift = inf prev_loglike = -inf - epsilon = 0.01 + epsilon = 0.001 count_iterations = 0 - best_loglike = prev_loglike + best_loglike = -inf best_parameter = norm_parameters - try: - while loglike_shift > epsilon: - count_iterations += 1 - # expectation - labels = assign_cluster_labels(data_array, data_weights, norm_parameters, cluster_limited) - # maximization - updated_parameters = updating_parameter(data_array, data_weights, labels, deepcopy(norm_parameters)) - # loglike shift - this_loglike = model_loglike(data_array, data_weights, labels, updated_parameters) - loglike_shift = abs(this_loglike - prev_loglike) - # update - prev_loglike = this_loglike - norm_parameters = updated_parameters - if this_loglike > best_loglike: - best_parameter = updated_parameters - best_loglike = this_loglike - labels = assign_cluster_labels(data_array, data_weights, best_parameter, None) - results.append({"loglike": best_loglike, "iterates": count_iterations, "cluster_num": total_cluster_num, - "parameters": best_parameter, "labels": labels, - "aic": aic(prev_loglike, 2 * total_cluster_num), - "bic": bic(prev_loglike, 2 * total_cluster_num, data_len)}) - except TypeError as e: - if log_handler: - log_handler.error("This error might be caused by outdated version of scipy!") + count_best = 1 + while loglike_shift > epsilon and count_iterations < 500 and count_best < 50: + count_iterations += 1 + # expectation + try: + labels = assign_cluster_labels(data_array, data_weights, norm_parameters, this_limit) + except ValueError as e: + if str(e) == "Assigning failed!": + if verbose_log and log_handler: + log_handler.info(" assigning failed for %i clusters" % total_cluster_num) + cluster_num_failure = True + break + else: + raise e + # maximization + updated_parameters = updating_parameter(data_array, data_weights, labels, deepcopy(norm_parameters)) + this_loglike = model_loglike(data_array, data_weights, labels, updated_parameters) + if log_handler and verbose_log: + log_handler.info(" iter_%i labels: " % count_iterations + str(list(labels))) + log_handler.info(" iter_%i params: " % count_iterations + str(updated_parameters)) + log_handler.info(" iter_%i loglike: " % count_iterations + str(this_loglike)) + loglike_shift = abs((this_loglike - prev_loglike) / this_loglike) + # update + prev_loglike = this_loglike + norm_parameters = updated_parameters + if this_loglike > best_loglike: + best_parameter = updated_parameters + best_loglike = this_loglike + count_best = 1 else: - sys.stdout.write("This error might be caused by outdated version of scipy!\n") - raise e + count_best += 1 + if cluster_num_failure: + break + # 2023-01-15 replace: labels = assign_cluster_labels(data_array, data_weights, best_parameter, None) + labels = assign_cluster_labels(data_array, data_weights, best_parameter, this_limit) + results.append({"loglike": best_loglike, "iterates": count_iterations, "cluster_num": total_cluster_num, + "parameters": best_parameter, "labels": labels, + "aic": aic(best_loglike, 2 * total_cluster_num), + "bic": bic(best_loglike, 2 * total_cluster_num, data_len)}) + # except TypeError as e: + # if log_handler: + # log_handler.error("This error might be caused by outdated version of scipy!") + # else: + # sys.stdout.write("This error might be caused by outdated version of scipy!\n") + # raise e + # round_times.append(time.time() - timey) + # timey = time.time() if verbose_log: if log_handler: log_handler.info(str(results)) else: sys.stdout.write(str(results) + "\n") + # TODO: if all clustering failed + if not results: + raise ValueError("Solution Not Found!") best_scheme = sorted(results, key=lambda x: x["bic"])[0] + # print(time.time() - time0, round_times, pdf_time) return best_scheme diff --git a/GetOrganelleLib/versions.py b/GetOrganelleLib/versions.py index f9a9e3b..344ecb4 100644 --- a/GetOrganelleLib/versions.py +++ b/GetOrganelleLib/versions.py @@ -5,6 +5,44 @@ def get_versions(): versions = [ + { + "number": "1.8.0.0", + "features": [ + "1. remove redundant disentangling for get_organelle_from_assembly.py", + "2. separate temporary files from different rounds", + "3. remove upper boundary for coverage-based filtering", + "4. remove max multiplicity boundary", + "5. remove scipy & sympy: " + " 5.1. use gekko instead of scipy for multiplicity estimation, with initials and option for multinomial;" + " 5.2. use custom norm_logpdf function instead of slow scipy.stats.norm.logpdf", + "6. limit the number of paths before generating", + "7. log copy info for --no-slim", + "8. statistical_func.py: fix a bug" + " random.choice(np.array([0])) triggers IndexError: Cannot choose from an empty sequence", + "9. assembly.parse_tab_file: " + " only keep one vertex labeled for each gene tags -> " + " gene tags can occur in multiple vertices that are linearly continuous;", + "10. using negative tag weights to better differentiate target and non-target but similar neighbors", + "11. change the default of depth_factor because the baseline is now set to be average rather than the max", + "12. slim_graph.generate_baits_offsets: fix a bug generating wrong offsets: min -> max", + "13. slim_graph.py: 1. rm_contigs: < depth_cutoff -> >= depth_cutoff. " + " 2. add contig_min_hit_percent (add associated func to seq_parser.py). " + " 3. allow no depth_cutoff (depth_cutoff==-1). ", + "14. add weight_factor to get_organelle_from_assembly.py", + "15. parsing the same graph_file and tab_file only once before multiple disentanglement trials", + "16. skip clustering and other steps when len(vertex_info)==1", + "17. discard np.random, fix consistent issue, modify find_target_graph to generate contigs", + "18. create soft link instead of copying the original read file(s) into the working directory", + "19. get_organelle_from_reads.py.make_read_index(): " + " on --continue and temp.indices.1 existed, separate conditions for speeding up ", + "20. Assembly.reduce_to_graph: minor changes", + "21. get_organelle_from_assembly.py: raise exception on empty graphs", + "22. add gb_to_tbl.py for common format conversion; add biopython as the dependency", + "23. fix a bug: summary_get_organelle_output.py does not recognize degenerate/ambiguous bases (issue 279)", + "24. automatic converting negative coverage to positive and report it (probably Bandage output issue)", + ], + "time": "2023-08-20 16:00 UTC+8" + }, { "number": "1.7.7.1", "features": [ @@ -423,7 +461,7 @@ def get_versions(): "2. get_organelle_from_assembly.py: fix a bug on parsing gfa file with long seq head names; " " --keep-temp fixed; fix a bug with '-h'; ", "3. Utilities/slim_fastg.py: --no-merge -> --merge; disable merge by default", - "4. GetOrganelleLib/assembly_parser.py: fix a bug with generating new vertices, " + "4. GetOrganelleLib/assembly_parser.py: fix a bug with generating new vertices_set, " " as well as merge_all_possible_contigs; export plastome-LSC direction according to convention based on " " accumulated orf lengths (the conventional reverse direction has more accumulated orf lengths), which " " makes users easier to use; remove processing_polymorphism() before filter_by_coverage() to better " @@ -569,7 +607,7 @@ def get_versions(): "features": [ "1.get_organelle_reads.py: fix a bug with --continue & --prefix when LogInfo() added; ", "2.assembly_parser.py & statistical_func.py: " - "if single copy vertex percentage is < 50%, continue dropping suspicious vertices", + "if single copy vertex percentage is < 50%, continue dropping suspicious vertices_set", "3.pip_control_func.py: for --prefix", ]}, {"number": "1.4.3a", @@ -611,7 +649,7 @@ def get_versions(): {"number": "1.4.1", "features": [ "1.assembly_parser.py: Assembly.export_path() and Assembly.merge_all_possible_vertices():" - " name of merged vertices optimized", + " name of merged vertices_set optimized", "2.README: PATH configuration", "3.mk_batch_for_iteratively_mapping_assembling.py: -t threads", "4.get_organelle_reads.py --fast mode modified", @@ -716,7 +754,7 @@ def get_versions(): ]}, {"number": "1.2.0b", "features": [ - "1.Assembly.parse_fastg(): (more robust) Add connection information to both of the related vertices" + "1.Assembly.parse_fastg(): (more robust) Add connection information to both of the related vertices_set" " even it is only mentioned once;", "2.Assembly.is_sequential_repeat(): fix a bug that leak in the reverse direction;", "3.add depth_factor to the main script;", diff --git a/Utilities/disentangle_organelle_assembly.py b/Utilities/disentangle_organelle_assembly.py index 48d716d..7a80b8d 100755 --- a/Utilities/disentangle_organelle_assembly.py +++ b/Utilities/disentangle_organelle_assembly.py @@ -30,8 +30,9 @@ def get_options(print_title): "organelle genome), and would directly give up linear/broken graphs. Choose this option " "to try for linear/broken cases.") parser.add_argument("--weight-f", dest="weight_factor", type=float, default=100.0, - help="weight factor for excluding non-target contigs. Default:%(default)s") - parser.add_argument("--depth-f", dest="depth_factor", type=float, default=10., + help="weight factor for excluding isolated/terminal suspicious contigs with gene labels. " + "Default:%(default)s") + parser.add_argument("--depth-f", dest="depth_factor", type=float, default=5., help="Depth factor for excluding non-target contigs. Default:%(default)s") parser.add_argument("--type-f", dest="type_factor", type=float, default=3., help="Type factor for identifying genome type tag. Default:%(default)s") @@ -71,9 +72,9 @@ def get_options(print_title): help="Minimum coverage for a contig to be included in disentangling. Default:%(default)s") parser.add_argument("--max-depth", dest="max_cov", type=float, default=inf, help="Minimum coverage for a contig to be included in disentangling. Default:%(default)s") - parser.add_argument("--max-multiplicity", dest="max_multiplicity", type=int, default=8, - help="Maximum multiplicity of contigs for disentangling genome paths. " - "Should be 1~12. Default:%(default)s") + # parser.add_argument("--max-multiplicity", dest="max_multiplicity", type=int, default=8, + # help="Maximum multiplicity of contigs for disentangling genome paths. " + # "Should be 1~12. Default:%(default)s") parser.add_argument("--prefix", dest="prefix", default="target", help="Prefix of output files inside output directory. Default:%(default)s") parser.add_argument("--keep-temp", dest="keep_temp_graph", default=False, action="store_true", @@ -104,7 +105,7 @@ def get_options(print_title): sys.stdout.write("Insufficient arguments!\n") sys.exit() else: - assert 12 >= options.max_multiplicity >= 1 + # assert 12 >= options.max_multiplicity >= 1 assert options.max_paths_num > 0 if options.output_directory and not os.path.exists(options.output_directory): os.mkdir(options.output_directory) @@ -119,8 +120,6 @@ def get_options(print_title): # options.expected_max_size /= 2 elif options.mode in ("embplant_nr", "animal_mt", "fungus_nr"): options.expected_max_size /= 8 - random.seed(options.random_seed) - np.random.seed(options.random_seed) return options, log_handler @@ -131,37 +130,44 @@ def main(): " from slim_fastg.py-produced files (csv & fastg). " + \ "\n\n" options, log_handler = get_options(print_title) + random.seed(options.random_seed) + np.random.seed(options.random_seed) @set_time_limit(options.time_limit) - def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, type_factor, mode="embplant_pt", - hard_cov_threshold=10., expected_max_size=inf, expected_min_size=0, + def disentangle_circular_assembly(input_graph, + # tab_file, + prefix, + weight_factor, + # type_factor, + mode="embplant_pt", + hard_cov_threshold=5., expected_max_size=inf, expected_min_size=0, contamination_depth=3., contamination_similarity=5., degenerate=True, degenerate_depth=1.5, degenerate_similarity=1.5, - min_sigma_factor=0.1, max_copy_in=10, only_max_cov=True, + min_sigma_factor=0.1, only_max_cov=True, # max_copy_in=10, keep_temp=False, acyclic_allowed=False, verbose=False, inner_logging=None, debug=False): - if options.resume and os.path.exists(prefix + ".graph1.selected_graph.gfa"): - pass + if options.resume and os.path.exists(prefix + ".graph1.path_sequence.gfa"): if inner_logging: inner_logging.info(">>> Result graph existed!") else: sys.stdout.write(">>> Result graph existed!\n") else: - time_a = time.time() - if inner_logging: - inner_logging.info(">>> Parsing " + fastg_file + " ..") - else: - sys.stdout.write("Parsing " + fastg_file + " ..\n") - input_graph = Assembly(fastg_file, min_cov=options.min_cov, max_cov=options.max_cov) + # time_a = time.time() + # if inner_logging: + # inner_logging.info(">>> Parsing " + fastg_file + " ..") + # else: + # sys.stdout.write("Parsing " + fastg_file + " ..\n") + # input_graph = Assembly(fastg_file, min_cov=options.min_cov, max_cov=options.max_cov) time_b = time.time() - if inner_logging: - inner_logging.info(">>> Parsing input fastg file finished: " + str(round(time_b - time_a, 4)) + "s") - else: - sys.stdout.write("\n>>> Parsing input fastg file finished: " + str(round(time_b - time_a, 4)) + "s\n") + # if inner_logging: + # inner_logging.info(">>> Parsing input fastg file finished: " + str(round(time_b - time_a, 4)) + "s") + # else: + # sys.stdout.write("\n>>> Parsing input fastg file finished: " + str(round(time_b - time_a, 4)) + "s\n") temp_graph = prefix + ".temp.fastg" if keep_temp else None - - copy_results = input_graph.find_target_graph(tab_file, database_name=mode, mode=mode, - type_factor=type_factor, + selected_graph = prefix + ".graph.selected_graph.gfa" + copy_results = input_graph.find_target_graph( # tab_file, + db_name=mode, mode=mode, + # type_factor=type_factor, weight_factor=weight_factor, hard_cov_threshold=hard_cov_threshold, contamination_depth=contamination_depth, @@ -170,13 +176,14 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t degenerate_similarity=degenerate_similarity, expected_max_size=expected_max_size, expected_min_size=expected_min_size, - max_contig_multiplicity=max_copy_in, only_keep_max_cov=only_max_cov, min_sigma_factor=min_sigma_factor, + selected_graph=selected_graph, temp_graph=temp_graph, broken_graph_allowed=acyclic_allowed, verbose=verbose, log_handler=inner_logging, - debug=debug) + debug=debug, + random_obj=random) time_c = time.time() if inner_logging: inner_logging.info(">>> Detecting target graph finished: " + str(round(time_c - time_b, 4)) + "s") @@ -194,12 +201,14 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t go_res += 1 broken_graph = copy_res["graph"] count_path = 0 - - these_paths = broken_graph.get_all_paths(mode=mode, log_handler=inner_logging) + # use options.max_paths_num + 1 to trigger the warning + these_paths = broken_graph.get_all_paths(mode=mode, log_handler=inner_logging, + max_paths_num=options.max_paths_num + 1) # reducing paths if len(these_paths) > options.max_paths_num: this_warn_str = "Only exporting " + str(options.max_paths_num) + " out of all " + \ - str(len(these_paths)) + " possible paths. (see '--max-paths-num' to change it.)" + str(options.max_paths_num) + \ + "+ possible paths. (see '--max-paths-num' to change it.)" if inner_logging: inner_logging.warning(this_warn_str) else: @@ -242,7 +251,7 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t # still_complete.append(False) open(prefix + ".graph" + str(go_res) + other_tag + "." + str(count_path) + ".path_sequence.fasta", "w").write("\n".join(all_contig_str)) - broken_graph.write_to_gfa(prefix + ".graph" + str(go_res) + ".selected_graph.gfa") + broken_graph.write_to_gfa(prefix + ".graph" + str(go_res) + ".path_sequence.gfa") else: for go_res, copy_res in enumerate(copy_results): go_res += 1 @@ -250,13 +259,16 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t # should add making one-step-inversion pairs for paths, # which would be used to identify existence of a certain isomer using mapping information count_path = 0 - + # use options.max_paths_num + 1 to trigger the warning these_paths = idealized_graph.get_all_circular_paths( - mode=mode, log_handler=inner_logging, reverse_start_direction_for_pt=options.reverse_lsc) + mode=mode, log_handler=inner_logging, reverse_start_direction_for_pt=options.reverse_lsc, + max_paths_num=options.max_paths_num + 1, + ) # reducing paths if len(these_paths) > options.max_paths_num: this_warn_str = "Only exporting " + str(options.max_paths_num) + " out of all " + \ - str(len(these_paths)) + " possible paths. (see '--max-paths-num' to change it.)" + str(options.max_paths_num) + \ + "+ possible paths. (see '--max-paths-num' to change it.)" if inner_logging: inner_logging.warning(this_warn_str) else: @@ -282,7 +294,7 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t inner_logging.info(print_str) else: sys.stdout.write(print_str + "\n") - idealized_graph.write_to_gfa(prefix + ".graph" + str(go_res) + ".selected_graph.gfa") + idealized_graph.write_to_gfa(prefix + ".graph" + str(go_res) + ".path_sequence.gfa") if degenerate_base_used: inner_logging.warning("Degenerate base(s) used!") time_d = time.time() @@ -292,9 +304,35 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t sys.stdout.write("\n\n>>> Solving and unfolding graph finished: " + str(round(time_d - time_c, 4)) + "s\n") try: - disentangle_circular_assembly(options.fastg_file, options.tab_file, + time_1 = time.time() + if log_handler: + log_handler.info(">>> Parsing " + options.fastg_file + " ..") + else: + sys.stdout.write("Parsing " + options.fastg_file + " ..\n") + assembly_graph_obj = Assembly( + options.fastg_file, min_cov=options.min_cov, max_cov=options.max_cov, log_handler=log_handler) + if log_handler: + log_handler.info(">>> Loading and cleaning labels along " + options.fastg_file) + else: + sys.stdout.write("Loading and cleaning labels along " + options.fastg_file + "\n") + assembly_graph_obj.parse_tab_file( + options.tab_file, + database_name=options.mode, + type_factor=options.type_factor, + max_gene_gap=250, + max_cov_diff=options.depth_factor, # contamination_depth? + verbose=options.verbose, + log_handler=log_handler, + random_obj=random) + time_2 = time.time() + if log_handler: + log_handler.info(">>> Parsing input fastg file finished: " + str(round(time_2 - time_1, 4)) + "s") + else: + sys.stdout.write("\n>>> Parsing input fastg file finished: " + str(round(time_2 - time_1, 4)) + "s\n") + disentangle_circular_assembly(assembly_graph_obj, + # options.tab_file, os.path.join(options.output_directory, options.prefix), - type_factor=options.type_factor, + # type_factor=options.type_factor, mode=options.mode, weight_factor=options.weight_factor, hard_cov_threshold=options.depth_factor, @@ -305,7 +343,7 @@ def disentangle_circular_assembly(fastg_file, tab_file, prefix, weight_factor, t expected_max_size=options.expected_max_size, expected_min_size=options.expected_min_size, min_sigma_factor=options.min_sigma_factor, - max_copy_in=options.max_multiplicity, + # max_copy_in=options.max_multiplicity, only_max_cov=options.only_keep_max_cov, acyclic_allowed=options.acyclic_allowed, keep_temp=options.keep_temp_graph, inner_logging=log_handler, verbose=options.verbose, debug=options.debug) diff --git a/Utilities/evaluate_assembly_using_mapping.py b/Utilities/evaluate_assembly_using_mapping.py index 83fbac8..5d11ce0 100755 --- a/Utilities/evaluate_assembly_using_mapping.py +++ b/Utilities/evaluate_assembly_using_mapping.py @@ -13,7 +13,11 @@ from GetOrganelleLib.statistical_func import * from GetOrganelleLib.versions import get_versions PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] -from sympy import Interval +try: + from sympy import Interval +except ImportError: + print("please install sympy to execute this script, e.g. using: pip install sympy") + sys.exit() import sys import platform SYSTEM_NAME = "" diff --git a/Utilities/gb_to_tbl.py b/Utilities/gb_to_tbl.py new file mode 100755 index 0000000..37b5a1e --- /dev/null +++ b/Utilities/gb_to_tbl.py @@ -0,0 +1,165 @@ +#! /usr/bin/env python +__author__ = 'Jianjun Jin' + +import os +import sys +from platform import system +if system() == "Windows": + line_br = "\r\n" +elif system() == "Darwin": + line_br = "\r" +else: + line_br = "\n" +try: + from Bio import SeqIO, SeqFeature +except ImportError: + sys.stdout.write("Python package biopython not found!" + line_br + + "You could use \"pip install biopython\" to install it." + line_br) + sys.exit() +from optparse import OptionParser +from glob import glob + + +"""Convert Genbank format file to tbl format (https://www.ncbi.nlm.nih.gov/Sequin/table.html)""" +# example1 from Seq2 of https://www.ncbi.nlm.nih.gov/WebSub/html/help/feature-table.html + + +def get_options(description=""): + usage = "This is a GetOrganelle script for converting genbank format to tbl format, \n" \ + "which can be further used to submit through Banklt (https://www.ncbi.nlm.nih.gov/WebSub/)\n" \ + "Usage: gb_to_tbl.py gb_files" + parser = OptionParser(description=description, usage=usage) + parser.add_option("-o", dest="output", + help="Output directory. Default: along with the original file.") + parser.add_option("-t", dest="gene_types", default="CDS,tRNA,rRNA,gene,repeat_region,source", + help="Annotation type taken as gene. Set 'all' to report all types. Default: %default") + parser.add_option("-q", dest="qualifiers", default="gene,product,note,standard_name,rpt_type", + help="The qualifiers to record. Set 'all' to report all qualifiers. Default: %default.") + # parser.add_option("--ignore-format-error", dest="ignore_format_error", default=False, action="store_true", + # help="Skip the Error: key \"*\" not found in annotation. Not suggested.") + options, argv = parser.parse_args() + if not len(argv): + parser.print_help() + sys.exit() + if system() == "Windows": + new_argv = [] + for input_fn_pattern in argv: + new_argv.extend(glob(input_fn_pattern)) + argv = new_argv + return options, argv + + +def parse_bio_gb_locations(location_feature): + if type(location_feature) == SeqFeature.CompoundLocation: + return [parse_bio_gb_locations(location)[0] for location in location_feature.parts] + elif type(location_feature) == SeqFeature.FeatureLocation: + return [(int(location_feature.start), int(location_feature.end), location_feature.strand)] + else: + raise ValueError(str(type(location_feature))) + + +def location_feature_to_str(location_feature, feature_type): + locations = parse_bio_gb_locations(location_feature=location_feature) + # switch location[0][0] and locations[0][1] if strand/location[0][2] says reverse/-1 + # locations[0][0] + 1 because Location records indices in Biopython + # example1: 2626 2590 tRNA + lines = ["\t".join([str(locations[0][0] + 1), str(locations[0][1])][::locations[0][2]]) + + "\t" + feature_type + line_br] + # add more parts if location_feature is a CompoundLocation + # example1: 2570 2535 + for loc in locations[1:]: + lines.append("\t".join([str(loc[0] + 1), str(loc[1])][::loc[2]]) + line_br) + return "".join(lines) + + +def genbank_to_tbl(genbank_f, out_base, accepted_type_dict, qualifiers_dict): + this_records = list(SeqIO.parse(genbank_f, "genbank")) + if accepted_type_dict is None: + record_all_types = True + accepted_type_dict = {} + else: + record_all_types = False + if qualifiers_dict is None: + record_all_qualifiers = True + qualifiers_dict = {} + else: + record_all_qualifiers = False + with open(out_base + ".fasta", "w") as output_fs: + pass + with open(out_base + ".tbl", "w") as output_tbl: + pass + for go_record, seq_record in enumerate(this_records): + this_seq_name = (str(go_record + 1) + "--") * int(bool(len(this_records) > 1)) + \ + seq_record.name * int(bool(seq_record.name)) + if not this_seq_name: + raise Exception("Sequence name not found in the " + str(go_record + 1) + " record of " + genbank_f) + else: + with open(out_base + ".fasta", "a") as output_fs: + output_fs.write(">" + this_seq_name + line_br + str(seq_record.seq) + line_br) + with open(out_base + ".tbl", "a") as output_tbl: + output_tbl.write(">Feature " + this_seq_name + line_br) + for feature in seq_record.features: + if record_all_types or feature.type.lower() in accepted_type_dict: + this_type = accepted_type_dict.get(feature.type.lower(), feature.type) + try: + # some locations are compound locations + location_str = location_feature_to_str(feature.location, feature_type=this_type) + except ValueError as e: + sys.stdout.write("Warning: abnormal location " + str(e) + + line_br + str(feature) + + "\nin the " + + str(go_record + 1) + " record of " + genbank_f + " .. skipped!" + line_br) + continue + else: + output_tbl.write(location_str) + for qualifier_k in feature.qualifiers: + if record_all_qualifiers or qualifier_k.lower() in qualifiers_dict: + this_qualifier = qualifiers_dict.get(qualifier_k.lower(), qualifier_k) + this_values = feature.qualifiers[qualifier_k] + for this_val in this_values: + output_tbl.write("\t\t\t" + this_qualifier + "\t" + this_val + "\n") + + +def main(): + options, argv = get_options( + "Convert Genbank format file to tbl format (https://www.ncbi.nlm.nih.gov/Sequin/table.html)" + line_br + + "By jinjianjun@mail.kib.ac.cn") + if options.gene_types != "all": + accepted_types = {this_type.lower(): this_type for this_type in options.gene_types.split(",")} + else: + accepted_types = None + if options.qualifiers != "all": + qualifiers = {this_q.lower(): this_q for this_q in options.qualifiers.split(",")} + else: + qualifiers = None + + # check output file names + output_base_names = [] + if options.output: + if not os.path.exists(options.output): + os.mkdir(options.output) + elif os.path.isfile(options.output): + raise FileExistsError(options.output + " is a file!") + else: + options.output = "" + for gb_file in argv: + this_out_f = os.path.join(options.output, os.path.basename(gb_file)) + if this_out_f.endswith(".gb"): + this_out_f = this_out_f[:-3] + elif this_out_f.endswith(".gbk"): + this_out_f = this_out_f[:-4] + elif this_out_f.endswith(".genbank"): + this_out_f = this_out_f[:-8] + output_base_names.append(this_out_f) + + # converting + for go_gb, gb_file in enumerate(argv): + if os.path.isfile(gb_file): + genbank_to_tbl(gb_file, output_base_names[go_gb], + accepted_type_dict=accepted_types, qualifiers_dict=qualifiers) + else: + sys.stdout.write("Error: " + gb_file + " not found!" + line_br) + + +if __name__ == '__main__': + main() diff --git a/Utilities/get_organelle_config.py b/Utilities/get_organelle_config.py index 473fa6c..31713d2 100755 --- a/Utilities/get_organelle_config.py +++ b/Utilities/get_organelle_config.py @@ -87,7 +87,9 @@ def get_options(description): parser.add_argument("--use-local", dest="use_local", help="Input a path. This local database path must include subdirectories " "LabelDatabase and SeedDatabase, under which there is the fasta file(s) named by the " - "organelle type you want add, such as fungus_mt.fasta. ") + "organelle type you want add, such as fungus_mt.fasta. " + "See https://github.com/Kinggerm/GetOrganelleDB#option-2-initialization-from-local-files " + "for the guidelines. ") parser.add_argument("--clean", dest="clean", default=False, action="store_true", help="Remove all configured database files (==\"--rm all\").") parser.add_argument("--list", dest="list_available", default=False, action="store_true", diff --git a/Utilities/slim_graph.py b/Utilities/slim_graph.py index 9e0d651..762b32e 100755 --- a/Utilities/slim_graph.py +++ b/Utilities/slim_graph.py @@ -16,9 +16,9 @@ import GetOrganelleLib from GetOrganelleLib.versions import get_versions from GetOrganelleLib.pipe_control_func import * +PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] import math import copy -PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] import platform SYSTEM_NAME = "" if platform.system() == "Linux": @@ -95,9 +95,9 @@ def get_options(print_title): "\nplant_mt->embplant_mt; plant_nr->embplant_nr") parser.add_argument("--no-hits", dest="treat_no_hits", default="ex_no_con", help="Provide treatment for non-hitting contigs.\t" - "\nex_no_con \t keep those connect with hitting-include contigs. (Default)" - "\nex_no_hit \t exclude all." - "\nkeep_all \t keep all") + "\nDefault: ex_no_con: keep those connect with hitting-include contigs. " + "\nex_no_hit: exclude all. " + "\nkeep_all: keep all.") parser.add_argument("--max-slim-extending-len", dest="max_slim_extending_len", default=MAX_SLIM_EXTENDING_LENS["anonym"], type=float, @@ -119,7 +119,8 @@ def get_options(print_title): "rather than both.") parser.add_argument("--depth-cutoff", dest="depth_cutoff", default=10000.0, type=float, help="After detection for target coverage, those beyond certain times (depth cutoff) of the" - " detected coverage would be excluded. Default: %(default)s") + " detected coverage would be excluded. Use -1 to disable this process. " + "Default: %(default)s") parser.add_argument("--min-depth", dest="min_depth", default=0., type=float, help="Input a float or integer number. Filter fastg file by a minimum depth. Default: %(default)s.") parser.add_argument("--max-depth", dest="max_depth", default=inf, type=float, @@ -141,9 +142,16 @@ def get_options(print_title): parser.add_argument("-o", "--out-dir", dest="out_dir", help="By default the output would be along with the input fastg file. " "But you could assign a new directory with this option.") + parser.add_argument("--perc-hit", "--contig-min-hit-percent", dest="contig_min_hit_percent", default=0., type=float, + help="[0.0, 1.0], " + "For each database, if the hits in a contig covers less than contig_min_hit_percent, " + "these hits will be discarded, biologically meaning that " + "contig does not represent that database. " + "This is useful to exclude true organelle contigs from the assembled genome " + "without removing the nu-pt or nu-mt. Default: %(default)s") parser.add_argument("-e", "--evalue", dest="evalue", default=1e-25, type=float, help="blastn evalue threshold. Default: %(default)s") - parser.add_argument("--percent", "--perc-identity", dest="percent_identity", default=None, type=float, + parser.add_argument("--perc-identity", dest="percent_identity", default=None, type=float, help="blastn percent identity threshold. Default unset.") parser.add_argument("--blast-options", dest="blast_options", default="", help="other blastn options. e.g. --blast-options \"-word_size 13\".") @@ -242,7 +250,8 @@ def get_options(print_title): 'one of them should be assigned priority!\n') exit() if ex_chosen == 1 and in_chosen == 0 and (options.treat_no_hits in ["ex_no_con", "ex_no_hit"]): - sys.stdout.write('\n\nOption Error: no contigs survive according to you choice!\n') + sys.stdout.write( + "\n\nOption Error: no contigs survive according to you choice! Use \'--no-hits keep_all\'\n") exit() if options.include_priority: include_priority_str = str(options.include_priority) @@ -405,6 +414,8 @@ def _check_default_db(this_sub_organelle, extra_type=""): os.mkdir(log_output_dir) assert not (options.out_base and options.prefix), "\"--out-base\" conflicts with \"--prefix\"!" assert not (options.out_base and len(options.assemblies) > 1), "\"--out-base\" conflicts with multiple input files!" + assert options.depth_cutoff > 0 or options.depth_cutoff == -1 + assert 0. <= options.contig_min_hit_percent <= 1. if options.out_base: # Replace illegal characters for blastn options.out_base = os.path.basename(options.out_base.replace("'", "_")) @@ -483,11 +494,31 @@ def blast_and_call_names( index_files, out_file, threads, + contig_min_hit_percent=0., e_value=1e-25, percent_identity=None, other_options="", which_blast="", log_handler=None): + """ + :param fasta_file: + :param index_files: + :param out_file: + :param threads: + :param contig_min_hit_percent: float + For each database, + if the hits in a contig covers less than contig_min_hit_percent, these hits will be discarded, + biologically meaning that contig does not represent the database. + This is useful to exclude true organelle contigs from the assembled genome without removing the nu-pt or nu-mt. + :param e_value: + :param percent_identity: + :param other_options: + :param which_blast: + :param log_handler: + :return: + """ + assert 0 <= contig_min_hit_percent < 1 + from GetOrganelleLib.seq_parser import get_fasta_lengths names = {} if index_files: time0 = time.time() @@ -517,6 +548,7 @@ def blast_and_call_names( for line in blast_out_lines: line_split = line.strip().split('\t') query, hit = line_split[0], line_split[1] + # TODO: maybe add hit start and end to provide information to assembly.parse_tab_file() q_start, q_end, q_score = int(line_split[6]), int(line_split[7]), float(line_split[2]) q_min, q_max = min(q_start, q_end), max(q_start, q_end) # q_score = abs(q_max - q_min + 1)*q_score @@ -548,6 +580,34 @@ def blast_and_call_names( log_handler.info("Parsing blast result finished.") else: sys.stdout.write('Parsing blast result cost: '+str(round(time.time()-time1, 2)) + "\n") + if contig_min_hit_percent > 0: + query_lengths = get_fasta_lengths(fasta_file, blast_form_seq_name=True) + for query in sorted(names): + for this_database in sorted(names[query]): + # ranges are the ordered non overlapping ranges in the contig + # used to calculate the hit percentage of the contig + ranges = [] + for hit, q_info in names[query][this_database].items(): # mix q_info from different hits + for q_min, q_max, q_score in q_info: + # each time, try to insert (q_min, q_max) into ranges + i = 0 + while i < len(ranges): + this_min, this_max = ranges[i] + if q_max < this_min: + break + elif q_min > this_max: + i += 1 + continue + else: # overlap, then merge + q_min = min(q_min, this_min) + q_max = max(q_max, this_max) + del ranges[i] + ranges.insert(i, [q_min, q_max]) + hit_cover = float(sum([q_max_ - q_min_ + 1 for q_min_, q_max_ in ranges])) + if hit_cover / query_lengths[query] < contig_min_hit_percent: + del names[query][this_database] + if not names[query]: + del names[query] return names @@ -717,7 +777,7 @@ def generate_baits_offsets(in_names, databases, assembly_graph): these_hit_list.extend(hit_info_list) all_loc_values = [x[0] for x in these_hit_list] + [x[1] for x in these_hit_list] min_loc = min(all_loc_values) - max_loc = min(all_loc_values) + max_loc = max(all_loc_values) vertex_trimming[(vertex_name, False)] = min_loc - 1 vertex_trimming[(vertex_name, True)] = assembly_graph.vertex_info[vertex_name].len - max_loc return vertex_trimming @@ -726,16 +786,22 @@ def generate_baits_offsets(in_names, databases, assembly_graph): def reduce_matrix(in_names, ex_names, seq_matrix, assembly_graph, max_slim_extending_len, bait_offsets, aver_in_dep, depth_cutoff, treat_no_hits, include_priority_assigned, include_assigned, exclude_priority_assigned, exclude_assigned, - log_handler=None): + verbose, log_handler=None): if log_handler: log_handler.info("Mapping names ...") time0 = time.time() # candidate_short_to_2fulls = {} if assembly_graph: - if aver_in_dep: + if depth_cutoff != -1 and aver_in_dep: rm_contigs = [this_v.v_name for this_v in assembly_graph - if 2 ** abs(math.log(this_v.cov / aver_in_dep, 2)) < depth_cutoff] + if 2 ** abs(math.log(this_v.cov / aver_in_dep, 2)) >= depth_cutoff] + if verbose: + if log_handler: + log_handler.info("Depth-cutoff-based removing(" + str(len(rm_contigs)) + "): " + str(rm_contigs)) + else: + sys.stdout.write( + "Depth-cutoff-based removing(" + str(len(rm_contigs)) + "): " + str(rm_contigs) + "\n") assembly_graph.remove_vertex(rm_contigs) if exclude_priority_assigned: assembly_graph.remove_vertex(ex_names) @@ -743,7 +809,9 @@ def reduce_matrix(in_names, ex_names, seq_matrix, assembly_graph, max_slim_exten assembly_graph.reduce_to_subgraph(bait_vertices=in_names, bait_offsets=bait_offsets, limit_extending_len=max_slim_extending_len, - extending_len_weighted_by_depth=True) + extending_len_weighted_by_depth=True, + verbose=verbose, + log_handler=log_handler) elif treat_no_hits == "ex_no_hit": assembly_graph.remove_vertex([rm_c for rm_c in assembly_graph.vertex_info if rm_c not in in_names]) else: @@ -754,7 +822,9 @@ def reduce_matrix(in_names, ex_names, seq_matrix, assembly_graph, max_slim_exten assembly_graph.reduce_to_subgraph(bait_vertices=in_names, bait_offsets=bait_offsets, limit_extending_len=max_slim_extending_len, - extending_len_weighted_by_depth=True) + extending_len_weighted_by_depth=True, + verbose=verbose, + log_handler=log_handler) elif treat_no_hits == "ex_no_hit": assembly_graph.remove_vertex([rm_c for rm_c in assembly_graph.vertex_info if rm_c not in in_names]) else: @@ -763,6 +833,13 @@ def reduce_matrix(in_names, ex_names, seq_matrix, assembly_graph, max_slim_exten assembly_graph.remove_vertex(ex_names) else: pass + if verbose: + if log_handler: + log_handler.info("Vertices(" + str(len(assembly_graph.vertex_info)) + "): " + + str(assembly_graph.vertex_clusters)) + else: + sys.stdout.write("Vertices(" + str(len(assembly_graph.vertex_info)) + "): " + + str(assembly_graph.vertex_clusters) + "\n") else: # accepted = set() if exclude_priority_assigned: @@ -973,7 +1050,15 @@ def main(): this_assembly = Assembly() this_matrix = None if is_graph: - this_assembly = Assembly(graph_file=fas_file, min_cov=options.min_depth, max_cov=options.max_depth) + this_assembly = Assembly( + graph_file=fas_file, min_cov=options.min_depth, max_cov=options.max_depth, log_handler=log_handler) + if options.verbose_log: + if log_handler: + log_handler.info("Vertices(" + str(len(this_assembly.vertex_info)) + "): " + + str(this_assembly.vertex_clusters)) + else: + sys.stdout.write("Vertices(" + str(len(this_assembly.vertex_info)) + "): " + + str(this_assembly.vertex_clusters) + "\n") # merge contigs if options.merge_contigs: this_assembly.merge_all_possible_vertices() @@ -1015,21 +1100,39 @@ def main(): # structure: names[query][this_database][label] = [(q_min, q_max, q_score)] in_names = blast_and_call_names(fasta_file=blast_fas, index_files=include_indices, out_file=blast_fas+'.blast_in', threads=options.threads, + contig_min_hit_percent=options.contig_min_hit_percent, e_value=options.evalue, percent_identity=options.percent_identity, other_options=options.blast_options, which_blast=options.which_blast, log_handler=log_handler) ex_names = blast_and_call_names(fasta_file=blast_fas, index_files=exclude_indices, out_file=blast_fas+'.blast_ex', threads=options.threads, + contig_min_hit_percent=options.contig_min_hit_percent, e_value=options.evalue, percent_identity=options.percent_identity, other_options=options.blast_options, which_blast=options.which_blast, log_handler=log_handler) - if bool(include_indices) or bool(exclude_indices): + if options.verbose_log: + if log_handler: + log_handler.info("in_names(" + str(len(in_names)) + "): " + str(sorted(in_names))) + log_handler.info("ex_names(" + str(len(ex_names)) + "): " + str(sorted(ex_names))) + else: + sys.stdout.write("in_names(" + str(len(in_names)) + "): " + str(sorted(in_names)) + "\n") + sys.stdout.write("ex_names(" + str(len(ex_names)) + "): " + str(sorted(ex_names)) + "\n") + if options.depth_cutoff != -1 and (bool(include_indices) or bool(exclude_indices)): in_names_r, ex_names_r, aver_dep = modify_in_ex_according_to_depth( in_names=in_names, ex_names=ex_names, significant=options.significant, assembly_graph=this_assembly, depth_cutoff=options.depth_cutoff, log_handler=log_handler) + if options.verbose_log: + if log_handler: + log_handler.info("in_names_r(" + str(len(in_names_r)) + "): " + str(sorted(in_names_r))) + log_handler.info("ex_names_r(" + str(len(ex_names_r)) + "): " + str(sorted(ex_names))) + else: + sys.stdout.write( + "in_names_r(" + str(len(in_names_r)) + "): " + str(sorted(in_names_r)) + "\n") + sys.stdout.write( + "ex_names_r(" + str(len(ex_names_r)) + "): " + str(sorted(ex_names_r)) + "\n") else: in_names_r, ex_names_r, aver_dep = in_names, ex_names, None - # prepare bait_offsets: trim unlabeled terminal regions from bait vertices, for more accurate + # prepare bait_offsets: trim unlabeled terminal regions from bait vertices_set, for more accurate # control of "maximum slimming extending length" if this_assembly and options.treat_no_hits == "ex_no_con" and \ options.max_slim_extending_len not in (None, inf): @@ -1037,6 +1140,11 @@ def main(): in_names=in_names, databases=include_indices, assembly_graph=this_assembly) else: bait_offsets = {} + # if options.verbose_log: + # if log_handler: + # log_handler.info("bait_offsets: " + str(bait_offsets)) + # else: + # sys.stdout.write("bait_offsets: " + str(bait_offsets) + "\n") new_assembly, new_matrix = \ reduce_matrix(in_names=in_names_r, ex_names=ex_names_r, seq_matrix=this_matrix, assembly_graph=this_assembly, max_slim_extending_len=options.max_slim_extending_len, @@ -1045,11 +1153,19 @@ def main(): include_priority_assigned=options.include_priority, include_assigned=options.include, exclude_priority_assigned=options.exclude_priority, - exclude_assigned=options.exclude, log_handler=log_handler) + exclude_assigned=options.exclude, verbose=options.verbose_log, + log_handler=log_handler) if log_handler: log_handler.info("Generating slimmed file to " + out_fas) else: sys.stdout.write("Generating slimmed file to " + out_fas + "\n") + if options.verbose_log: + if log_handler: + log_handler.info("Vertices(" + str(len(new_assembly.vertex_info)) + "): " + + str(new_assembly.vertex_clusters)) + else: + sys.stdout.write("Vertices(" + str(len(new_assembly.vertex_info)) + "): " + + str(new_assembly.vertex_clusters) + "\n") if is_graph: if is_fastg: new_assembly.write_to_fastg(out_fas, check_postfix=False) diff --git a/get_organelle_from_assembly.py b/get_organelle_from_assembly.py index 41016d3..18ec27f 100755 --- a/get_organelle_from_assembly.py +++ b/get_organelle_from_assembly.py @@ -14,8 +14,8 @@ import subprocess import sys import os +from copy import deepcopy from shutil import copyfile, rmtree - PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] import platform @@ -129,13 +129,16 @@ def get_options(description, version): parser.add_argument("--spades-out-dir", dest="spades_scaffolds_path", help="Input spades output directory with 'scaffolds.fasta' and 'scaffolds.paths', which are " "used for scaffolding disconnected contigs with GAPs. Default: disabled") - parser.add_argument("--depth-factor", dest="depth_factor", default=10.0, type=float, + parser.add_argument("--depth-factor", dest="depth_factor", default=5.0, type=float, help="Depth factor for differentiate genome type of contigs. " "The genome type of contigs are determined by blast. " "Default: %(default)s") parser.add_argument("--type-f", dest="type_factor", type=float, default=3., help="Type factor for identifying contig type tag when multiple tags exist in one contig. " "Default:%(default)s") + parser.add_argument("--weight-f", dest="weight_factor", type=float, default=100.0, + help="weight factor for excluding isolated/terminal suspicious contigs with gene labels. " + "Default:%(default)s") parser.add_argument("--contamination-depth", dest="contamination_depth", default=3., type=float, help="Depth factor for confirming contamination in parallel contigs. Default: %(default)s") parser.add_argument("--contamination-similarity", dest="contamination_similarity", default=0.9, type=float, @@ -178,9 +181,9 @@ def get_options(description, version): "Choose this flag to export all combinations.") parser.add_argument("--min-sigma", dest="min_sigma_factor", type=float, default=0.1, help="Minimum deviation factor for excluding non-target contigs. Default:%(default)s") - parser.add_argument("--max-multiplicity", dest="max_multiplicity", type=int, default=8, - help="Maximum multiplicity of contigs for disentangling genome paths. " - "Should be 1~12. Default:%(default)s") + # parser.add_argument("--max-multiplicity", dest="max_multiplicity", type=int, default=8, + # help="Maximum multiplicity of contigs for disentangling genome paths. " + # "Should be 1~12. Default:%(default)s") parser.add_argument("-t", dest="threads", type=int, default=1, help="Maximum threads to use.") parser.add_argument("--prefix", dest="prefix", default="", @@ -278,7 +281,8 @@ def _check_default_db(this_sub_organelle, extra_type=""): sys.stdout.write("\n############################################################################" "\nERROR: default " + this_sub_organelle + "," * int(bool(extra_type)) + extra_type + " database not added yet!\n" - "\nInstall it by: get_organelle_config.py -a " + this_sub_organelle + + "These two types must be used together!\n" * int(bool(extra_type)) + + "\nInstall it(them) by: get_organelle_config.py -a " + this_sub_organelle + "," * int(bool(extra_type)) + extra_type + "\nor\nInstall all types by: get_organelle_config.py -a all\n") exit() @@ -312,7 +316,7 @@ def _check_default_db(this_sub_organelle, extra_type=""): if not os.path.exists(scaffold_paths): raise FileNotFoundError(scaffold_paths + " not found!") assert options.threads > 0 - assert 12 >= options.max_multiplicity >= 1 + # assert 12 >= options.max_multiplicity >= 1 assert options.max_paths_num > 0 assert options.script_resume + options.script_overwrite < 2, "'--overwrite' conflicts with '--continue'" organelle_type_len = len(options.organelle_type) @@ -355,20 +359,20 @@ def _check_default_db(this_sub_organelle, extra_type=""): sys.exit() else: lib_versions_info.append("numpy " + np.__version__) + # try: + # import sympy + # except ImportError: + # log_handler.error("sympy is not available! Please install sympy!") + # sys.exit() + # else: + # lib_versions_info.append("sympy " + sympy.__version__) try: - import sympy + import gekko except ImportError: - log_handler.error("sympy is not available! Please install sympy!") + log_handler.error("gekko is not available! Please install gekko!") sys.exit() else: - lib_versions_info.append("sympy " + sympy.__version__) - try: - import scipy - except ImportError: - log_handler.error("scipy is not available! Please install scipy!") - sys.exit() - else: - lib_versions_info.append("scipy " + scipy.__version__) + lib_versions_info.append("gekko " + gekko.__version__) log_handler.info("PYTHON LIBS: " + "; ".join(lib_versions_info)) dep_versions_info = [] if not options.no_slim: @@ -505,13 +509,8 @@ def _check_default_db(this_sub_organelle, extra_type=""): log_handler.info("Options \"" + " ".join(remove_ops) + "\" taken/invalid for wrapped slim_graph.py, removed.") options.slim_options = " ".join(slim_op_parts) - random.seed(options.random_seed) - try: - import numpy as np - except ImportError: - pass - else: - np.random.seed(options.random_seed) + # random.seed(options.random_seed) + # np.random.seed(options.random_seed) return options, log_handler @@ -581,19 +580,43 @@ def slim_assembly_graph(organelle_types, in_custom, ex_custom, graph_in, graph_o def extract_organelle_genome(out_base, slim_out_fg, slim_out_csv, organelle_prefix, organelle_type, blast_db, verbose, log_handler, expected_maximum_size, expected_minimum_size, no_slim, options): from GetOrganelleLib.assembly_parser import Assembly, ProcessingGraphFailed + import random + random.seed(options.random_seed) + + # import numpy as np + + # testing random effect + # import random as rd_standard + # + # np.random.seed(options.random_seed) + # class test_random: + # def __init__(self): + # rd_standard.seed(options.random_seed) + # self.rd = rd_standard + # def random(self): + # print("random", self.rd.random()) + # return self.rd.random() + # def choice(self, *vars, **kwargs): + # print("choice", self.rd.random()) + # return self.rd.choice(*vars, **kwargs) + # def choices(self, *vars, **kwargs): + # print("choices", self.rd.random()) + # return self.rd.choices(*vars, **kwargs) + # random = test_random() - def disentangle_assembly(fastg_file, tab_file, output, weight_factor, log_dis, time_limit, type_factor=3., + def disentangle_assembly(assembly_obj, fastg_file, tab_file, output, weight_factor, log_dis, time_limit, + type_factor=3., mode="embplant_pt", blast_db_base="embplant_pt", contamination_depth=3., contamination_similarity=0.95, degenerate=True, degenerate_depth=1.5, degenerate_similarity=0.98, expected_max_size=inf, expected_min_size=0, hard_cov_threshold=10., - min_sigma_factor=0.1, here_max_copy=10, + min_sigma_factor=0.1, # here_max_copy=10, here_only_max_c=True, spades_scaffolds_path=None, here_acyclic_allowed=False, here_verbose=False, timeout_flag_str="'--disentangle-time-limit'", temp_graph=None): @set_time_limit(time_limit, flag_str=timeout_flag_str) - def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="embplant_pt", + def disentangle_inside(input_graph, fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="embplant_pt", in_db_n="embplant_pt", c_d=3., c_s=0.95, deg=True, deg_dep=1.5, deg_sim=0.98, - hard_c_t=10., min_s_f=0.1, max_copy_in=10, max_cov_in=True, + hard_c_t=10., min_s_f=0.1, max_cov_in=True, max_s=inf, min_s=0, spades_scaffold_p_in=None, acyclic_allowed_in=False, verbose_in=False, in_temp_graph=None): if spades_scaffold_p_in is not None: @@ -605,7 +628,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb else: log_in.info("Disentangling " + fastg_f + " as a circular genome ... ") image_produced = False - input_graph = Assembly(fastg_f) + # input_graph = Assembly(fastg_f) if spades_scaffold_p_in is not None: if not input_graph.add_gap_nodes_with_spades_res(os.path.join(spades_scaffold_p_in, "scaffolds.fasta"), os.path.join(spades_scaffold_p_in, "scaffolds.paths"), @@ -613,6 +636,16 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb log_handler=log_handler): raise ProcessingGraphFailed("No new connections.") else: + log_handler.info("Re-loading labels along " + slim_out_fg) + input_graph.parse_tab_file( + tab_f, + database_name=in_db_n, + type_factor=type_f, + max_gene_gap=250, + max_cov_diff=hard_c_t, # contamination_depth? + verbose=verbose, + log_handler=log_handler, + random_obj=random) if in_temp_graph: if in_temp_graph.endswith(".gfa"): this_tmp_graph = in_temp_graph[:-4] + ".scaffolds.gfa" @@ -620,27 +653,34 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb this_tmp_graph = in_temp_graph + ".scaffolds.gfa" input_graph.write_to_gfa(this_tmp_graph) if no_slim: - input_graph.estimate_copy_and_depth_by_cov(mode=mode_in, log_handler=log_in, verbose=verbose_in) + new_average_cov = \ + input_graph.estimate_copy_and_depth_by_cov(mode=mode_in, log_handler=log_in, verbose=verbose_in) target_results = input_graph.estimate_copy_and_depth_precisely( - maximum_copy_num=max_copy_in, - broken_graph_allowed=acyclic_allowed_in, return_new_graphs=True, verbose=verbose_in, + expected_average_cov=new_average_cov, + # broken_graph_allowed=acyclic_allowed_in, + verbose=verbose_in, log_handler=log_in) + log_target_res(target_results, + log_handler=log_handler, + universal_overlap=bool(input_graph.uni_overlap()), + mode=mode_in) else: - target_results = input_graph.find_target_graph(tab_f, - mode=mode_in, database_name=in_db_n, type_factor=type_f, + selected_graph = o_p + ".graph.selected_graph.gfa" + target_results = input_graph.find_target_graph(mode=mode_in, db_name=in_db_n, # type_factor=type_f, hard_cov_threshold=hard_c_t, contamination_depth=c_d, contamination_similarity=c_s, degenerate=deg, degenerate_depth=deg_dep, degenerate_similarity=deg_sim, expected_max_size=max_s, expected_min_size=min_s, - max_contig_multiplicity=max_copy_in, only_keep_max_cov=max_cov_in, min_sigma_factor=min_s_f, weight_factor=w_f, broken_graph_allowed=acyclic_allowed_in, log_handler=log_in, verbose=verbose_in, - temp_graph=in_temp_graph) + temp_graph=in_temp_graph, + selected_graph=selected_graph, + random_obj=random) if not target_results: raise ProcessingGraphFailed("No target graph detected!") if len(target_results) > 1: @@ -656,12 +696,14 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb go_res += 1 broken_graph = res["graph"] count_path = 0 - - these_paths = broken_graph.get_all_paths(mode=mode_in, log_handler=log_in) + # use options.max_paths_num + 1 to trigger the warning + these_paths = broken_graph.get_all_paths(mode=mode_in, log_handler=log_in, + max_paths_num=options.max_paths_num + 1) # reducing paths if len(these_paths) > options.max_paths_num: log_in.warning("Only exporting " + str(options.max_paths_num) + " out of all " + - str(len(these_paths)) + " possible paths. (see '--max-paths-num' to change it.)") + str(options.max_paths_num) + + "+ possible paths. (see '--max-paths-num' to change it.)") these_paths = these_paths[:options.max_paths_num] # exporting paths, reporting results @@ -715,7 +757,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb " scaffold(s) to " + out_n) open(out_n, "w").write("\n".join(all_contig_str)) if set(still_complete[-len(these_paths):]) == {"complete"}: - this_out_base = o_p + ".complete.graph" + str(go_res) + ".selected_graph." + this_out_base = o_p + ".complete.graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + this_out_base + "gfa") broken_graph.write_to_gfa(this_out_base + "gfa") image_produced = draw_assembly_graph_using_bandage( @@ -723,7 +765,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb assembly_graph_ob=broken_graph, log_handler=log_handler, verbose_log=verbose_in, which_bandage=options.which_bandage) elif set(still_complete[-len(these_paths):]) == {"nearly-complete"}: - this_out_base = o_p + ".nearly-complete.graph" + str(go_res) + ".selected_graph." + this_out_base = o_p + ".nearly-complete.graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + this_out_base + "gfa") broken_graph.write_to_gfa(this_out_base + "gfa") image_produced = draw_assembly_graph_using_bandage( @@ -731,7 +773,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb assembly_graph_ob=broken_graph, log_handler=log_handler, verbose_log=verbose_in, which_bandage=options.which_bandage) else: - this_out_base = o_p + ".contigs.graph" + str(go_res) + ".selected_graph." + this_out_base = o_p + ".contigs.graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + this_out_base + "gfa") broken_graph.write_to_gfa(this_out_base + "gfa") # image_produced = draw_assembly_graph_using_bandage( @@ -751,14 +793,16 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb go_res += 1 idealized_graph = res["graph"] count_path = 0 - + # use options.max_paths_num + 1 to trigger the warning these_paths = idealized_graph.get_all_circular_paths( - mode=mode_in, log_handler=log_in, reverse_start_direction_for_pt=options.reverse_lsc) + mode=mode_in, log_handler=log_in, reverse_start_direction_for_pt=options.reverse_lsc, + max_paths_num=options.max_paths_num + 1) # reducing paths if len(these_paths) > options.max_paths_num: log_in.warning("Only exporting " + str(options.max_paths_num) + " out of all " + - str(len(these_paths)) + " possible paths. (see '--max-paths-num' to change it.)") + str(options.max_paths_num) + + "+ possible paths. (see '--max-paths-num' to change it.)") these_paths = these_paths[:options.max_paths_num] # exporting paths, reporting results @@ -781,7 +825,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb ":".join([str(len_val) for len_val in ir_stats[:3]])) log_in.info( "Writing PATH" + str(count_path) + " of " + status_str + " " + mode_in + " to " + out_n) - temp_base_out = o_p + "." + status_str + ".graph" + str(go_res) + ".selected_graph." + temp_base_out = o_p + "." + status_str + ".graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + temp_base_out + "gfa") idealized_graph.write_to_gfa(temp_base_out + "gfa") image_produced = draw_assembly_graph_using_bandage( @@ -805,26 +849,56 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb " using Bandage to confirm the final result.") log_in.info("Writing output finished.") - disentangle_inside(fastg_f=fastg_file, tab_f=tab_file, o_p=output, w_f=weight_factor, log_in=log_dis, + disentangle_inside(input_graph=deepcopy(assembly_obj), + fastg_f=fastg_file, tab_f=tab_file, o_p=output, w_f=weight_factor, log_in=log_dis, type_f=type_factor, mode_in=mode, in_db_n=blast_db_base, c_d=contamination_depth, c_s=contamination_similarity, deg=degenerate, deg_dep=degenerate_depth, deg_sim=degenerate_similarity, hard_c_t=hard_cov_threshold, min_s_f=min_sigma_factor, - max_copy_in=here_max_copy, max_cov_in=here_only_max_c, + max_cov_in=here_only_max_c, # max_copy_in=here_max_copy, max_s=expected_max_size, min_s=expected_min_size, acyclic_allowed_in=here_acyclic_allowed, spades_scaffold_p_in=spades_scaffolds_path, verbose_in=here_verbose, in_temp_graph=temp_graph) + # parsing tab file + # if meta: + # try: + # self.parse_tab_file( + # tab_file, + # database_name=db_name, + # type_factor=type_factor, + # max_gene_gap=250, + # max_cov_diff=hard_cov_threshold, + # verbose=verbose, + # log_handler=log_handler) + # except ProcessingGraphFailed: + # return [] + # else: + # only parsing the assembly obj and tab file once + log_handler.info("Parsing " + slim_out_fg) + assembly_graph_obj = Assembly(slim_out_fg, log_handler=log_handler) + log_handler.info("Loading and cleaning labels along " + slim_out_fg) + assembly_graph_obj.parse_tab_file( + slim_out_csv, + database_name=blast_db, + type_factor=options.type_factor, + max_gene_gap=250, + max_cov_diff=options.depth_factor, # contamination_depth? + verbose=verbose, + log_handler=log_handler, + random_obj=random) + # start timeout_flag = "'--disentangle-time-limit'" export_succeeded = False path_prefix = os.path.join(out_base, organelle_prefix) - graph_temp_file = path_prefix + ".temp.gfa" if options.keep_temp_files else None + graph_temp_file1 = path_prefix + ".temp.R1.gfa" if options.keep_temp_files else None try: """disentangle""" - disentangle_assembly(fastg_file=slim_out_fg, blast_db_base=blast_db, mode=organelle_type, + disentangle_assembly(assembly_obj=assembly_graph_obj, + fastg_file=slim_out_fg, blast_db_base=blast_db, mode=organelle_type, tab_file=slim_out_csv, output=path_prefix, - weight_factor=100, type_factor=options.type_factor, + weight_factor=options.weight_factor, type_factor=options.type_factor, hard_cov_threshold=options.depth_factor, contamination_depth=options.contamination_depth, contamination_similarity=options.contamination_similarity, @@ -832,12 +906,12 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb degenerate_similarity=options.degenerate_similarity, expected_max_size=expected_maximum_size, expected_min_size=expected_minimum_size, - here_max_copy=options.max_multiplicity, + # here_max_copy=options.max_multiplicity, here_only_max_c=options.only_keep_max_cov, min_sigma_factor=options.min_sigma_factor, here_acyclic_allowed=False, here_verbose=verbose, log_dis=log_handler, time_limit=options.disentangle_time_limit, timeout_flag_str=timeout_flag, - temp_graph=graph_temp_file) + temp_graph=graph_temp_file1) except ImportError as e: log_handler.error("Disentangling failed: " + str(e)) return False @@ -856,12 +930,14 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb else: export_succeeded = True - if not export_succeeded: + if not export_succeeded and options.spades_scaffolds_path: + graph_temp_file1s = path_prefix + ".temp.R1S.gfa" if options.keep_temp_files else None try: """disentangle""" - disentangle_assembly(fastg_file=slim_out_fg, blast_db_base=blast_db, mode=organelle_type, + disentangle_assembly(assembly_obj=assembly_graph_obj, + fastg_file=slim_out_fg, blast_db_base=blast_db, mode=organelle_type, tab_file=slim_out_csv, output=path_prefix, - weight_factor=100, type_factor=options.type_factor, + weight_factor=options.weight_factor, type_factor=options.type_factor, hard_cov_threshold=options.depth_factor, contamination_depth=options.contamination_depth, contamination_similarity=options.contamination_similarity, @@ -869,13 +945,13 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb degenerate_similarity=options.degenerate_similarity, expected_max_size=expected_maximum_size, expected_min_size=expected_minimum_size, - here_max_copy=options.max_multiplicity, + # here_max_copy=options.max_multiplicity, here_only_max_c=options.only_keep_max_cov, min_sigma_factor=options.min_sigma_factor, spades_scaffolds_path=options.spades_scaffolds_path, here_acyclic_allowed=False, here_verbose=verbose, log_dis=log_handler, time_limit=options.disentangle_time_limit, timeout_flag_str=timeout_flag, - temp_graph=graph_temp_file) + temp_graph=graph_temp_file1s) except ImportError as e: log_handler.error("Disentangling failed: " + str(e)) return False @@ -895,13 +971,16 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb export_succeeded = True if not export_succeeded: + graph_temp_file2 = path_prefix + ".temp.R2.gfa" if options.keep_temp_files else None try: """disentangle the graph as scaffold(s)/contig(s)""" - disentangle_assembly(fastg_file=slim_out_fg, blast_db_base=blast_db, mode=organelle_type, + disentangle_assembly(assembly_obj=assembly_graph_obj, + fastg_file=slim_out_fg, blast_db_base=blast_db, mode=organelle_type, tab_file=slim_out_csv, output=path_prefix, - weight_factor=100, type_factor=options.type_factor, + weight_factor=options.weight_factor, type_factor=options.type_factor, here_verbose=verbose, log_dis=log_handler, - hard_cov_threshold=options.depth_factor * 0.8, + hard_cov_threshold=options.depth_factor * 0.6, + # TODO the adjustment should be changed if it's RNA data contamination_depth=options.contamination_depth, contamination_similarity=options.contamination_similarity, degenerate=options.degenerate, @@ -910,10 +989,10 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb expected_max_size=expected_maximum_size, expected_min_size=expected_minimum_size, min_sigma_factor=options.min_sigma_factor, - here_max_copy=options.max_multiplicity, + # here_max_copy=options.max_multiplicity, here_only_max_c=options.only_keep_max_cov, here_acyclic_allowed=True, time_limit=3600, timeout_flag_str=timeout_flag, - temp_graph=graph_temp_file) + temp_graph=graph_temp_file2) except (ImportError, AttributeError) as e: log_handler.error("Disentangling failed: " + str(e).strip()) except RuntimeError as e: @@ -983,7 +1062,8 @@ def main(): raise Exception("Input assembly graph file must have name suffix '.gfa' or '.fastg'.") processed_graph_file = os.path.join(options.output_base, options.prefix + "initial_assembly_graph." + in_postfix) if options.max_depth != inf or options.min_depth != 0.: - this_graph = Assembly(options.input_graph, max_cov=options.max_depth, min_cov=options.min_depth) + this_graph = Assembly( + options.input_graph, max_cov=options.max_depth, min_cov=options.min_depth, log_handler=log_handler) if in_postfix.endswith("gfa"): this_graph.write_to_gfa(out_file=processed_graph_file, check_postfix=False) else: @@ -1023,7 +1103,8 @@ def main(): exit() else: if os.path.getsize(slimmed_graph_file) == 0: - return "Slimming " + processed_graph_file + " finished with no target organelle contigs found!" + raise Exception( + "Slimming " + processed_graph_file + " finished with no target organelle contigs found!") log_handler.info("Slimming assembly graph finished.\n") organelle_type_prefix = [] @@ -1038,7 +1119,7 @@ def main(): for go_t, sub_organelle_type in enumerate(options.organelle_type): og_prefix = options.prefix + organelle_type_prefix[go_t] graph_existed = bool([gfa_f for gfa_f in os.listdir(options.output_base) - if gfa_f.startswith(og_prefix) and gfa_f.endswith(".selected_graph.gfa")]) + if gfa_f.startswith(og_prefix) and gfa_f.endswith(".path_sequence.gfa")]) fasta_existed = bool([fas_f for fas_f in os.listdir(options.output_base) if fas_f.startswith(og_prefix) and fas_f.endswith(".path_sequence.fasta")]) if options.script_resume and graph_existed and fasta_existed: diff --git a/get_organelle_from_reads.py b/get_organelle_from_reads.py index c3027d8..4fc904d 100755 --- a/get_organelle_from_reads.py +++ b/get_organelle_from_reads.py @@ -316,8 +316,8 @@ def get_options(description, version): "with the same list length to organelle_type (followed by '-F'). " "This is optional for any organelle mentioned in '-F' but required for 'anonym'. " "By default, certain database(s) in " + str(LBL_DB_PATH) + " would be used " - "contingent on the organelle types chosen (-F). " - "The default value become invalid when '--genes' or '--ex-genes' is used.") + "contingent on the organelle types chosen (-F). " + "The default value become invalid when '--genes' or '--ex-genes' is used.") group_assembly.add_argument("--ex-genes", dest="exclude_genes", help="This is optional and Not suggested, since non-target contigs could contribute " "information for better downstream coverage-based clustering. " @@ -327,10 +327,13 @@ def get_options(description, version): "Could be a list of databases split by comma(s) but " "NOT required to have the same list length to organelle_type (followed by '-F'). " "The default value will become invalid when '--genes' or '--ex-genes' is used.") - group_assembly.add_argument("--disentangle-df", dest="disentangle_depth_factor", default=10.0, type=float, + group_assembly.add_argument("--disentangle-df", dest="disentangle_depth_factor", default=5.0, type=float, help="Depth factor for differentiate genome type of contigs. " "The genome type of contigs are determined by blast. " "Default: %(default)s") + group_assembly.add_argument("--disentangle-tf", dest="disentangle_type_factor", type=float, default=3., + help="Type factor for identifying contig type tag when multiple tags exist in one contig. " + "Default:%(default)s") group_assembly.add_argument("--contamination-depth", dest="contamination_depth", default=3., type=float, help="Depth factor for confirming contamination in parallel contigs. Default: %(default)s") group_assembly.add_argument("--contamination-similarity", dest="contamination_similarity", default=0.9, @@ -528,7 +531,8 @@ def _check_default_db(this_sub_organelle, extra_type=""): sys.stderr.write("\n############################################################################" "\nERROR: default " + this_sub_organelle + "," * int(bool(extra_type)) + extra_type + " database not added yet!\n" - "\nInstall it by: get_organelle_config.py -a " + this_sub_organelle + + "These two types must be used together!\n" * int(bool(extra_type)) + + "\nInstall it(them) by: get_organelle_config.py -a " + this_sub_organelle + "," * int(bool(extra_type)) + extra_type + "\nor\nInstall all types by: get_organelle_config.py -a all\n") exit() @@ -661,23 +665,23 @@ def _check_default_db(this_sub_organelle, extra_type=""): lib_not_available = [] lib_versions_info.append("GetOrganelleLib " + GetOrganelleLib.__version__) try: - import numpy + import numpy as np except ImportError: lib_not_available.append("numpy") else: - lib_versions_info.append("numpy " + numpy.__version__) - try: - import sympy - except ImportError: - lib_not_available.append("sympy") - else: - lib_versions_info.append("sympy " + sympy.__version__) + lib_versions_info.append("numpy " + np.__version__) + # try: + # import sympy + # except ImportError: + # lib_not_available.append("sympy") + # else: + # lib_versions_info.append("sympy " + sympy.__version__) try: - import scipy + import gekko except ImportError: - lib_not_available.append("scipy") + lib_not_available.append("gekko") else: - lib_versions_info.append("scipy " + scipy.__version__) + lib_versions_info.append("gekko " + gekko.__version__) try: import psutil except ImportError: @@ -1136,7 +1140,8 @@ def estimate_maximum_n_reads_using_mapping( coverages_2 = [pos for ref in coverage_info for pos in coverage_info[ref] if pos > 0] base_cov_values = get_cover_range(coverages_2, guessing_percent=BASE_COV_SAMPLING_PERCENT) mean_read_len, max_read_len, all_read_nums = \ - get_read_len_mean_max_count(mapped_fq, maximum_n_reads_hard_bound) + get_read_len_mean_max_count(mapped_fq, maximum_n_reads_hard_bound, n_process=1) + # get_read_len_mean_max_count(mapped_fq, maximum_n_reads_hard_bound, n_process=threads) if executable(os.path.join(which_spades, "spades.py -h")) and \ executable(os.path.join(which_bowtie2, "bowtie2")): try: @@ -1563,7 +1568,7 @@ def check_parameters(word_size, original_fq_files, seed_fs_files, seed_fq_files, simulate_fq_simple(from_fasta_file=this_modified_graph, out_dir=seed_fq_files[go_t] + ".spades", out_name="get_org.assembly_graph.simulated.fq", - sim_read_jump_size=7, resume=resume) + sim_read_jump_size=7, resume=resume, random_obj=random) closest_seed_f = os.path.join(seed_fq_files[go_t] + ".spades", "get_org.closest_seed.fasta") seed_seq_list = SequenceList(seed_fs_files[go_t]) for seq_record in seed_seq_list: @@ -1827,6 +1832,337 @@ def make_read_index(original_fq_files, direction_according_to_user_input, all_re else: log_handler.info("indices for fastq existed!") len_indices = len([x for x in open(temp2_clusters_dir[1], 'r')]) + elif resume and os.path.exists(temp1_contig_dir[1]) and not rm_duplicates: + if index_in_memory: + log_handler.info("Reading existed indices for fastq ...") + # + if keep_seq_parts: + forward_reverse_reads = [x.strip().split("\t") for x in open(temp1_contig_dir[1], 'r')] + cancel_seq_parts = True if max([len(x) for x in forward_reverse_reads]) == 1 else False + else: + forward_reverse_reads = [x.strip() for x in open(temp1_contig_dir[1], 'r')] + + # lengths = [] + use_user_direction = False + for id_file, file_name in enumerate(original_fq_files): + file_in = open(file_name, "r") + count_this_read_n = 0 + line = file_in.readline() + # if anti seed input, name & direction should be recognized + if anti_seed: + while line and count_this_read_n < all_read_limits[id_file]: + if line.startswith("@"): + count_this_read_n += 1 + # parsing name & direction + if use_user_direction: + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + else: + try: + if ' ' in line: + this_head = line[1:].split(' ') + this_name, direction = this_head[0], int(this_head[1][0]) + elif '#' in line: + this_head = line[1:].split('#') + this_name, direction = this_head[0], int(this_head[1].strip("/")[0]) + elif line[-3] == "/" and line[-2].isdigit(): # 2019-04-22 added + this_name, direction = line[1:-3], int(line[-2]) + elif line[1:].strip().isdigit(): + log_handler.info("Using user-defined read directions. ") + use_user_direction = True + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + else: + log_handler.info('Unrecognized head: ' + file_name + ': ' + str(line.strip())) + log_handler.info("Using user-defined read directions. ") + use_user_direction = True + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + except (ValueError, IndexError): + log_handler.info('Unrecognized head: ' + file_name + ': ' + str(line.strip())) + log_handler.info("Using user-defined read directions. ") + use_user_direction = True + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + if (this_name, direction) in anti_lines: + line_count += 4 + for i in range(4): + line = file_in.readline() + continue + this_seq = file_in.readline().strip() + # drop nonsense reads + if len(this_seq) < word_size: + line_count += 4 + for i in range(3): + line = file_in.readline() + continue + file_in.readline() + quality_str = file_in.readline() + if do_split_low_quality: + this_seq = split_seq_by_quality_pattern(this_seq, quality_str, low_quality, word_size) + # drop nonsense reads + if not this_seq: + line_count += 4 + line = file_in.readline() + continue + line_clusters.append([line_count]) + else: + log_handler.error("Illegal fq format in line " + str(line_count) + ' ' + str(line)) + exit() + if echo_step != inf and line_count % echo_step == 0: + to_print = str("%s" % datetime.datetime.now())[:23].replace('.', ',') + " - INFO: " + str( + (line_count + 4) // 4) + " reads" + sys.stdout.write(to_print + '\b' * len(to_print)) + sys.stdout.flush() + line_count += 4 + line = file_in.readline() + else: + while line and count_this_read_n < all_read_limits[id_file]: + if line.startswith("@"): + count_this_read_n += 1 + this_seq = file_in.readline().strip() + + # drop nonsense reads + if len(this_seq) < word_size: + line_count += 4 + for i in range(3): + line = file_in.readline() + continue + + file_in.readline() + quality_str = file_in.readline() + if do_split_low_quality: + this_seq = split_seq_by_quality_pattern(this_seq, quality_str, low_quality, word_size) + # drop nonsense reads + if not this_seq: + line_count += 4 + line = file_in.readline() + continue + line_clusters.append([line_count]) + else: + log_handler.error("Illegal fq format in line " + str(line_count) + ' ' + str(line)) + exit() + if echo_step != inf and line_count % echo_step == 0: + to_print = str("%s" % datetime.datetime.now())[:23].replace('.', ',') + " - INFO: " + str( + (line_count + 4) // 4) + " reads" + sys.stdout.write(to_print + '\b' * len(to_print)) + sys.stdout.flush() + line_count += 4 + line = file_in.readline() + line = file_in.readline() + file_in.close() + if line: + log_handler.info("For " + file_name + ", only top " + str(int(all_read_limits[id_file])) + + " reads are used in downstream analysis.") + if this_process: + memory_usage = "Mem " + str(round(this_process.memory_info().rss / 1024.0 / 1024 / 1024, 3)) + " G, " + else: + memory_usage = '' + + del name_to_line + + if not index_in_memory: + # dump line clusters + len_indices = len(line_clusters) + temp2_indices_file_out = open(temp2_clusters_dir[0], 'w') + for this_index in range(len_indices): + temp2_indices_file_out.write('\t'.join([str(x) for x in line_clusters[this_index]])) + temp2_indices_file_out.write('\n') + temp2_indices_file_out.close() + os.rename(temp2_clusters_dir[0], temp2_clusters_dir[1]) + + del seq_duplicates + len_indices = len(line_clusters) + log_handler.info(memory_usage + str(len_indices) + " reads") + elif resume and os.path.exists(temp1_contig_dir[1]): + # lengths = [] + use_user_direction = False + for id_file, file_name in enumerate(original_fq_files): + file_in = open(file_name, "r") + count_this_read_n = 0 + line = file_in.readline() + # if anti seed input, name & direction should be recognized + if anti_seed: + while line and count_this_read_n < all_read_limits[id_file]: + if line.startswith("@"): + count_this_read_n += 1 + # parsing name & direction + if use_user_direction: + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + else: + try: + if ' ' in line: + this_head = line[1:].split(' ') + this_name, direction = this_head[0], int(this_head[1][0]) + elif '#' in line: + this_head = line[1:].split('#') + this_name, direction = this_head[0], int(this_head[1].strip("/")[0]) + elif line[-3] == "/" and line[-2].isdigit(): # 2019-04-22 added + this_name, direction = line[1:-3], int(line[-2]) + elif line[1:].strip().isdigit(): + log_handler.info("Using user-defined read directions. ") + use_user_direction = True + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + else: + log_handler.info('Unrecognized head: ' + file_name + ': ' + str(line.strip())) + log_handler.info("Using user-defined read directions. ") + use_user_direction = True + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + except (ValueError, IndexError): + log_handler.info('Unrecognized head: ' + file_name + ': ' + str(line.strip())) + log_handler.info("Using user-defined read directions. ") + use_user_direction = True + this_name = line[1:].strip() + direction = direction_according_to_user_input[id_file] + + if (this_name, direction) in anti_lines: + line_count += 4 + for i in range(4): + line = file_in.readline() + continue + this_seq = file_in.readline().strip() + # drop nonsense reads + if len(this_seq) < word_size: + line_count += 4 + for i in range(3): + line = file_in.readline() + continue + + file_in.readline() + quality_str = file_in.readline() + if do_split_low_quality: + this_seq = split_seq_by_quality_pattern(this_seq, quality_str, low_quality, word_size) + # drop nonsense reads + if not this_seq: + line_count += 4 + line = file_in.readline() + continue + + if keep_seq_parts: + if cancel_seq_parts and len(this_seq) > 1: + cancel_seq_parts = False + this_c_seq = complementary_seqs(this_seq) + # lengths.extend([len(seq_part) for seq_part in this_seq]) + else: + this_seq = this_seq[0] + this_c_seq = complementary_seq(this_seq) + # lengths.append(len(this_seq)) + else: + this_c_seq = complementary_seq(this_seq) + # lengths.append(len(this_seq)) + if this_seq in seq_duplicates: + line_clusters[seq_duplicates[this_seq]].append(line_count) + elif this_c_seq in seq_duplicates: + line_clusters[seq_duplicates[this_c_seq]].append(line_count) + else: + if index_in_memory: + forward_reverse_reads.append(this_seq) + forward_reverse_reads.append(this_c_seq) + seq_duplicates[this_seq] = this_index + line_clusters.append([line_count]) + this_index += 1 + if len(seq_duplicates) > rm_duplicates: + seq_duplicates = {} + else: + log_handler.error("Illegal fq format in line " + str(line_count) + ' ' + str(line)) + exit() + if echo_step != inf and line_count % echo_step == 0: + to_print = str("%s" % datetime.datetime.now())[:23].replace('.', ',') + " - INFO: " + str( + (line_count + 4) // 4) + " reads" + sys.stdout.write(to_print + '\b' * len(to_print)) + sys.stdout.flush() + line_count += 4 + line = file_in.readline() + else: + while line and count_this_read_n < all_read_limits[id_file]: + if line.startswith("@"): + count_this_read_n += 1 + this_seq = file_in.readline().strip() + + # drop nonsense reads + if len(this_seq) < word_size: + line_count += 4 + for i in range(3): + line = file_in.readline() + continue + + file_in.readline() + quality_str = file_in.readline() + if do_split_low_quality: + this_seq = split_seq_by_quality_pattern(this_seq, quality_str, low_quality, word_size) + # drop nonsense reads + if not this_seq: + line_count += 4 + line = file_in.readline() + continue + if keep_seq_parts: + if cancel_seq_parts and len(this_seq) > 1: + cancel_seq_parts = False + this_c_seq = complementary_seqs(this_seq) + # lengths.extend([len(seq_part) for seq_part in this_seq]) + else: + this_seq = this_seq[0] + this_c_seq = complementary_seq(this_seq) + # lengths.append(len(this_seq)) + else: + this_c_seq = complementary_seq(this_seq) + # lengths.append(len(this_seq)) + if this_seq in seq_duplicates: + line_clusters[seq_duplicates[this_seq]].append(line_count) + elif this_c_seq in seq_duplicates: + line_clusters[seq_duplicates[this_c_seq]].append(line_count) + else: + if index_in_memory: + forward_reverse_reads.append(this_seq) + forward_reverse_reads.append(this_c_seq) + seq_duplicates[this_seq] = this_index + line_clusters.append([line_count]) + this_index += 1 + if len(seq_duplicates) > rm_duplicates: + seq_duplicates = {} + else: + log_handler.error("Illegal fq format in line " + str(line_count) + ' ' + str(line)) + exit() + if echo_step != inf and line_count % echo_step == 0: + to_print = str("%s" % datetime.datetime.now())[:23].replace('.', ',') + " - INFO: " + str( + (line_count + 4) // 4) + " reads" + sys.stdout.write(to_print + '\b' * len(to_print)) + sys.stdout.flush() + line_count += 4 + line = file_in.readline() + line = file_in.readline() + file_in.close() + if line: + log_handler.info("For " + file_name + ", only top " + str(int(all_read_limits[id_file])) + + " reads are used in downstream analysis.") + if this_process: + memory_usage = "Mem " + str(round(this_process.memory_info().rss / 1024.0 / 1024 / 1024, 3)) + " G, " + else: + memory_usage = '' + + del name_to_line + + if not index_in_memory: + # dump line clusters + len_indices = len(line_clusters) + temp2_indices_file_out = open(temp2_clusters_dir[0], 'w') + for this_index in range(len_indices): + temp2_indices_file_out.write('\t'.join([str(x) for x in line_clusters[this_index]])) + temp2_indices_file_out.write('\n') + temp2_indices_file_out.close() + os.rename(temp2_clusters_dir[0], temp2_clusters_dir[1]) + + del seq_duplicates + len_indices = len(line_clusters) + if len_indices == 0 and line_count // 4 > 0: + log_handler.error("No qualified reads found!") + log_handler.error("Word size (" + str(word_size) + ") CANNOT be larger than your " + "post-trimmed maximum read length!") + exit() + log_handler.info(memory_usage + str(len_indices) + " candidates in all " + str(line_count // 4) + " reads") else: if not index_in_memory: temp1_contig_out = open(temp1_contig_dir[0], 'w') @@ -3345,8 +3681,14 @@ def separate_fq_by_pair(out_base, prefix, verbose_log, log_handler): def extract_organelle_genome(out_base, spades_output, ignore_kmer_res, slim_out_fg, organelle_prefix, organelle_type, blast_db, read_len_for_log, verbose, log_handler, basic_prefix, expected_maximum_size, expected_minimum_size, do_spades_scaffolding, options): + from GetOrganelleLib.assembly_parser import ProcessingGraphFailed, Assembly - def disentangle_assembly(fastg_file, tab_file, output, weight_factor, log_dis, time_limit, type_factor=3., + random.seed(options.random_seed) + # import numpy as np + # np.random.seed(options.random_seed) + + def disentangle_assembly(assembly_obj, fastg_file, tab_file, output, weight_factor, log_dis, time_limit, + type_factor=3., mode="embplant_pt", blast_db_base="embplant_pt", contamination_depth=3., contamination_similarity=0.95, degenerate=True, degenerate_depth=1.5, degenerate_similarity=0.98, @@ -3355,7 +3697,14 @@ def disentangle_assembly(fastg_file, tab_file, output, weight_factor, log_dis, t here_acyclic_allowed=False, here_verbose=False, timeout_flag_str="'--disentangle-time-limit'", temp_graph=None): @set_time_limit(time_limit, flag_str=timeout_flag_str) - def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="embplant_pt", + def disentangle_inside(input_graph, + fastg_f, + tab_f, + o_p, + w_f, + log_in, + type_f=3., + mode_in="embplant_pt", in_db_n="embplant_pt", c_d=3., c_s=0.95, deg=True, deg_dep=1.5, deg_sim=0.98, hard_c_t=10., min_s_f=0.1, max_c_in=True, max_s=inf, min_s=0, with_spades_scaffolds_in=False, @@ -3371,7 +3720,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb log_in.info("Disentangling " + fastg_f + " as a/an " + in_db_n + "-insufficient graph ... ") else: log_in.info("Disentangling " + fastg_f + " as a circular genome ... ") - input_graph = Assembly(fastg_f) + # input_graph = Assembly(fastg_f) if with_spades_scaffolds_in: if not input_graph.add_gap_nodes_with_spades_res(os.path.join(spades_output, "scaffolds.fasta"), os.path.join(spades_output, "scaffolds.paths"), @@ -3379,14 +3728,26 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb log_handler=log_handler): raise ProcessingGraphFailed("No new connections.") else: + log_handler.info("Re-loading labels along " + out_fastg) + input_graph.parse_tab_file( + tab_f, + database_name=in_db_n, + type_factor=type_f, + max_gene_gap=250, + max_cov_diff=hard_c_t, # contamination_depth? + verbose=verbose, + log_handler=log_handler, + random_obj=random) if in_temp_graph: if in_temp_graph.endswith(".gfa"): this_tmp_graph = in_temp_graph[:-4] + ".scaffolds.gfa" else: this_tmp_graph = in_temp_graph + ".scaffolds.gfa" input_graph.write_to_gfa(this_tmp_graph) - target_results = input_graph.find_target_graph(tab_f, - mode=mode_in, database_name=in_db_n, type_factor=type_f, + target_results = input_graph.find_target_graph( # tab_f, + mode=mode_in, + db_name=in_db_n, + # type_factor=type_f, hard_cov_threshold=hard_c_t, contamination_depth=c_d, contamination_similarity=c_s, @@ -3400,7 +3761,9 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb read_len_for_log=read_len_for_log, kmer_for_log=int(this_K[1:]), log_handler=log_in, verbose=verbose_in, - temp_graph=in_temp_graph) + temp_graph=in_temp_graph, + selected_graph=o_p + ".graph.selected_graph.gfa", + random_obj=random) if not target_results: raise ProcessingGraphFailed("No target graph detected!") if len(target_results) > 1: @@ -3416,11 +3779,14 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb go_res += 1 broken_graph = res["graph"] count_path = 0 - these_paths = broken_graph.get_all_paths(mode=mode_in, log_handler=log_in) + # use options.max_paths_num + 1 to trigger the warning + these_paths = broken_graph.get_all_paths(mode=mode_in, log_handler=log_in, + max_paths_num=options.max_paths_num + 1) # reducing paths if len(these_paths) > options.max_paths_num: log_in.warning("Only exporting " + str(options.max_paths_num) + " out of all " + - str(len(these_paths)) + " possible paths. (see '--max-paths-num' to change it.)") + str(options.max_paths_num) + + "+ possible paths. (see '--max-paths-num' to change it.)") these_paths = these_paths[:options.max_paths_num] # exporting paths, reporting results for this_paths, other_tag in these_paths: @@ -3474,7 +3840,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb open(out_n, "w").write("\n".join(all_contig_str)) if set(still_complete[-len(these_paths):]) == {"complete"}: - this_out_base = o_p + ".complete.graph" + str(go_res) + ".selected_graph." + this_out_base = o_p + ".complete.graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + this_out_base + "gfa") broken_graph.write_to_gfa(this_out_base + "gfa") image_produced = draw_assembly_graph_using_bandage( @@ -3483,7 +3849,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb assembly_graph_ob=broken_graph, log_handler=log_handler, verbose_log=verbose_in, which_bandage=options.which_bandage) elif set(still_complete[-len(these_paths):]) == {"nearly-complete"}: - this_out_base = o_p + ".nearly-complete.graph" + str(go_res) + ".selected_graph." + this_out_base = o_p + ".nearly-complete.graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + this_out_base + "gfa") broken_graph.write_to_gfa(this_out_base + "gfa") image_produced = draw_assembly_graph_using_bandage( @@ -3492,7 +3858,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb assembly_graph_ob=broken_graph, log_handler=log_handler, verbose_log=verbose_in, which_bandage=options.which_bandage) else: - this_out_base = o_p + ".contigs.graph" + str(go_res) + ".selected_graph." + this_out_base = o_p + ".contigs.graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + this_out_base + "gfa") broken_graph.write_to_gfa(this_out_base + "gfa") # image_produced = draw_assembly_graph_using_bandage( @@ -3513,13 +3879,15 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb go_res += 1 idealized_graph = res["graph"] count_path = 0 - + # use options.max_paths_num + 1 to trigger the warning these_paths = idealized_graph.get_all_circular_paths( - mode=mode_in, log_handler=log_in, reverse_start_direction_for_pt=options.reverse_lsc) + mode=mode_in, log_handler=log_in, reverse_start_direction_for_pt=options.reverse_lsc, + max_paths_num=options.max_paths_num + 1) # reducing paths if len(these_paths) > options.max_paths_num: log_in.warning("Only exporting " + str(options.max_paths_num) + " out of all " + - str(len(these_paths)) + " possible paths. (see '--max-paths-num' to change it.)") + str(options.max_paths_num) + + "+ possible paths. (see '--max-paths-num' to change it.)") these_paths = these_paths[:options.max_paths_num] # exporting paths, reporting results @@ -3542,7 +3910,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb ":".join([str(len_val) for len_val in ir_stats[:3]])) log_in.info( "Writing PATH" + str(count_path) + " of " + status_str + " " + mode_in + " to " + out_n) - temp_base_out = o_p + "." + status_str + ".graph" + str(go_res) + ".selected_graph." + temp_base_out = o_p + "." + status_str + ".graph" + str(go_res) + ".path_sequence." log_in.info("Writing GRAPH to " + temp_base_out + "gfa") idealized_graph.write_to_gfa(temp_base_out + "gfa") image_produced = draw_assembly_graph_using_bandage( @@ -3570,8 +3938,10 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb " using Bandage to confirm the final result.") log_in.info("Writing output finished.") - disentangle_inside(fastg_f=fastg_file, tab_f=tab_file, o_p=output, w_f=weight_factor, log_in=log_dis, - type_f=type_factor, mode_in=mode, in_db_n=blast_db_base, + disentangle_inside(input_graph=deepcopy(assembly_obj), + fastg_f=fastg_file, tab_f=tab_file, o_p=output, w_f=weight_factor, log_in=log_dis, + type_f=type_factor, + mode_in=mode, in_db_n=blast_db_base, c_d=contamination_depth, c_s=contamination_similarity, deg=degenerate, deg_dep=degenerate_depth, deg_sim=degenerate_similarity, hard_c_t=hard_cov_threshold, min_s_f=min_sigma_factor, max_c_in=here_only_max_c, @@ -3591,32 +3961,52 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb timeout_flag = "'--disentangle-time-limit'" export_succeeded = False path_prefix = os.path.join(out_base, organelle_prefix) - graph_temp_file = path_prefix + ".temp.gfa" if options.keep_temp_files else None + graph_temp_file1 = path_prefix + "R1.temp.gfa" if options.keep_temp_files else None + file_to_assembly_obj = {} for go_k, kmer_dir in enumerate(kmer_dirs): out_fastg = slim_out_fg[go_k] if out_fastg and os.path.getsize(out_fastg): try: """disentangle""" - out_csv = out_fastg[:-5] + "csv" # if it is the first round (the largest kmer), copy the slimmed result to the main spades output # if go_k == 0: # main_spades_folder = os.path.split(kmer_dir)[0] # os.system("cp " + out_fastg + " " + main_spades_folder) # os.system("cp " + out_csv + " " + main_spades_folder) - disentangle_assembly(fastg_file=out_fastg, blast_db_base=blast_db, - mode=organelle_type, tab_file=out_csv, output=path_prefix, - weight_factor=100, hard_cov_threshold=options.disentangle_depth_factor, + log_handler.info("Parsing " + out_fastg) + assembly_graph_obj = Assembly(out_fastg, log_handler=log_handler) + log_handler.info("Loading and cleaning labels along " + out_fastg) + assembly_graph_obj.parse_tab_file( + out_csv, + database_name=blast_db, + type_factor=options.disentangle_type_factor, + max_gene_gap=250, + max_cov_diff=options.disentangle_depth_factor, # contamination_depth? + verbose=verbose, + log_handler=log_handler, + random_obj=random) + file_to_assembly_obj[out_fastg] = assembly_graph_obj + disentangle_assembly(assembly_obj=assembly_graph_obj, + fastg_file=out_fastg, + blast_db_base=blast_db, + mode=organelle_type, + tab_file=out_csv, + output=path_prefix, + weight_factor=100, + type_factor=options.disentangle_type_factor, + hard_cov_threshold=options.disentangle_depth_factor, contamination_depth=options.contamination_depth, contamination_similarity=options.contamination_similarity, - degenerate=options.degenerate, degenerate_depth=options.degenerate_depth, + degenerate=options.degenerate, + degenerate_depth=options.degenerate_depth, degenerate_similarity=options.degenerate_similarity, expected_max_size=expected_maximum_size, expected_min_size=expected_minimum_size, here_only_max_c=True, here_acyclic_allowed=False, here_verbose=verbose, log_dis=log_handler, time_limit=options.disentangle_time_limit, timeout_flag_str=timeout_flag, - temp_graph=graph_temp_file) + temp_graph=graph_temp_file1) except ImportError as e: log_handler.error("Disentangling failed: " + str(e)) return False @@ -3639,6 +4029,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb break if not export_succeeded and do_spades_scaffolding: + graph_temp_file1s = path_prefix + "R1S.temp.gfa" if options.keep_temp_files else None largest_k_graph_f_exist = bool(slim_out_fg[0]) if kmer_dirs and largest_k_graph_f_exist: out_fastg = slim_out_fg[0] @@ -3646,9 +4037,15 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb try: """disentangle""" out_csv = out_fastg[:-5] + "csv" - disentangle_assembly(fastg_file=out_fastg, blast_db_base=blast_db, - mode=organelle_type, tab_file=out_csv, output=path_prefix, - weight_factor=100, hard_cov_threshold=options.disentangle_depth_factor, + disentangle_assembly(assembly_obj=file_to_assembly_obj[out_fastg], + fastg_file=out_fastg, + blast_db_base=blast_db, + mode=organelle_type, + tab_file=out_csv, + output=path_prefix, + weight_factor=100, + type_factor=options.disentangle_type_factor, + hard_cov_threshold=options.disentangle_depth_factor, contamination_depth=options.contamination_depth, contamination_similarity=options.contamination_similarity, degenerate=options.degenerate, degenerate_depth=options.degenerate_depth, @@ -3658,7 +4055,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb here_only_max_c=True, with_spades_scaffolds=True, here_acyclic_allowed=False, here_verbose=verbose, log_dis=log_handler, time_limit=options.disentangle_time_limit, timeout_flag_str=timeout_flag, - temp_graph=graph_temp_file) + temp_graph=graph_temp_file1s) except FileNotFoundError: log_handler.warning("scaffolds.fasta and/or scaffolds.paths not found!") except ImportError as e: @@ -3682,6 +4079,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb export_succeeded = True if not export_succeeded: + graph_temp_file2 = path_prefix + "R2.temp.gfa" if options.keep_temp_files else None largest_k_graph_f_exist = bool(slim_out_fg[0]) if kmer_dirs and largest_k_graph_f_exist: for go_k, kmer_dir in enumerate(kmer_dirs): @@ -3694,11 +4092,18 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb if out_fastg_list: out_fastg = out_fastg_list[0] out_csv = out_fastg[:-5] + "csv" - disentangle_assembly(fastg_file=out_fastg, blast_db_base=blast_db, - mode=organelle_type, tab_file=out_csv, - output=path_prefix, weight_factor=100, here_verbose=verbose, + disentangle_assembly(assembly_obj=file_to_assembly_obj[out_fastg], + fastg_file=out_fastg, + blast_db_base=blast_db, + mode=organelle_type, + tab_file=out_csv, + output=path_prefix, + weight_factor=100, + type_factor=options.disentangle_type_factor, + here_verbose=verbose, log_dis=log_handler, - hard_cov_threshold=options.disentangle_depth_factor * 0.8, + hard_cov_threshold=options.disentangle_depth_factor * 0.6, + # TODO the adjustment should be changed if it's RNA data contamination_depth=options.contamination_depth, contamination_similarity=options.contamination_similarity, degenerate=options.degenerate, @@ -3708,7 +4113,7 @@ def disentangle_inside(fastg_f, tab_f, o_p, w_f, log_in, type_f=3., mode_in="emb expected_min_size=expected_minimum_size, here_only_max_c=True, here_acyclic_allowed=True, time_limit=3600, timeout_flag_str=timeout_flag, - temp_graph=graph_temp_file) + temp_graph=graph_temp_file2) except (ImportError, AttributeError) as e: log_handler.error("Disentangling failed: " + str(e)) break @@ -3907,13 +4312,23 @@ def main(): else: target_fq = os.path.join(out_base, str(file_id + 1) + "-" + os.path.basename(read_file)) - if os.path.realpath(target_fq) == os.path.realpath(os.path.join(os.getcwd(), read_file)): + if os.path.exists(target_fq) and os.path.islink(target_fq): + if os.path.realpath(target_fq) != os.path.realpath(os.path.join(os.getcwd(), read_file)): + log_handler.error("Existed symlink (%s -> %s) does not link to the input file (%s)!" % + (target_fq, + os.path.realpath(target_fq), + os.path.realpath(os.path.join(os.getcwd(), read_file)))) + exit() + elif os.path.realpath(target_fq) == os.path.realpath(os.path.join(os.getcwd(), read_file)): log_handler.error("Do not put original reads file(s) in the output directory!") exit() if not (os.path.exists(target_fq) and resume): if all_read_nums[file_id] > READ_LINE_TO_INF: - os.system("cp " + read_file + " " + target_fq + ".Temp") - os.system("mv " + target_fq + ".Temp " + target_fq) + # os.system("cp " + read_file + " " + target_fq + ".Temp") + # os.system("mv " + target_fq + ".Temp " + target_fq) + if os.path.exists(target_fq): + os.remove(target_fq) + os.system("ln -s " + os.path.abspath(read_file) + " " + target_fq) else: os.system("head -n " + str(int(4 * all_read_nums[file_id])) + " " + read_file + " > " + target_fq + ".Temp") @@ -3930,8 +4345,9 @@ def main(): get_read_quality_info(original_fq_files, sampling_reads_for_quality, options.min_quality_score, log_handler, maximum_ignore_percent=options.maximum_ignore_percent) log_handler.info("Counting read lengths ...") - mean_read_len, max_read_len, all_read_nums = get_read_len_mean_max_count(original_fq_files, - options.maximum_n_reads) + mean_read_len, max_read_len, all_read_nums = get_read_len_mean_max_count( + original_fq_files, options.maximum_n_reads, n_process=1) + # original_fq_files, options.maximum_n_reads, n_process=options.threads) log_handler.info("Mean = " + str(round(mean_read_len, 1)) + " bp, maximum = " + str(max_read_len) + " bp.") log_handler.info("Reads used = " + "+".join([str(sub_num) for sub_num in all_read_nums])) @@ -4151,7 +4567,7 @@ def main(): for go_t, sub_organelle_type in enumerate(options.organelle_type): og_prefix = options.prefix + organelle_type_prefix[go_t] graph_existed = bool([gfa_f for gfa_f in os.listdir(out_base) - if gfa_f.startswith(og_prefix) and gfa_f.endswith(".selected_graph.gfa")]) + if gfa_f.startswith(og_prefix) and gfa_f.endswith(".path_sequence.gfa")]) fasta_existed = bool([fas_f for fas_f in os.listdir(out_base) if fas_f.startswith(og_prefix) and fas_f.endswith(".path_sequence.fasta")]) if resume and graph_existed and fasta_existed: diff --git a/setup.py b/setup.py index 707eeef..e6c03ac 100644 --- a/setup.py +++ b/setup.py @@ -41,25 +41,26 @@ install_dependencies.append("numpy==1.16.4") else: sys.stdout.write("Existed module numpy " + str(numpy.__version__) + "\n") -try: - import scipy -except ImportError: - if MAJOR_VERSION == 3: - install_dependencies.append("scipy>=1.3.0") - else: - # higher version not compatible with python2 - install_dependencies.append("scipy==1.2.1") -else: - sys.stdout.write("Existed module numpy " + str(scipy.__version__) + "\n") -try: - import sympy -except ImportError: - if MAJOR_VERSION == 3: - install_dependencies.append("sympy>=1.4") - else: - install_dependencies.append("sympy==1.4") -else: - sys.stdout.write("Existed module sympy " + str(sympy.__version__) + "\n") +# try: +# import scipy +# except ImportError: +# if MAJOR_VERSION == 3: +# install_dependencies.append("scipy>=1.3.0") +# else: +# # higher version not compatible with python2 +# install_dependencies.append("scipy==1.2.1") +# else: +# sys.stdout.write("Existed module scipy " + str(scipy.__version__) + "\n") +# try: +# import sympy +# from sympy import Symbol, solve, lambdify, log +# except ImportError: +# if MAJOR_VERSION == 3: +# install_dependencies.append("sympy>=1.4") +# else: +# install_dependencies.append("sympy==1.4") +# else: +# sys.stdout.write("Existed module sympy " + str(sympy.__version__) + "\n") try: import requests except ImportError: @@ -67,6 +68,16 @@ else: sys.stdout.write("Existed module requests " + str(requests.__version__) + "\n") +try: + import gekko +except ImportError: + install_dependencies.append("gekko>=1.0.4") +else: + sys.stdout.write("Existed module gekko " + str(gekko.__version__) + "\n") + +install_dependencies.append("biopython") + + PATH_OF_THIS_SCRIPT = os.path.split(os.path.realpath(__file__))[0] LIB_NAME = "GetOrganelleLib" # LIB_DIR = os.path.join(PATH_OF_THIS_SCRIPT, LIB_NAME) @@ -113,6 +124,7 @@ def get_recursive_files(target_dir, start_from="", exclude_files=None): "Utilities/disentangle_organelle_assembly.py", "Utilities/evaluate_assembly_using_mapping.py", "Utilities/fastg_to_gfa.py", + "Utilities/gb_to_tbl.py", "Utilities/get_annotated_regions_from_gb.py", "Utilities/get_organelle_config.py", "Utilities/get_pair_reads.py", @@ -128,18 +140,18 @@ def get_recursive_files(target_dir, start_from="", exclude_files=None): "Utilities/summary_get_organelle_output.py", "Utilities/reconstruct_graph_from_fasta.py"] # rename execution program if not python -dep_scripts_to_change = [] -if os.path.isdir(os.path.join(DEP_DIR, SYSTEM_NAME, "SPAdes", "bin")): - for spades_script in os.listdir(os.path.join(DEP_DIR, SYSTEM_NAME, "SPAdes", "bin")): - if spades_script.endswith(".py") and not spades_script.startswith("."): - dep_scripts_to_change.append(os.path.join(DEP_DIR, SYSTEM_NAME, "SPAdes", "bin", spades_script)) -if os.path.exists(os.path.join(DEP_DIR, SYSTEM_NAME, "bowtie2", "bowtie2-build")): - dep_scripts_to_change.append(os.path.join(DEP_DIR, SYSTEM_NAME, "bowtie2", "bowtie2-build")) -if os.path.basename(sys.executable) != "python": - for rename_py_script in scripts_to_install + dep_scripts_to_change: - original_lines = open(rename_py_script, encoding="utf-8").readlines() - original_lines[0] = "#!" + sys.executable + "\n" - open(rename_py_script, "w", encoding="utf-8").writelines(original_lines) +# dep_scripts_to_change = [] +# if os.path.isdir(os.path.join(DEP_DIR, SYSTEM_NAME, "SPAdes", "bin")): +# for spades_script in os.listdir(os.path.join(DEP_DIR, SYSTEM_NAME, "SPAdes", "bin")): +# if spades_script.endswith(".py") and not spades_script.startswith("."): +# dep_scripts_to_change.append(os.path.join(DEP_DIR, SYSTEM_NAME, "SPAdes", "bin", spades_script)) +# if os.path.exists(os.path.join(DEP_DIR, SYSTEM_NAME, "bowtie2", "bowtie2-build")): +# dep_scripts_to_change.append(os.path.join(DEP_DIR, SYSTEM_NAME, "bowtie2", "bowtie2-build")) +# if os.path.basename(sys.executable) != "python": +# for rename_py_script in scripts_to_install + dep_scripts_to_change: +# original_lines = open(rename_py_script, encoding="utf-8").readlines() +# original_lines[0] = "#!" + sys.executable + "\n" +# open(rename_py_script, "w", encoding="utf-8").writelines(original_lines) # check local BLAST