From e1f6581a817034cf38a8b6a091741a3b0a2e86e8 Mon Sep 17 00:00:00 2001
From: AroneyS <samuel.aroney@qut.edu.au>
Date: Thu, 2 May 2024 08:30:49 +1000
Subject: [PATCH] add dtypes to evaluate.py inputs

---
 binchicken.yml                          |  2 +-
 binchicken/workflow/scripts/evaluate.py | 43 +++++++++++++++++++++----
 2 files changed, 38 insertions(+), 7 deletions(-)

diff --git a/binchicken.yml b/binchicken.yml
index 24446a50..02fda45e 100644
--- a/binchicken.yml
+++ b/binchicken.yml
@@ -11,7 +11,7 @@ dependencies:
   - bird_tool_utils_python=0.4.*
   - extern=0.4.*
   - ruamel.yaml=0.17.*
-  - polars=0.20.*
+  - polars=0.20.4
   - pigz=2.3.*
   - pyarrow=12.0.*
   - parallel=20230522
diff --git a/binchicken/workflow/scripts/evaluate.py b/binchicken/workflow/scripts/evaluate.py
index 2f2d33b2..41985189 100755
--- a/binchicken/workflow/scripts/evaluate.py
+++ b/binchicken/workflow/scripts/evaluate.py
@@ -6,7 +6,38 @@
 import polars as pl
 import os
 
-OUTPUT_COLUMNS={
+SINGLEM_COLUMNS = {
+    "gene": str,
+    "sample": str,
+    "sequence": str,
+    "num_hits": int,
+    "coverage": float,
+    "taxonomy": str,
+}
+
+TARGET_COLUMNS = SINGLEM_COLUMNS | {
+    "target": int,
+}
+APPRAISE_COLUMNS = SINGLEM_COLUMNS | {
+    "found_in": str,
+}
+
+CLUSTER_COLUMNS = {
+    "samples": str,
+    "length": int,
+    "total_targets": int,
+    "total_size": int,
+    "recover_samples": str,
+    "coassembly": str,
+}
+EDGE_COLUMNS = {
+    "style": str,
+    "cluster_size": int,
+    "samples": str,
+    "target_ids": str,
+}
+
+OUTPUT_COLUMNS = {
     "coassembly": str,
     "gene": str,
     "sequence": str,
@@ -302,11 +333,11 @@ def summarise_stats(matches, combined_otu_table, recovered_bins):
     novel_hits_path = snakemake.output.novel_hits
     summary_stats_path = snakemake.output.summary_stats
 
-    target_otu_table = pl.read_csv(target_path, separator="\t")
-    binned_otu_table = pl.read_csv(binned_path, separator="\t")
-    elusive_clusters = pl.read_csv(elusive_clusters_path, separator="\t")
-    elusive_edges = pl.read_csv(elusive_edges_path, separator="\t")
-    recovered_otu_table = pl.read_csv(recovered_otu_table_path, separator="\t")
+    target_otu_table = pl.read_csv(target_path, separator="\t", dtypes=TARGET_COLUMNS)
+    binned_otu_table = pl.read_csv(binned_path, separator="\t", dtypes=APPRAISE_COLUMNS)
+    elusive_clusters = pl.read_csv(elusive_clusters_path, separator="\t", dtypes=CLUSTER_COLUMNS)
+    elusive_edges = pl.read_csv(elusive_edges_path, separator="\t", dtypes=EDGE_COLUMNS)
+    recovered_otu_table = pl.read_csv(recovered_otu_table_path, separator="\t", dtypes=SINGLEM_COLUMNS)
 
     matches, unmatched, summary = evaluate(target_otu_table, binned_otu_table, elusive_clusters, elusive_edges, recovered_otu_table, recovered_bins)
     # Export hits matching elusive targets