diff --git a/binchicken/workflow/scripts/target_elusive.py b/binchicken/workflow/scripts/target_elusive.py index c5f467bb..6a947595 100755 --- a/binchicken/workflow/scripts/target_elusive.py +++ b/binchicken/workflow/scripts/target_elusive.py @@ -30,23 +30,29 @@ def get_clusters( # Set to 2 to produce paired edges MAX_COASSEMBLY_SAMPLES = 2 - logging.info("Choosing preclusters based on distances") + logging.info("Processing distances...") best_samples = np.argsort(distances, axis=1)[:, :PRECLUSTER_SIZE] chosen_samples = [(samples[i], list(np.array(samples)[b[b != i]])) for i, b in enumerate(best_samples)] + + sample_combinations = ( + pl.DataFrame({"cluster_size": range(1, MAX_COASSEMBLY_SAMPLES)}) + .with_columns( + sample_combinations = pl.col("cluster_size").map_elements( + lambda x: [i for i in itertools.combinations(range(PRECLUSTER_SIZE-1), x)], + return_dtype=pl.List(pl.List(pl.Int64)), + ) + ) + .explode("sample_combinations") + .select("sample_combinations") + ) + + logging.info("Choosing preclusters based on distances") with pl.StringCache(): preclusters = ( pl.DataFrame(chosen_samples, schema={"sample": pl.Categorical, "samples": pl.List(pl.Categorical)}) .with_columns(length = pl.col("samples").list.len()) - .join(pl.DataFrame({"cluster_size": range(1, MAX_COASSEMBLY_SAMPLES)}), how="cross") - .with_columns( - samples_combinations = pl.struct(["length", "cluster_size"]).map_elements( - lambda x: [i for i in itertools.combinations(range(x["length"]), x["cluster_size"])], - return_dtype=pl.List(pl.List(pl.Int64)), - ) - ) - .select("sample", "samples", "samples_combinations") - .explode("samples_combinations") - .with_columns(pl.col("samples").list.gather(pl.col("samples_combinations"))) + .join(sample_combinations, how="cross") + .with_columns(pl.col("samples").list.gather(pl.col("sample_combinations"))) .select( samples = pl.concat_list("sample", "samples") .cast(pl.List(str))