diff --git a/docs/api/haptools.rst b/docs/api/haptools.rst index a1dd0d24..cab3b819 100644 --- a/docs/api/haptools.rst +++ b/docs/api/haptools.rst @@ -129,3 +129,11 @@ haptools.index module :members: :undoc-members: :show-inheritance: + +haptools.clump module +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: haptools.clump + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/commands/clump.rst b/docs/commands/clump.rst new file mode 100644 index 00000000..529d1848 --- /dev/null +++ b/docs/commands/clump.rst @@ -0,0 +1,108 @@ +.. _commands-clump: + + +clump +===== + +Clump a set of variants specified as a :doc:`.hap file `. + +The ``clump`` command creates a clump file joining SNPs or STRs in LD with one another. + +Usage +~~~~~ +.. code-block:: bash + + haptools clump \ + --verbosity [CRITICAL|ERROR|WARNING|INFO|DEBUG|NOTSET] \ + --summstats-snps PATH \ + --gts-snps PATH \ + --summstats-strs PATH \ + --gts-strs PATH \ + --clump-field TEXT \ + --clump-id-field TEXT \ + --clump-chrom-field TEXT \ + --clump-pos-field TEXT \ + --clump-p1 FLOAT \ + --clump-p2 FLOAT \ + --clump-r2 FLOAT \ + --clump-kb FLOAT \ + --ld [Exact|Pearson] \ + --out PATH + +Examples +~~~~~~~~ +.. code-block:: bash + + haptools clump \ + --summstats-snps tests/data/test_snpstats.linear \ + --gts-snps tests/data/simple.vcf \ + --clump-id-field ID \ + --clump-chrom-field CHROM \ + --clump-pos-field POS \ + --out test_snps.clump + +You can use ``--ld [Exact|Pearson]`` to indicate which type of LD calculation you'd like to perform. ``Exact`` utilizes an exact cubic solution adopted from `CubeX `_ whereas ``Pearson`` utilizes a Pearson R calculation. Note ``Exact`` only works on SNPs and not any other variant type eg. STRs. The default value is ``Pearson``. + +.. code-block:: bash + + haptools clump \ + --summstats-snps tests/data/test_snpstats.linear \ + --gts-snps tests/data/simple.vcf \ + --clump-id-field ID \ + --clump-chrom-field CHROM \ + --clump-pos-field POS \ + --ld Exact \ + --out test_snps.clump + +You can modify thresholds and values used in the clumping process. ``--clump-p1`` is the largest value of a p-value to consider being an index variant for a clump. ``--clump-p2`` dictates the maximum p-value any variant can have to be considered when clumping. ``--clump-r2`` is the R squared threshold where being greater than this value implies the candidate variant is in LD with the index variant. ``--clump-kb`` is the maximum distance upstream or downstream from the index variant to collect candidate variants for LD comparison. For example, ``--clump-kb 100`` implies all variants 100 Kb upstream and 100 Kb downstream from the variant will be considered. + +.. code-block:: bash + + haptools clump \ + --summstats-snps tests/data/test_snpstats.linear \ + --gts-snps tests/data/simple.vcf \ + --clump-id-field ID \ + --clump-chrom-field CHROM \ + --clump-pos-field POS \ + --clump-p1 0.001 \ + --clump-p2 0.05 \ + --clump-r2 0.7 \ + --clump-kb 200.5 \ + --out test_snps.clump + +You can also input STRs when calculating clumps. They can be used together with SNPs or alone. + +.. code-block:: bash + + haptools clump \ + --summstats-strs tests/data/test_strstats.linear \ + --gts-strs tests/data/simple_tr.vcf \ + --summstats-snps tests/data/test_snpstats.linear \ + --gts-snps tests/data/simple.vcf \ + --clump-id-field ID \ + --clump-chrom-field CHROM \ + --clump-pos-field POS \ + --ld Exact \ + --out test_snps.clump + +.. code-block:: bash + + haptools clump \ + --summstats-strs tests/data/test_strstats.linear \ + --gts-strs tests/data/simple_tr.vcf \ + --clump-id-field ID \ + --clump-chrom-field CHROM \ + --clump-pos-field POS \ + --ld Exact \ + --out test_snps.clump + +All files used in these examples are described :doc:`here `. + + +Detailed Usage +~~~~~~~~~~~~~~ + +.. click:: haptools.__main__:main + :prog: haptools + :show-nested: + :commands: clump \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 9729cbdf..49b5e549 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -22,6 +22,8 @@ Commands * :doc:`haptools index `: Sort, compress, and index our custom file format for haplotypes. +* :doc:`haptools clump `: Convert variants in LD with one another into clumps. + * :doc:`haptools ld `: Compute Pearson's correlation coefficient between a target haplotype and a set of haplotypes. .. figure:: https://drive.google.com/uc?id=1c0i_Hjms7579s24zRsKp5yMs7BxNHed_ @@ -91,6 +93,7 @@ There is an option to *Cite this repository* on the right sidebar of `the reposi commands/karyogram.rst commands/transform.rst commands/index.rst + commands/clump.rst commands/ld.rst .. toctree:: diff --git a/haptools/__main__.py b/haptools/__main__.py index d1b08cc3..f9c47219 100755 --- a/haptools/__main__.py +++ b/haptools/__main__.py @@ -883,6 +883,130 @@ def index( index_haps(haplotypes, sort, output, log) +@main.command(short_help="Clump summary stat files.") +@click.option( + "--summstats-snps", + type=click.Path(path_type=Path), + help="File to load snps summary statistics", +) +@click.option( + "--summstats-strs", + type=click.Path(path_type=Path), + help="File to load strs summary statistics", +) +@click.option( + "--gts-snps", type=click.Path(path_type=Path), help="SNP genotypes (VCF or PGEN)" +) +@click.option("--gts-strs", type=click.Path(path_type=Path), help="STR genotypes (VCF)") +@click.option( + "--clump-p1", type=float, default=0.0001, help="Max pval to start a new clump" +) +@click.option( + "--clump-p2", type=float, default=0.01, help="Filter for pvalue less than" +) +@click.option( + "--clump-id-field", type=str, default="SNP", help="Column header of the variant ID" +) +@click.option( + "--clump-field", type=str, default="P", help="Column header of the p-values" +) +@click.option( + "--clump-chrom-field", + type=str, + default="CHR", + help="Column header of the chromosome", +) +@click.option( + "--clump-pos-field", type=str, default="POS", help="Column header of the position" +) +@click.option( + "--clump-kb", + type=float, + default=250, + help="clump kb radius", +) +@click.option( + "--clump-r2", + type=float, + default=0.5, + help="r^2 threshold", +) +@click.option( + "--ld", + type=click.Choice(["Exact", "Pearson"]), + default="Pearson", + show_default=True, + help=( + "Calculation type to infer LD, Exact Solution or " + "Pearson R. (Exact|Pearson). Note the Exact Solution " + "works best when all three genotypes are present (0,1,2) in " + "the variants being compared." + ), +) +@click.option( + "--out", + type=click.Path(path_type=Path), + required=True, + help="Output filename", +) +@click.option( + "-v", + "--verbosity", + type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"]), + default="INFO", + show_default=True, + help="The level of verbosity desired", +) +def clump( + summstats_snps: Path, + summstats_strs: Path, + gts_snps: Path, + gts_strs: Path, + clump_p1: float, + clump_p2: float, + clump_id_field: str, + clump_field: str, + clump_chrom_field: str, + clump_pos_field: str, + clump_kb: float, + clump_r2: float, + ld: str, + out: Path, + verbosity: str = "CRITICAL", +): + """ + Performs clumping on datasets with SNPs, SNPs and STRs, and STRs. + Clumping is the process of identifying SNPs or STRs that are highly + correlated with one another and concatenating them all together into + a single "clump" in order to not repeat the same effect size due to + LD. + """ + from .logging import getLogger + from .clump import clumpstr + + log = getLogger(name="clump", level=verbosity) + log.debug(f"Loading SNPs from {summstats_snps} {gts_snps}") + log.debug(f"Loading STRs from {summstats_strs} {gts_strs}") + + clumpstr( + summstats_snps, + summstats_strs, + gts_snps, + gts_strs, + clump_p1, + clump_p2, + clump_id_field, + clump_field, + clump_chrom_field, + clump_pos_field, + clump_kb, + clump_r2, + ld, + out, + log, + ) + + if __name__ == "__main__": # run the CLI if someone tries 'python -m haptools' on the command line main(prog_name="haptools") diff --git a/haptools/clump.py b/haptools/clump.py new file mode 100644 index 00000000..ecd142c3 --- /dev/null +++ b/haptools/clump.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python + +# To test: ./clumpSTR.py --summstats-snps tests/eur_gwas_pvalue_chr19.LDL.glm.linear --clump-snp-field ID --clump-field p-value --clump-chrom-field CHROM --clump-pos-field position --clump-p1 0.2 --out test.clump +from logging import getLogger, Logger +import numpy as np +import logging +import math +import sys + +from haptools.data.genotypes import Genotypes, GenotypesVCF, GenotypesTR + + +class Variant: + def __init__(self, varid, chrom, pos, pval, vartype): + self.varid = varid + self.chrom = chrom + self.pos = pos + self.pval = pval + self.vartype = vartype + + def __str__(self): + return "%s %s %s %s %s" % ( + self.varid, + self.chrom, + self.pos, + self.pval, + self.vartype, + ) + + +class SummaryStats: + """ + Load and process summary statistics + + Attributes + ---------- + summstats: list[Variant] + list of Variant objects that represent all summary + statistics in the file. + log: Logger + A logging instance for recording debug statements. + + Examples + -------- + Loading a summary stats file, grabbing an index variant, and its + candidate variants to calculate LD. + + >>> summstats = SummaryStats(log) + >>> summstats.Load(summstats_strs, vartype="SNP", pthresh=clump_p2, + id_field=clump_id_field, p_field=clump_field, + chrom_field=clump_chrom_field, pos_field=clump_pos_field) + >>> indexvar = summstats.GetNextIndexVariant(clump_p1) + >>> candidates = summstats.QueryWindow(indexvar, clump_kb) + """ + + def __init__(self, log: Logger = None): + self.summstats = [] + self.log = log or getLogger(self.__class__.__name__) + + def Load( + self, + statsfile, + vartype="SNP", + pthresh=1.0, + id_field="SNP", + p_field="P", + chrom_field="CHR", + pos_field="POS", + ): + """ + Load summary statistics + Ignore variants with pval < pthresh + Not yet implemented + """ + summstats = [] # List of Variants + + # First, parse header line to get col. numbers + f = open(statsfile, "r") + header = f.readline() + if header.startswith("#"): + header = header[1:] + header_items = [item.strip() for item in header.split()] + try: + snp_col = header_items.index(id_field) + except ValueError: + self.log.error("Could not find %s in header" % id_field) + sys.exit(1) + try: + p_col = header_items.index(p_field) + except ValueError: + self.log.error("Could not find %s in header" % p_field) + sys.exit(1) + try: + chrom_col = header_items.index(chrom_field) + except ValueError: + self.log.error("Could not find %s in header" % chrom_field) + sys.exit(1) + try: + pos_col = header_items.index(pos_field) + except ValueError: + self.log.error("Could not find %s in header" % pos_field) + sys.exit(1) + + # Now, load in stats. Skip things with pval>pthresh + line = f.readline() + while line.strip() != "": + items = [item.strip() for item in line.strip().split()] + if float(items[p_col]) > pthresh: + line = f.readline() + continue + summstats.append( + Variant( + items[snp_col], + items[chrom_col], + int(items[pos_col]), + float(items[p_col]), + vartype, + ) + ) + line = f.readline() + f.close() + self.summstats.extend(summstats) + + def GetNextIndexVariant(self, index_pval_thresh): + """ + Get the next index variant, which is the + variant with the best p-value + If no more variants below the clump-p1 threshold, + return None + Not yet implemented + """ + best_var = None + best_var_p = 1.0 + for variant in self.summstats: + if variant.pval < best_var_p and variant.pval < index_pval_thresh: + best_var = variant + best_var_p = variant.pval + return best_var + + def QueryWindow(self, indexvar, window_kb): + """ + Find all candidate variants in the specified + window around the index variant + Not yet implemented + """ + # First get stats on the indexvariant + chrom = indexvar.chrom + pos = indexvar.pos + + # Find candidates in the window + candidates = [] + for variant in self.summstats: + if variant.chrom == chrom and abs(variant.pos - pos) / 1000 < window_kb: + candidates.append(variant) + + return candidates + + def RemoveClump(self, clumpvars): + """ + Remove the variants from a clump + from further consideration + """ + keepvars = [] + for variant in self.summstats: + if variant not in clumpvars: + keepvars.append(variant) + self.summstats = keepvars + + +def _SortSamples(samples): + """ + Sort samples along with their indices. + """ + # Create indices to track for each sample + inds = np.arange(len(samples)) + + # Sort indices and samples together + sorted_data = [[sample, ind] for sample, ind in sorted(zip(samples, inds))] + + # Grab samples and inds separately + sorted_samples = [sample for sample, _ in sorted_data] + sorted_inds = [ind for _, ind in sorted_data] + + return sorted_samples, sorted_inds + + +def GetOverlappingSamples(snpgts, strgts): + """ + Get indices of overlapping samples for snps and strs + + Parameters + ---------- + snpgts: GenotypesVCF + SNP Genotypes object + strgts: GenotypesTR + STR Genotypes object + + Returns + ------- + snp_samples: list(int) + Indices of overlapping samples for snps + str_samples: list(int) + Indices of overlapping samples for strs + """ + # Sort samples and respective indices in sample array together + snp_match_inds = [] + str_match_inds = [] + snp_samples, snp_inds = _SortSamples(snpgts.samples) + str_samples, str_inds = _SortSamples(strgts.samples) + + # Since both lists are sorted we iterate over each list using counters + # to determine where potential matching samples are + snp_counter = 0 + str_counter = 0 + while snp_counter < len(snp_inds) and str_counter < len(str_inds): + if str_samples[str_counter] < snp_samples[snp_counter]: + str_counter += 1 + elif str_samples[str_counter] == snp_samples[snp_counter]: + snp_match_inds.append(snp_inds[snp_counter]) + str_match_inds.append(str_inds[str_counter]) + snp_counter += 1 + str_counter += 1 + else: + snp_counter += 1 + + return snp_match_inds, str_match_inds + + +def LoadVariant(var, gts, log): + """ + Extract vector of genotypes for this variant + """ + # Grab variant from snps or strs depending on variant type + var_ind = (gts.variants["pos"] == int(var.pos)) & ( + gts.variants["chrom"] == var.chrom + ) + variant_gts = np.sum(gts.data[:, var_ind, :], axis=2).flatten() + return variant_gts + + +def _CalcChiSQ(f00, f01, f10, f11, gt_counts, n): + """ + Calculate Chi-squared test stat for given freqs. + """ + chisq_exp = np.zeros((3, 3)) + root_exp = np.zeros((3, 3)) + + # calculate expected values for a given root + root_exp[0, 0] = n * f00**2 + root_exp[0, 1] = 2 * n * f00 * f01 + root_exp[0, 2] = n * f01**2 + root_exp[1, 0] = 2 * n * f00 * f10 + root_exp[1, 1] = 2 * n * f01 * f10 + 2 * n * f00 * f11 + root_exp[1, 2] = 2 * n * f01 * f11 + root_exp[2, 0] = n * f10**2 + root_exp[2, 1] = 2 * n * f10 * f11 + root_exp[2, 2] = n * f11**2 + for i in range(3): + for j in range(3): + if root_exp[i, j] > 0.0: + chisq_exp = (gt_counts[i, j] - root_exp[i, j]) ** 2 / root_exp[i, j] + + return np.sum(chisq_exp) + + +def _CalcLDStats(f00, p, q, gt_counts, n): + """ + Given frequency of gt 0|0 (f11) and major and minor allele freqs p and q calculate stats. + """ + f01 = p - f00 + f10 = q - f00 + f11 = 1 - (f00 + f01 + f10) + D = (f00 * f11) - (f01 * f10) + if D >= 0.0: + Dmax = min(p * (1.0 - q), q * (1.0 - p)) + else: + Dmax = min(p * q, (1 - p) * (1 - q)) + Dprime = D / Dmax + r_squared = (D**2) / (p * (1 - p) * q * (1 - q)) + + return ( + round(Dprime, 6), + round(r_squared, 6), + _CalcChiSQ(f00, f01, f10, f11, gt_counts, n), + ) + + +def _CalcBestRoot(real_roots, minhap, maxhap, p, q, gt_counts, n): + """ + Given a list of real roots (max possible 3) calculate the best root. + The best root is the one with the lowest chisq test statistic value. + """ + # determine the best root by grabbing the one with the lowest chisq + best_Dprime = 0 + best_rsquared = 0 + best_chisq = np.inf + for root_freq00 in real_roots: + # calculate LD stats given root freq is within bounds + if root_freq00 >= minhap - 0.00001 and root_freq00 <= maxhap + 0.00001: + Dprime, r_squared, chisq = _CalcLDStats(root_freq00, p, q, gt_counts, n) + if chisq < best_chisq: + best_Dprime = Dprime + best_rsquared = r_squared + best_chisq = chisq + + return best_Dprime, best_rsquared + + +def ComputeExactLD(candidate_gt, index_gt, log): + """ + Compute exact solution of haplotype frequencies to calculate r squared value. + NOTE currently this approach only works for biallelic variants since having more variants + causes the equation we're solving for to be a cubic but instead to the degree of n where + n is the total number of alelles which also invalidates STRs. + + Parameters + ---------- + candidate_gt: np.array + array of size (genotypes,) where genotypes is the number of samples + index_gt: np.array + array of size (genotypes,) where genotypes is the number of samples + log: Logger + A logging instance for recording debug statements. + + Returns + ------- + r_squared: float + R squared value inferred from ML solution. + """ + # load in 3x3 array where axes are genotypes (0,1,2) for each variant + # y-axis = candidate gt, x-axis = index_gt + gt_counts = np.zeros((3, 3)) + for gt1 in range(3): + # subset candidate gts to genotype gt + subset_gt1 = candidate_gt == gt1 + for gt2 in range(3): + subset_index_gt = index_gt[subset_gt1] + gt_counts[gt1, gt2] = np.sum(subset_index_gt == gt2) + + n = np.sum(gt_counts) + p = (2.0 * np.sum(gt_counts[0, :]) + np.sum(gt_counts[1, :])) / (2.0 * n) + q = (2.0 * np.sum(gt_counts[:, 0]) + np.sum(gt_counts[:, 1])) / (2.0 * n) + + num_alt = 2.0 * gt_counts[0, 0] + gt_counts[0, 1] + gt_counts[1, 0] + a = 4.0 * n + b = 2.0 * n * (1.0 - 2.0 * p - 2.0 * q) - 2.0 * num_alt - gt_counts[1, 1] + c = ( + -num_alt * (1.0 - 2.0 * p - 2.0 * q) + - gt_counts[1, 1] * (1.0 - p - q) + + 2.0 * n * p * q + ) + d = -num_alt * p * q + + minhap = num_alt / (2.0 * float(n)) + maxhap = (num_alt + gt_counts[1, 1]) / (2.0 * float(n)) + + xN = -b / (3.0 * a) + d2 = (math.pow(b, 2) - 3.0 * a * c) / (9 * math.pow(a, 2)) + yN = a * math.pow(xN, 3) + b * math.pow(xN, 2) + c * xN + d + yN2 = math.pow(yN, 2) + h2 = 4 * math.pow(a, 2) * math.pow(d2, 3) + + # store all real roots to cubic to iterate over and determine which is best + real_roots = [] + + # three possible scenarios of solutions + if yN2 > h2: + # calculate real root alpha + number1 = 0.0 + number2 = 0.0 + if (1.0 / (2.0 * a) * (-yN + math.pow((yN2 - h2), 0.5))) < 0: + number1 = -math.pow( + -(1.0 / (2.0 * a) * (-yN + math.pow((yN2 - h2), 0.5))), 1.0 / 3.0 + ) + else: + number1 = math.pow( + (1.0 / (2.0 * a) * (-yN + math.pow((yN2 - h2), 0.5))), 1.0 / 3.0 + ) + + if (1.0 / (2.0 * a) * (-yN - math.pow((yN2 - h2), 0.5))) < 0: + number2 = -math.pow( + -(1.0 / (2.0 * a) * (-yN - math.pow((yN2 - h2), 0.5))), 1.0 / 3.0 + ) + else: + number2 = math.pow( + (1.0 / (2.0 * a) * (-yN - math.pow((yN2 - h2), 0.5))), 1.0 / 3.0 + ) + + # singular real root + alpha = xN + number1 + number2 + + # store real root alpha + real_roots = [alpha] + + elif yN2 == h2: + # Calculate three real roots alpha beta and gamma + delta = math.pow((yN / 2.0 * a), (1.0 / 3.0)) + alpha = xN + delta + beta = xN + delta + gamma = xN - 2.0 * delta + + # store all real roots + real_roots = [alpha, beta, gamma] + + elif yN2 < h2: + # calculate 3 real roots alpha beta and gamma + h = math.pow(h2, 0.5) + theta = (math.acos(-yN / h)) / 3.0 + delta = math.pow(d2, 0.5) + alpha = xN + 2.0 * delta * math.cos(theta) + beta = xN + 2.0 * delta * math.cos(2.0 * math.pi / 3.0 + theta) + gamma = xN + 2.0 * delta * math.cos(4.0 * math.pi / 3.0 + theta) + + # store all real roots + real_roots = [alpha, beta, gamma] + + else: + raise Exception(f"Can't calculate r squared from given values {yN2} and {h2}") + + # Solve for best roots + best_Dprime, best_rsquared = _CalcBestRoot( + real_roots, minhap, maxhap, p, q, gt_counts, n + ) + return best_Dprime, best_rsquared + + +def _FilterGts(candidate_gt, index_gt, log): + """ + Filter invalid values from gts which is 255 since uint8 encodes -1 as 255 + """ + valid_gts = (candidate_gt < 255) & (index_gt < 255) + candidate_gt = candidate_gt[valid_gts] + index_gt = index_gt[valid_gts] + + log.debug(f"Valid Genotype Indices: {valid_gts}") + log.debug(f"Candidate GTs: {candidate_gt}") + log.debug(f"Index GTs: {index_gt}") + return candidate_gt, index_gt + + +def ComputeLD(candidate_gt, index_gt, LD_type, log): + """ + Compute the LD between two variants + """ + # Filter invalid gt values + candidate_gt, index_gt = _FilterGts(candidate_gt, index_gt, log) + if not (np.size(candidate_gt) and np.size(index_gt)): + return None, 0 + + # Check if all values in either array are the same + if np.unique(candidate_gt).shape[0] == 1 or np.unique(index_gt).shape[0] == 1: + log.debug("GTs between one of the variants are constant across all samples.") + return None, np.nan + + # Compute and Maximum likelihood solution or Pearson r2 + if LD_type == "Exact": + return ComputeExactLD(candidate_gt, index_gt, log) + elif LD_type == "Pearson": + return None, np.corrcoef(index_gt, candidate_gt)[0, 1] ** 2 + + +def WriteClump(indexvar, clumped_vars, outf): + """ + Write a clump to the output file + Not yet implemented + """ + outf.write( + "\t".join( + [ + indexvar.varid, + indexvar.chrom, + str(indexvar.pos), + str(indexvar.pval), + indexvar.vartype, + ",".join([str(item) for item in clumped_vars]), + ] + ) + + "\n" + ) + + +def clumpstr( + summstats_snps, + summstats_strs, + gts_snps, + gts_strs, + clump_p1, + clump_p2, + clump_id_field, + clump_field, + clump_chrom_field, + clump_pos_field, + clump_kb, + clump_r2, + LD_type, + out, + log, +): + ###### User checks ########## + # if summstats_snps, also need gts_snps + log.debug(f"Validating SNP files {summstats_snps} or {gts_snps}") + if summstats_snps or gts_snps: + try: + assert gts_snps + assert summstats_snps + except: + raise Exception( + f"One of summstats-snps {summstats_snps} and gts-snps {gts_snps} is " + "not present. Please ensure both have been inputted correctly." + ) + + # if summstats_strs, also need gts_strs + log.debug(f"Validating STR files {summstats_strs} or {gts_strs}") + if summstats_strs or gts_strs: + try: + assert gts_strs + assert summstats_strs + except: + raise Exception( + f"One of summstats-strs {summstats_strs} and gts-strs {gts_strs} is " + "not present. Please ensure both have been inputted correctly." + ) + + if summstats_strs and LD_type == "Exact": + raise Exception( + "The exact method of computing LD can only be used with biallelic loci. " + + "STRs are not compatible with the exact LD compute method. " + ) + + ###### Load summary stats ########## + summstats = SummaryStats(log) + if summstats_snps is not None: + summstats.Load( + summstats_snps, + vartype="SNP", + pthresh=clump_p2, + id_field=clump_id_field, + p_field=clump_field, + chrom_field=clump_chrom_field, + pos_field=clump_pos_field, + ) + if summstats_strs is not None: + summstats.Load( + summstats_strs, + vartype="STR", + pthresh=clump_p2, + id_field=clump_id_field, + p_field=clump_field, + chrom_field=clump_chrom_field, + pos_field=clump_pos_field, + ) + + ###### Set up genotypes ########## + snpgts = None + strgts = None + gts = None + if gts_snps: + if str(gts_snps).endswith("pgen"): + log.debug("Loading SNP Genotypes.") + snpgts = GenotypesPLINK.load(gts_snps) + else: + log.debug("Loading SNP Genotypes.") + snpgts = GenotypesVCF.load(gts_snps) + if gts_strs: + log.debug("Loading STR Genotypes.") + strgts = GenotypesTR.load(gts_strs) + + if gts_snps and gts_strs: + log.debug("Calculating set of overlapping samples between STRs and SNPs.") + # Grab all shared samples between snp list and str list + # NOTE samples are returned such that resulting data is matching + snp_samples, str_samples = GetOverlappingSamples(snpgts, strgts) + log.debug(f"Shared samples: {snp_samples}") + + # NOTE snpgts has data, variants, and samples where data is alleles (samples x variants x alleles) + # variants has id, pos, chrom, ref, alt for snps + # samples is list of samples corresponding to x-axis of samples + snpgts.data = snpgts.data[snp_samples, :, :] + snpgts.samples = tuple(np.array(snpgts.samples)[snp_samples]) + strgts.data = strgts.data[str_samples, :, :] + strgts.samples = tuple(np.array(strgts.samples)[str_samples]) + + # Merge STR and SNP GTs + gts = Genotypes.merge_variants((snpgts, strgts), fname=None) + elif gts_snps: + gts = snpgts + elif gts_strs: + gts = strgts + else: + raise Exception("Unable to load valid genotype data.") + + ###### Setup output file ########## + outf = open(out, "w") + outf.write("\t".join(["ID", "CHROM", "POS", "P", "VARTYPE", "CLUMPVARS"]) + "\n") + + ###### Perform clumping ########## + indexvar = summstats.GetNextIndexVariant(clump_p1) + while indexvar is not None: + # Load indexvar gts + indexvar_gt = LoadVariant(indexvar, gts, log) + # Collect candidate variants within range of index variant + candidates = summstats.QueryWindow(indexvar, clump_kb) + + log.debug(f"Current index variant: {indexvar}") + + # calculate LD between candidate vars and index var + clumpvars = [] + for c in candidates: + # load candidate variant c genotypes + candidate_gt = LoadVariant(c, gts, log) + Dprime, r2 = ComputeLD(candidate_gt, indexvar_gt, LD_type, log) + # If using pearson Dprime is not calculated + log.debug( + f"D' and r2 between {indexvar} with {c}\nD' = {Dprime}, r^2 = {r2}" + ) + if r2 > clump_r2: + clumpvars.append(c) + WriteClump(indexvar, clumpvars, outf) + summstats.RemoveClump(clumpvars + [indexvar]) + indexvar = summstats.GetNextIndexVariant(clump_p1) + outf.close() diff --git a/haptools/data/genotypes.py b/haptools/data/genotypes.py index 03932132..499937ef 100644 --- a/haptools/data/genotypes.py +++ b/haptools/data/genotypes.py @@ -757,6 +757,279 @@ def write(self): vcf.close() +class GenotypesTR(Genotypes): + """ + A class for processing TR genotypes from a file + Unlike the base Genotypes class, this class genotypes will be repeat number + in the variants array + + Attributes + ---------- + data : np.array + See documentation for :py:attr:`~.Genotypes.data` + fname : Path | str + See documentation for :py:attr:`~.Genotypes.fname` + samples : tuple[str] + See documentation for :py:attr:`~.Genotypes.samples` + variants : np.array + Variant-level meta information: + 1. ID + 2. CHROM + 3. POS + 4. [REF, ALT1, ALT2, ...] + log: Logger + See documentation for :py:attr:`~.Genotypes.log` + """ + + def __init__(self, fname: Path | str, log: Logger = None): + super().__init__(fname, log) + # TODO see if this is correct for reading in STR datatypes + dtype = {k: v[0] for k, v in self.variants.dtype.fields.items()} + self.variants = np.array([], dtype=list(dtype.items()) + [("alleles", object)]) + + def _variant_arr(self, record: Variant): + """ + See documentation for :py:meth:`~.Genotypes._variant_arr` + """ + + # TODO see if this is correct since ref alt doesn't really exist + return np.array( + ( + record.record_id, + record.chrom, + record.pos, + (record.ref_allele, *record.alt_alleles), + ), + dtype=self.variants.dtype, + ) + + @classmethod + def load( + cls: GenotypesTR, + fname: Path | str, + region: str = None, + samples: list[str] = None, + variants: set[str] = None, + ) -> Genotypes: + """ + Load STR genotypes from a VCF file + + Read the file contents, check the genotype phase, and create the MAC matrix + + Parameters + ---------- + fname + See documentation for :py:attr:`~.Data.fname` + region : str, optional + See documentation for :py:meth:`~.Genotypes.read` + samples : list[str], optional + See documentation for :py:meth:`~.Genotypes.read` + variants : set[str], optional + See documentation for :py:meth:`~.Genotypes.read` + + Returns + ------- + Genotypes + A Genotypes object with the data loaded into its properties + """ + genotypes = cls(fname) + genotypes.read(region, samples, variants) + genotypes.check_phase() + return genotypes + + def read( + self, + region: str = None, + samples: list[str] = None, + variants: set[str] = None, + max_variants: int = None, + ): + """ + Read genotypes from a VCF into a numpy matrix stored in :py:attr:`~.Genotypes.data` + + Raises + ------ + ValueError + If the genotypes array is empty + + Parameters + ---------- + region : str, optional + The region from which to extract genotypes; ex: 'chr1:1234-34566' or 'chr7' + + For this to work, the VCF must be indexed and the seqname must match! + + Defaults to loading all genotypes + samples : list[str], optional + A subset of the samples from which to extract genotypes + + Defaults to loading genotypes from all samples + variants : set[str], optional + A set of variant IDs for which to extract genotypes + + All other variants will be ignored. This may be useful if you're running + out of memory. + max_variants : int, optional + The maximum mumber of variants to load from the file. Setting this value + helps preallocate the arrays, making the process faster and less memory + intensive. You should use this option if your processes are frequently + "Killed" from memory overuse. + + If you don't know how many variants there are, set this to a large number + greater than what you would except. The np array will be resized + appropriately. You can also use the bcftools "counts" plugin to obtain the + number of expected sites within a region. + + Note that this value is ignored if the variants argument is provided. + """ + super().read() + records = self.__iter__(region=region, samples=samples, variants=variants) + if variants is not None: + max_variants = len(variants) + # check whether we can preallocate memory instead of making copies + if max_variants is None: + self.log.warning( + "The max_variants parameter was not specified. We have no choice but to" + " append to an ever-growing array, which can lead to memory overuse!" + ) + variants_arr = [] + data_arr = [] + for rec in records: + variants_arr.append(rec.variants) + data_arr.append(rec.data) + self.log.info(f"Copying {len(variants_arr)} variants into np arrays.") + # convert to np array for speedy operations later on + self.variants = np.array(variants_arr, dtype=self.variants.dtype) + self.data = np.array(data_arr, dtype=np.uint8) + else: + # preallocate arrays! this will save us lots of memory and speed b/c + # appends can sometimes make copies + self.variants = np.empty((max_variants,), dtype=self.variants.dtype) + # in order to check_phase() later, we must store the phase info, as well + self.data = np.empty( + (max_variants, len(self.samples), (2 + (not self._prephased))), + dtype=np.uint8, + ) + num_seen = 0 + for rec in records: + if num_seen >= max_variants: + break + self.variants[num_seen] = rec.variants + self.data[num_seen] = rec.data + num_seen += 1 + if max_variants > num_seen: + self.log.info( + f"Removing {max_variants-num_seen} unneeded variant records that " + "were preallocated b/c max_variants was specified." + ) + self.variants = self.variants[:num_seen] + self.data = self.data[:num_seen] + if 0 in self.data.shape: + self.log.warning( + "Failed to load genotypes. If you specified a region, check that the" + " contig name matches! For example, double-check the 'chr' prefix." + ) + # transpose the GT matrix so that samples are rows and variants are columns + self.log.info(f"Transposing genotype matrix of size {self.data.shape}.") + self.data = self.data.transpose((1, 0, 2)) + + def __iter__( + self, region: str = None, samples: list[str] = None, variants: set[str] = None + ) -> Iterator[namedtuple]: + """ + Read genotypes from a VCF line by line without storing anything + + Parameters + ---------- + region : str, optional + See documentation for :py:meth:`~.Genotypes.read` + samples : list[str], optional + See documentation for :py:meth:`~.Genotypes.read` + variants : set[str], optional + See documentation for :py:meth:`~.Genotypes.read` + + Returns + ------- + Iterator[namedtuple] + See documentation for :py:meth:`~.Genotypes._iterate` + """ + vcf = VCF(str(self.fname), samples=samples, lazy=True) + self.samples = tuple(vcf.samples) + # call another function to force the lines above to be run immediately + # see https://stackoverflow.com/a/36726497 + return self._iterate(vcf, region, variants) + + def _iterate(self, vcf: VCF, region: str = None, variants: set[str] = None): + """ + A generator over the lines of a VCF + + This is a helper function for :py:meth:`~.Genotypes.__iter__` + + Parameters + ---------- + vcf: VCF + The cyvcf2.VCF object from which to fetch variant records + region : str, optional + See documentation for :py:meth:`~.Genotypes.read` + variants : set[str], optional + See documentation for :py:meth:`~.Genotypes.read` + + Yields + ------ + Iterator[namedtuple] + An iterator over each line in the file, where each line is encoded as a + namedtuple containing each of the class properties + """ + self.log.info(f"Loading genotypes from {len(self.samples)} samples") + Record = namedtuple("Record", "data variants") + # iterable used to collect records + vcfiter = vcf(region) + tr_records = trh.TRRecordHarmonizer(vcffile=vcf, vcfiter=vcfiter, region=region) + num_seen = 0 + # iterate over each line in the VCF + # note, this can take a lot of time if there are many samples + for variant in tr_records: + if variants is not None and variant.record_id not in variants: + if num_seen >= len(variants): + # exit early if we've already found all the variants + break + continue + # save meta information about each variant + variant_arr = self._variant_arr(variant) + # extract the genotypes to a matrix of size n x 3 + # the last dimension has three items: + # 1) presence of REF in strand one + # 2) presence of REF in strand two + # 3) whether the genotype is phased (if self._prephased is False) + # Check + try: + data = np.array(variant.vcfrecord.genotypes, dtype=np.uint8) + + except ValueError: + self.log.warning( + "The current variant in the VCF contains genotypes that do not have" + " 2 alleles. " + + "This will result in a significant slowdown due to iterating" + " over " + + "all GTs and fixing the shape issue. Please update the VCF by " + + "adding another allele to each GT with only one allele to fix the" + " slowdown." + ) + data = [] + for gt_sample in variant.vcfrecord.genotypes: + if len(gt_sample) == 2: + new_gt_sample = [gt_sample[0], -1, gt_sample[1]] + else: + new_gt_sample = gt_sample + data.append(new_gt_sample) + data = np.array(data, dtype=np.uint8) + + data = data[:, : (2 + (not self._prephased))] + yield Record(data, variant_arr) + num_seen += 1 + vcf.close() + + class GenotypesTR(Genotypes): """ A class for processing TR genotypes from a file diff --git a/haptools/data/tr_harmonizer.py b/haptools/data/tr_harmonizer.py index e207719c..ee03fff2 100644 --- a/haptools/data/tr_harmonizer.py +++ b/haptools/data/tr_harmonizer.py @@ -1267,7 +1267,6 @@ def UniqueLengthGenotypes(self) -> Set[int]: Find allele indicies corresponding to the unique length alleles. Equivalent to calling :code:`set(UniqueLengthGenotypeMapping().values())` - Returns ------- Set[int] diff --git a/tests/data/test_snpstats.linear b/tests/data/test_snpstats.linear new file mode 100644 index 00000000..89fea184 --- /dev/null +++ b/tests/data/test_snpstats.linear @@ -0,0 +1,5 @@ +#CHROM POS ID REF ALT A1 TEST OBS_CT BETA SE T_STAT P ERRCODE +1 10114 1:10114:T:C T C C ADD 2504 -0.010774 1.0004 -0.09 0.99 . +1 10116 1:10116:A:G A G A ADD 2504 -0.436 1.00034 -0.345 0.2345 . +1 10117 1:10117:C:A C A A ADD 2504 1.50 0.45 2.5 0.0005 . +1 10122 1:10122:A:G A G A ADD 2504 1.00 0.001 42.0 1.26e-102 . diff --git a/tests/data/test_strstats.linear b/tests/data/test_strstats.linear new file mode 100644 index 00000000..7e665e5f --- /dev/null +++ b/tests/data/test_strstats.linear @@ -0,0 +1,5 @@ +#CHROM POS ID REF ALT A1 TEST OBS_CT BETA SE T_STAT P ERRCODE +1 10114 1:10114:GTT GTT GTTGTT GTT ADD 2504 -0.010774 1.0004 -0.09 0.99 . +1 11000 1:10116:ACACAC ACACAC AC AC ADD 2504 -0.436 1.00034 -0.345 0.2345 . +1 12117 1:10117:AAA AAA AAAA AAA ADD 2504 1.50 0.45 2.5 0.0005 . +1 13122 1:10122:GTGT GTGT GTGTGT GTGT ADD 2504 1.00 0.001 42.0 1.26e-102 . diff --git a/tests/test_clump.py b/tests/test_clump.py new file mode 100644 index 00000000..7b98347d --- /dev/null +++ b/tests/test_clump.py @@ -0,0 +1,299 @@ +import os +from pathlib import Path + +import pytest +import numpy as np +from cyvcf2 import VCF +from click.testing import CliRunner + +from haptools.__main__ import main +from haptools.logging import getLogger +from haptools.data.genotypes import GenotypesVCF +from haptools.data.genotypes import GenotypesTR +from haptools.clump import ( + GetOverlappingSamples, + _SortSamples, + LoadVariant, + _FilterGts, + ComputeLD, + clumpstr, + Variant, +) + +DATADIR = Path(__file__).parent.joinpath("data") +log = getLogger(name="test") + + +class TestClump: + def _ld_expected(self): + return [0.5625, 0.6071, 0.5977, 0.5398] + + def test_loading_snps(self): + gts_snps = DATADIR.joinpath("outvcf_test.vcf.gz") + snpgts = GenotypesVCF.load(str(gts_snps)) + snpvars = [ + Variant("test1", "1", "10114", "0.05", "snp"), + Variant("test2", "1", "59423090", "0.05", "snp"), + Variant("test3", "2", "10122", "0.05", "snp"), + ] + strgts = None + + answers = [ + np.array([2, 2, 0, 0, 0]), + np.array([0, 0, 2, 2, 2]), + np.array([0, 0, 2, 2, 2]), + ] + + for var, answer in zip(snpvars, answers): + vargts = LoadVariant(var, snpgts, log) + assert len(vargts) == snpgts.data.shape[0] + assert np.array_equal(vargts, answer) + + def test_sample_sorting(self): + test1 = ["Sample_02", "Sample_00", "Sample_01"] + test2 = ["Sample_3", "Sample_2", "Sample_1"] + test3 = ["Sample_0", "Sample_1", "Sample_2"] + test1_samples, test1_inds = _SortSamples(test1) + test2_samples, test2_inds = _SortSamples(test2) + test3_samples, test3_inds = _SortSamples(test3) + + assert test1_samples == ["Sample_00", "Sample_01", "Sample_02"] + assert test1_inds == [1, 2, 0] + assert test2_samples == ["Sample_1", "Sample_2", "Sample_3"] + assert test2_inds == [2, 1, 0] + assert test3_samples == ["Sample_0", "Sample_1", "Sample_2"] + assert test3_inds == [0, 1, 2] + + def test_overlapping_samples(self): + # Test the GetOverlappingSamples function + snps = GenotypesVCF(fname="NA") + strs = GenotypesTR(fname="NA") + + # Test 1 No Matching + snps.samples = ["Sample_02", "Sample_00", "Sample_01"] + strs.samples = ["Sample_3", "Sample_2", "Sample_1"] + snp_inds, str_inds = GetOverlappingSamples(snps, strs) + assert snp_inds == [] and str_inds == [] + + # Test 2 All Matching + snps.samples = ["Sample_2", "Sample_3", "Sample_1"] + strs.samples = ["Sample_3", "Sample_2", "Sample_1"] + snp_inds, str_inds = GetOverlappingSamples(snps, strs) + assert snp_inds == [2, 0, 1] and str_inds == [2, 1, 0] + + # Test 3 SNPs and STRs incremented + snps.samples = ["Sample_2", "Sample_03", "Sample_01", "Sample_4"] + strs.samples = ["Sample_3", "Sample_2", "Sample_1", "Sample_4"] + snp_inds, str_inds = GetOverlappingSamples(snps, strs) + assert snp_inds == [0, 3] and str_inds == [1, 3] + + # Test 4 Uneven Sample Lists + snps.samples = ["Sample_2", "Sample_03", "Sample_01", "Sample_4", "Sample_5"] + strs.samples = ["Sample_3", "Sample_2", "Sample_4"] + snp_inds, str_inds = GetOverlappingSamples(snps, strs) + assert snp_inds == [0, 3] and str_inds == [1, 2] + + def test_gt_filter(self): + gt1 = np.array([255, 1, 256, 3, 4, 7]) + gt2 = np.array([0, 510, 258, 1, 3, 8]) + gt1, gt2 = _FilterGts(gt1, gt2, log) + assert np.array_equal(gt1, [3, 4, 7]) and np.array_equal(gt2, [1, 3, 8]) + + gt1 = np.array([255]) + gt2 = np.array([255]) + gt1, gt2 = _FilterGts(gt1, gt2, log) + assert np.array_equal(gt1, []) and np.array_equal(gt2, []) + + def test_ld(self): + # load expected for all tests + expected = self._ld_expected() + + # create snp/str variants to compare + snp1_gt = np.array([1, 0, 0, 0, 1, 0, 1, 2, 1]) + snp2_gt = np.array([1, 0, 0, 0, 1, 0, 1, 1, 2]) + str1_gt = np.array([3, 1, 0, 2, 7, 0, 4, 5, 6]) + str2_gt = np.array([3, 1, 1, 0, 4, 2, 3, 6, 8]) + + # Calculate expected + r1 = np.round(ComputeLD(snp1_gt, snp2_gt, "Pearson", log)[1], 4) + r2 = np.round(ComputeLD(snp1_gt, str1_gt, "Pearson", log)[1], 4) + r3 = np.round(ComputeLD(str1_gt, str2_gt, "Pearson", log)[1], 4) + r4 = np.round(ComputeLD(snp1_gt, snp2_gt, "Exact", log)[1], 4) + + assert [r1, r2, r3, r4] == expected + + def test_invalid_stats_vcf(self): + clump_p1 = 0.0001 + clump_p2 = 0.01 + clump_snp_field = "SNP" + clump_field = "P" + clump_chrom_field = "CHR" + clump_pos_field = "POS" + clump_kb = (250,) + clump_r2 = 0.5 + LD_type = "Pearson" + out = "NA" + summstats_snps = "fake/path" + summstats_strs = None + gts_snps = None + gts_strs = None + + def _get_error(f1, f2, ftype): + if ftype == "SNPs": + error = ( + f"One of summstats-snps {f1} and gts-snps {f2} is " + "not present. Please ensure both have been inputted correctly." + ) + else: + error = ( + f"One of summstats-strs {f1} and gts-strs {f2} is " + "not present. Please ensure both have been inputted correctly." + ) + return error + + # gts_snps None + with pytest.raises(Exception) as e: + clumpstr( + summstats_snps, + summstats_strs, + gts_snps, + gts_strs, + clump_p1, + clump_p2, + clump_snp_field, + clump_field, + clump_chrom_field, + clump_pos_field, + clump_kb, + clump_r2, + LD_type, + out, + log, + ) + assert (str(e.value)) == _get_error(summstats_snps, gts_snps, "SNPs") + + # Summstats_snps None + summstats_snps = None + gts_snps = "fake/path" + with pytest.raises(Exception) as e: + clumpstr( + summstats_snps, + summstats_strs, + gts_snps, + gts_strs, + clump_p1, + clump_p2, + clump_snp_field, + clump_field, + clump_chrom_field, + clump_pos_field, + clump_kb, + clump_r2, + LD_type, + out, + log, + ) + assert (str(e.value)) == _get_error(summstats_snps, gts_snps, "SNPs") + + # gts_strs None + summstats_snps = "fake/path" + gts_snps = "fake/path" + summstats_strs = "fake/path" + with pytest.raises(Exception) as e: + clumpstr( + summstats_snps, + summstats_strs, + gts_snps, + gts_strs, + clump_p1, + clump_p2, + clump_snp_field, + clump_field, + clump_chrom_field, + clump_pos_field, + clump_kb, + clump_r2, + LD_type, + out, + log, + ) + assert (str(e.value)) == _get_error(summstats_strs, gts_strs, "STRs") + + # summstats_strs None + summstats_strs = None + gts_strs = "fake/path" + with pytest.raises(Exception) as e: + clumpstr( + summstats_snps, + summstats_strs, + gts_snps, + gts_strs, + clump_p1, + clump_p2, + clump_snp_field, + clump_field, + clump_chrom_field, + clump_pos_field, + clump_kb, + clump_r2, + LD_type, + out, + log, + ) + assert (str(e.value)) == _get_error(summstats_strs, gts_strs, "STRs") + + +class TestClumpCLI: + def test_snps(self, capfd): + out_file = Path("test_snps.clump") + cmd = ( + "clump --summstats-snps tests/data/test_snpstats.linear " + "--gts-snps tests/data/simple.vcf " + "--clump-id-field ID " + "--clump-chrom-field CHROM " + "--clump-pos-field POS " + "--ld Exact " + "--verbosity DEBUG " + "--out test_snps.clump" + ) + + runner = CliRunner() + result = runner.invoke(main, cmd.split(" "), catch_exceptions=False) + captured = capfd.readouterr() + assert result.exit_code == 0 + out_file.unlink() + + def test_strs(self, capfd): + out_file = Path("test_strs.clump") + cmd = ( + "clump --summstats-strs tests/data/test_strstats.linear " + "--gts-strs tests/data/simple_tr.vcf " + "--clump-id-field ID " + "--clump-chrom-field CHROM " + "--clump-pos-field POS " + "--verbosity DEBUG " + "--out test_strs.clump" + ) + runner = CliRunner() + result = runner.invoke(main, cmd.split(" "), catch_exceptions=False) + assert result.exit_code == 0 + out_file.unlink() + + def test_snps_strs(self, capfd): + out_file = Path("test_snps_strs.clump") + cmd = ( + "clump --summstats-snps tests/data/test_snpstats.linear " + "--gts-snps tests/data/simple.vcf " + "--summstats-strs tests/data/test_strstats.linear " + "--gts-strs tests/data/simple_tr.vcf " + "--clump-id-field ID " + "--clump-chrom-field CHROM " + "--clump-pos-field POS " + "--verbosity DEBUG " + "--out test_snps_strs.clump" + ) + runner = CliRunner() + result = runner.invoke(main, cmd.split(" "), catch_exceptions=False) + assert result.exit_code == 0 + + out_file.unlink() diff --git a/tests/test_data.py b/tests/test_data.py index 522a54c7..07fa9bb3 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -22,6 +22,7 @@ GenotypesTR, GenotypesVCF, GenotypesPLINK, + GenotypesTR, )