Skip to content

Commit

Permalink
Merge pull request #201 from bacpop/bacpop-17
Browse files Browse the repository at this point in the history
Bacpop 17
  • Loading branch information
johnlees authored Apr 29, 2022
2 parents 46aff5d + 73c5aef commit f2f0fcf
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 76 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.4.5'
__version__ = '2.4.6'

# Minimum sketchlib version
SKETCHLIB_MAJOR = 1
Expand Down
110 changes: 84 additions & 26 deletions PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,89 @@ def assign_query(dbFuncs,
gpu_dist,
gpu_graph,
deviceid,
web,
json_sketch,
save_partial_query_graph):
"""Code for assign query mode. Written as a separate function so it can be called
"""Code for assign query mode for CLI"""

createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
readDBParams = dbFuncs['readDBParams']

if ref_db == output:
sys.stderr.write("--output and --ref-db must be different to "
"prevent overwrite.\n")
sys.exit(1)

# Find distances to reference db
kmers, sketch_sizes, codon_phased = readDBParams(ref_db)

# construct database
createDatabaseDir(output, kmers)
qNames = constructDatabase(q_files,
kmers,
sketch_sizes,
output,
threads,
overwrite,
codon_phased = codon_phased,
calc_random = False,
use_gpu = gpu_sketch,
deviceid = deviceid)

isolateClustering = assign_query_hdf5(dbFuncs,
ref_db,
qNames,
output,
qc_dict,
update_db,
write_references,
distances,
threads,
overwrite,
plot_fit,
graph_weights,
max_a_dist,
max_pi_dist,
type_isolate,
model_dir,
strand_preserved,
previous_clustering,
external_clustering,
core,
accessory,
gpu_sketch,
gpu_dist,
gpu_graph,
deviceid,
save_partial_query_graph)
return(isolateClustering)

def assign_query_hdf5(dbFuncs,
ref_db,
qNames,
output,
qc_dict,
update_db,
write_references,
distances,
threads,
overwrite,
plot_fit,
graph_weights,
max_a_dist,
max_pi_dist,
type_isolate,
model_dir,
strand_preserved,
previous_clustering,
external_clustering,
core,
accessory,
gpu_sketch,
gpu_dist,
gpu_graph,
deviceid,
save_partial_query_graph):
"""Code for assign query mode taking hdf5 as input. Written as a separate function so it can be called
by web APIs"""

# Modules imported here as graph tool is very slow to load (it pulls in all of GTK?)
Expand Down Expand Up @@ -80,10 +159,6 @@ def assign_query(dbFuncs,
from .utils import update_distance_matrices
from .utils import createOverallLineage

from .web import sketch_to_hdf5

createDatabaseDir = dbFuncs['createDatabaseDir']
constructDatabase = dbFuncs['constructDatabase']
joinDBs = dbFuncs['joinDBs']
queryDatabase = dbFuncs['queryDatabase']
readDBParams = dbFuncs['readDBParams']
Expand Down Expand Up @@ -115,7 +190,7 @@ def assign_query(dbFuncs,
prev_clustering = model_prefix

# Find distances to reference db
kmers, sketch_sizes, codon_phased = readDBParams(ref_db)
kmers = readDBParams(ref_db)[0]

# Iterate through different types of model fit with a refined model when specified
# Core and accessory assignments use the same model and same overall set of distances
Expand Down Expand Up @@ -150,22 +225,7 @@ def assign_query(dbFuncs,
sys.exit(1)
else:
rNames = getSeqsInDb(os.path.join(ref_db, os.path.basename(ref_db) + ".h5"))
# construct database - use a single database directory for all query outputs
if (web and json_sketch is not None):
qNames = sketch_to_hdf5(json_sketch, output)
elif (fit_type == 'default'):
# construct database
createDatabaseDir(output, kmers)
qNames = constructDatabase(q_files,
kmers,
sketch_sizes,
output,
threads,
overwrite,
codon_phased = codon_phased,
calc_random = False,
use_gpu = gpu_sketch,
deviceid = deviceid)

if (fit_type == 'default' or (fit_type != 'default' and use_ref_graph)):
# run query
qrDistMat = queryDatabase(rNames = rNames,
Expand Down Expand Up @@ -610,8 +670,6 @@ def main():
args.gpu_dist,
args.gpu_graph,
args.deviceid,
web=False,
json_sketch=None,
save_partial_query_graph=False)

sys.stderr.write("\nDone\n")
Expand Down
103 changes: 54 additions & 49 deletions PopPUNK/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,47 +170,49 @@ def get_colours(query, clusters):
colours.append('blue')
return colours

def sketch_to_hdf5(sketch, output):
"""Convert JSON sketch to query hdf5 database"""
kmers = []
dists = []

sketch_dict = json.loads(sketch)
qNames = ["query"]
def sketch_to_hdf5(sketches_dict, output):
"""Convert dict of JSON sketches to query hdf5 database"""
qNames = []
queryDB = h5py.File(os.path.join(output, os.path.basename(output) + '.h5'), 'w')
sketches = queryDB.create_group("sketches")
sketch_props = sketches.create_group(qNames[0])

for key, value in sketch_dict.items():
try:
kmers.append(int(key))
dists.append(np.array(value, dtype='uint64'))
except (TypeError, ValueError):
if key == "version":
sketches.attrs['sketch_version'] = value
elif key == "codon_phased":
sketches.attrs['codon_phased'] = value
elif key == "densified":
sketches.attrs['densified'] = value
elif key == "bases":
sketch_props.attrs['base_freq'] = value
elif key == "bbits":
sketch_props.attrs['bbits'] = value
elif key == "length":
sketch_props.attrs['length'] = value
elif key == "missing_bases":
sketch_props.attrs['missing_bases'] = value
elif key == "sketchsize64":
sketch_props.attrs['sketchsize64'] = value
elif key == "species":
pass
else:
sys.stderr.write(key + " not recognised")

sketch_props.attrs['kmers'] = kmers
for k_index in range(len(kmers)):
k_spec = sketch_props.create_dataset(str(kmers[k_index]), data=dists[k_index], dtype='uint64')
k_spec.attrs['kmer-size'] = kmers[k_index]

for top_key, top_value in sketches_dict.items():
qNames.append(top_key)
kmers = []
dists = []
sketch_dict = json.loads(top_value)
sketch_props = sketches.create_group(top_key)

for key, value in sketch_dict.items():
try:
kmers.append(int(key))
dists.append(np.array(value, dtype='uint64'))
except (TypeError, ValueError):
if key == "version":
sketches.attrs['sketch_version'] = value
elif key == "codon_phased":
sketches.attrs['codon_phased'] = value
elif key == "densified":
pass
elif key == "bases":
sketch_props.attrs['base_freq'] = value
elif key == "bbits":
sketch_props.attrs['bbits'] = value
elif key == "length":
sketch_props.attrs['length'] = value
elif key == "missing_bases":
sketch_props.attrs['missing_bases'] = value
elif key == "sketchsize64":
sketch_props.attrs['sketchsize64'] = value
elif key == "species":
pass
else:
sys.stderr.write(key + " not recognised")

sketch_props.attrs['kmers'] = kmers
for k_index in range(len(kmers)):
k_spec = sketch_props.create_dataset(str(kmers[k_index]), data=dists[k_index], dtype='uint64')
k_spec.attrs['kmer-size'] = kmers[k_index]
queryDB.close()
return qNames

Expand Down Expand Up @@ -290,13 +292,14 @@ def get_aliases(aliasDF, clusterLabels, species):
alias_dict = {"GPSC":str(GPS_name)}
return alias_dict

def summarise_clusters(output, species, species_db):
def summarise_clusters(output, species, species_db, qNames):
"""Retreieve assigned query and all cluster prevalences.
Write list of all isolates in cluster for tree subsetting"""
totalDF = pd.read_csv(os.path.join(output, os.path.basename(output) + "_clusters.csv"))
queryDF = totalDF.loc[totalDF['Taxon'] == "query"]
queryDF = totalDF[totalDF['Taxon'].isin(qNames)]
queryDF = queryDF.reset_index(drop=True)
query = str(queryDF["Cluster"][0])
queries_names = list(queryDF["Taxon"])
queries_clusters = list(queryDF["Cluster"])
num_samples = len(totalDF["Taxon"])
totalDF["Cluster"] = totalDF["Cluster"].astype(str)
cluster_list = list(totalDF["Cluster"])
Expand All @@ -307,19 +310,21 @@ def summarise_clusters(output, species, species_db):
uniquetotalDF = totalDF.drop_duplicates(subset=['Cluster'])
clusters = list(uniquetotalDF['Cluster'])
prevalences = list(uniquetotalDF["Prevalence"])
query_prevalence = prevalences[clusters.index(query)]
# write list of all isolates in cluster
clusterDF = totalDF.loc[totalDF['Cluster'] == query]
to_include = list(clusterDF['Taxon'])
with open(os.path.join(output, "include.txt"), "w") as i:
i.write("\n".join(to_include))
queries_prevalence = []
for query in queries_clusters:
queries_prevalence.append(prevalences[clusters.index(str(query))])
# write list of all isolates in cluster
clusterDF = totalDF.loc[totalDF['Cluster'] == str(query)]
to_include = list(clusterDF['Taxon'])
with open(os.path.join(output, "include" + str(query) + ".txt"), "w") as i:
i.write("\n".join(to_include))
# get aliases
if os.path.isfile(os.path.join(species_db, "aliases.csv")):
aliasDF = pd.read_csv(os.path.join(species_db, "aliases.csv"))
alias_dict = get_aliases(aliasDF, list(clusterDF['Taxon']), species)
else:
alias_dict = {"Aliases": "NA"}
return query, query_prevalence, clusters, prevalences, alias_dict, to_include
return queries_names, queries_clusters, queries_prevalence, clusters, prevalences, alias_dict, to_include

@scheduler.task('interval', id='clean_tmp', hours=1, misfire_grace_time=900)
def clean_tmp():
Expand Down

0 comments on commit f2f0fcf

Please sign in to comment.