Skip to content

Commit

Permalink
Add Pandera schemas for dataframe validation of the successful_runs f…
Browse files Browse the repository at this point in the history
…ile, and of taxonomy files. This includes a custom check method for validating the tax hierarchies.
  • Loading branch information
chrisAta committed Dec 12, 2024
1 parent 4ea0999 commit 228d322
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
import pandas as pd

from mgnify_pipelines_toolkit.constants.db_labels import TAXDB_LABELS, ASV_TAXDB_LABELS
from mgnify_pipelines_toolkit.constants.tax_ranks import (
SHORT_TAX_RANKS,
SHORT_PR2_TAX_RANKS,
)
from mgnify_pipelines_toolkit.schemas.schemas import (
SuccessfulRunsSchema,
generate_dynamic_tax_df_schema,
)

logging.basicConfig(level=logging.DEBUG)

Expand Down Expand Up @@ -140,8 +148,17 @@ def generate_db_summary(
if db_label in TAXDB_LABELS:
df_list = []

if "PR2" in db_label:
short_tax_ranks = SHORT_PR2_TAX_RANKS
else:
short_tax_ranks = SHORT_TAX_RANKS

for run_acc, tax_df in tax_dfs.items():
df_list.append(parse_one_tax_file(run_acc, tax_df))
res_df = parse_one_tax_file(run_acc, tax_df)
res_schema = generate_dynamic_tax_df_schema(run_acc, short_tax_ranks)
res_schema.validate(res_df)

df_list.append(res_df)

res_df = pd.concat(df_list, axis=1).fillna(0)
res_df = res_df.sort_index()
Expand All @@ -154,6 +171,11 @@ def generate_db_summary(

elif db_label in ASV_TAXDB_LABELS:

if "PR2" in db_label:
short_tax_ranks = SHORT_PR2_TAX_RANKS
else:
short_tax_ranks = SHORT_TAX_RANKS

amp_region_dict = defaultdict(list)

for (
Expand All @@ -168,6 +190,8 @@ def generate_db_summary(
] # there are a lot of underscores in these names... but it is consistent
# e.g. ERR4334351_16S-V3-V4_DADA2-SILVA_asv_krona_counts.txt
amp_region_df = parse_one_tax_file(run_acc, tax_df)
res_schema = generate_dynamic_tax_df_schema(run_acc, short_tax_ranks)
res_schema.validate(amp_region_df)
amp_region_dict[amp_region].append(amp_region_df)

for amp_region, amp_region_dfs in amp_region_dict.items():
Expand Down Expand Up @@ -257,6 +281,8 @@ def summarise_analyses(runs: Path, analyses_dir: Path, output_prefix: str) -> No
:type output_prefix: str
"""
runs_df = pd.read_csv(runs, names=["run", "status"])
SuccessfulRunsSchema(runs_df) # Run validation on the successful_runs .csv file

all_db_labels = TAXDB_LABELS + ASV_TAXDB_LABELS
for db_label in all_db_labels:

Expand Down
4 changes: 4 additions & 0 deletions mgnify_pipelines_toolkit/constants/tax_ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,7 @@
"Genus",
"Species",
]

SHORT_TAX_RANKS = ["sk", "k", "p", "c", "o", "f", "g", "s"]

SHORT_PR2_TAX_RANKS = ["d", "sg", "dv", "sdv", "c", "o", "f", "g", "s"]
71 changes: 71 additions & 0 deletions mgnify_pipelines_toolkit/schemas/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pandas as pd
import pandera as pa
from pandera.typing import DataFrame, Series
import pandera.extensions as extensions

import pydantic


class SuccessfulRunsSchema(pa.DataFrameModel):
run: Series[str] = pa.Field(unique=True)
status: Series[str]

@pa.check(
"run",
name="run_validity_check",
raise_warning=True,
error="One or more run accessions do not fit the INSDC format [ERR*,SRR*,DRR*]. This is only a warning, not an error.",
)
def run_validity_check(cls, run: Series[str]) -> Series[bool]:
# This will only produce a WARNING, not an ERROR. This is to allow flexibility of running this on non-ENA/INSDC data
run_accession_regex = "(E|D|S)RR[0-9]{6,}"
return run.str.contains(run_accession_regex)

@pa.check(
"status",
name="status_vality_check",
error='The status column can only have values ["all_results", "no_asvs"].',
)
def status_vality_check(cls, status: Series[str]) -> Series[bool]:
possible_statuses = ["all_results", "no_asvs"]
return status.isin(possible_statuses)


class PydanticModel(pydantic.BaseModel):
df: DataFrame[pa.DataFrameModel]


@extensions.register_check_method(statistics=["short_tax_ranks"])
def is_valid_tax_hierarchy(pandas_obj, *, short_tax_ranks):

bool_list = []
short_tax_ranks.append(
"Unclassified"
) # This is the only non-hierarchical value we can still accept

for taxa in pandas_obj:
taxa_lst = [rank.split("__")[0] for rank in taxa.split(";")]
if len(set.intersection(set(taxa_lst), set(short_tax_ranks))) == len(taxa_lst):
bool_list.append(True)
else:
bool_list.append(False)

return pd.Series(bool_list)


def generate_dynamic_tax_df_schema(
run_acc: str, short_tax_ranks: list
) -> pa.DataFrameSchema:

tax_schema = pa.DataFrameSchema(
{run_acc: pa.Column(int, checks=pa.Check.ge(0))},
index=pa.Index(
str,
unique=True,
checks=[pa.Check.is_valid_tax_hierarchy(short_tax_ranks=short_tax_ranks)],
),
strict=True,
coerce=True,
)

return tax_schema
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ dependencies = [
"pandas==2.0.2",
"regex==2023.12.25",
"requests==2.32.3",
"click==8.1.7"
"click==8.1.7",
"pandera==0.21.1"
]

[build-system]
Expand All @@ -33,6 +34,7 @@ packages = ["mgnify_pipelines_toolkit",
"mgnify_pipelines_toolkit.analysis",
"mgnify_pipelines_toolkit.constants",
"mgnify_pipelines_toolkit.utils",
"mgnify_pipelines_toolkit.schemas",
"mgnify_pipelines_toolkit.analysis.shared",
"mgnify_pipelines_toolkit.analysis.amplicon",
"mgnify_pipelines_toolkit.analysis.assembly",
Expand Down Expand Up @@ -74,7 +76,8 @@ tests = [
"numpy==1.26.0",
"regex==2023.12.25",
"requests==2.32.3",
"click==8.1.7"
"click==8.1.7",
"pandera==0.21.1"
]

dev = [
Expand Down

0 comments on commit 228d322

Please sign in to comment.