Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Set Operation Methods #12

Merged
merged 11 commits into from
Sep 13, 2024
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

/target

# Dev files
/dev

# Jupyter Notebook
.ipynb_checkpoints

# Byte-compiled / optimized / DLL files
__pycache__/
.pytest_cache/
Expand Down
92 changes: 77 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
// use rayon::prelude::*;
// Standard library imports
use std::collections::{HashMap, HashSet};

// External crate imports
use anyhow::{anyhow, Result};
use log::debug;
use std::collections::HashMap;

// use sourmash::sketch::nodegraph::Nodegraph;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use sourmash::encodings::HashFunctions;
use sourmash::signature::SeqToHashes;

Expand Down Expand Up @@ -47,14 +46,9 @@ impl KmerCountTable {
}

pub fn count_hash(&mut self, hashval: u64) -> u64 {
let mut count: u64 = 1;
if self.counts.contains_key(&hashval) {
count = *self.counts.get(&hashval).unwrap();
count = count + 1;
}
self.counts.insert(hashval, count);

count
let count = self.counts.entry(hashval).or_insert(0);
*count += 1;
*count
}

pub fn count(&mut self, kmer: String) -> PyResult<u64> {
Expand Down Expand Up @@ -86,7 +80,26 @@ impl KmerCountTable {
}
}

// Consume this DNA strnig. Return number of k-mers consumed.
// Get the count for a specific hash value directly
pub fn get_hash(&self, hashval: u64) -> u64 {
// Return the count for the hash value, or 0 if it does not exist
*self.counts.get(&hashval).unwrap_or(&0)
}

// Get counts for a list of hash keys and return an list of counts
pub fn get_hash_array(&self, hash_keys: Vec<u64>) -> Vec<u64> {
// Map each hash key to its count, defaulting to 0 if the key is not present
hash_keys.iter().map(|&key| self.get_hash(key)).collect()
}

// Getter for the 'hashes' attribute, returning all hash keys in the table
#[getter]
pub fn hashes(&self) -> Vec<u64> {
// Collect and return all keys from the counts HashMap
self.counts.keys().cloned().collect()
}

// Consume this DNA string. Return number of k-mers consumed.
#[pyo3(signature = (seq, allow_bad_kmers=true))]
pub fn consume(&mut self, seq: String, allow_bad_kmers: bool) -> PyResult<u64> {
let hashes = SeqToHashes::new(
Expand Down Expand Up @@ -118,8 +131,57 @@ impl KmerCountTable {

Ok(n)
}

// Helper method to get hash set of k-mers
fn hash_set(&self) -> HashSet<u64> {
self.counts.keys().cloned().collect()
}

// Set operation methods
pub fn union(&self, other: &KmerCountTable) -> HashSet<u64> {
self.hash_set().union(&other.hash_set()).cloned().collect()
}

pub fn intersection(&self, other: &KmerCountTable) -> HashSet<u64> {
self.hash_set()
.intersection(&other.hash_set())
.cloned()
.collect()
}

pub fn difference(&self, other: &KmerCountTable) -> HashSet<u64> {
self.hash_set()
.difference(&other.hash_set())
.cloned()
.collect()
}

pub fn symmetric_difference(&self, other: &KmerCountTable) -> HashSet<u64> {
self.hash_set()
.symmetric_difference(&other.hash_set())
.cloned()
.collect()
}

// Python dunder methods for set operations
fn __or__(&self, other: &KmerCountTable) -> HashSet<u64> {
self.union(other)
}

fn __and__(&self, other: &KmerCountTable) -> HashSet<u64> {
self.intersection(other)
}

fn __sub__(&self, other: &KmerCountTable) -> HashSet<u64> {
self.difference(other)
}

fn __xor__(&self, other: &KmerCountTable) -> HashSet<u64> {
self.symmetric_difference(other)
}
}

// Python module definition
#[pymodule]
fn oxli(m: &Bound<'_, PyModule>) -> PyResult<()> {
env_logger::init();
Expand Down
185 changes: 171 additions & 14 deletions src/python/tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import pytest
import oxli
import pytest


def test_simple():
# Helper function, create tables.
def create_sample_kmer_table(ksize, kmers):
table = oxli.KmerCountTable(ksize)
for kmer in kmers:
table.count(kmer)
return table


# Adding Kmers
def test_count():
# yo dawg it works
cg = oxli.KmerCountTable(4)
kmer = "ATCG"
Expand All @@ -11,6 +21,24 @@ def test_simple():
assert cg.get(kmer) == 1


def test_count_hash():
kmer = "TAAACCCTAACCCTAACCCTAACCCTAACCC"
cg = oxli.KmerCountTable(ksize=31)
hashkey = cg.hash_kmer(kmer)

assert cg.get_hash(hashkey) == 0
assert cg.count_hash(hashkey) == 1
assert cg.get_hash(hashkey) == 1


def test_hash_rc():
table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"])
hash_aaa = table.hash_kmer("AAA") # 10679328328772601858
hash_ttt = table.hash_kmer("TTT") # 10679328328772601858

assert hash_aaa == hash_ttt, "Hash should be same for reverse complement."


def test_wrong_ksize():
# but only with the right ksize
cg = oxli.KmerCountTable(3)
Expand Down Expand Up @@ -40,24 +68,22 @@ def test_consume_2():
assert cg.consume(seq) == 2
assert cg.get("ATCG") == 1
assert cg.get("TCGG") == 1
assert cg.get("CCGA") == 1 # reverse complement!
assert cg.get("CCGA") == 1 # reverse complement!


def test_consume_bad_DNA():
# test an invalid base in last position
cg = oxli.KmerCountTable(4)
seq = "ATCGGX"
with pytest.raises(ValueError,
match="bad k-mer encountered at position 2"):
with pytest.raises(ValueError, match="bad k-mer encountered at position 2"):
cg.consume(seq, allow_bad_kmers=False)


def test_consume_bad_DNA_2():
# test an invalid base in first position
cg = oxli.KmerCountTable(4)
seq = "XATCGG"
with pytest.raises(ValueError,
match="bad k-mer encountered at position 0"):
with pytest.raises(ValueError, match="bad k-mer encountered at position 0"):
cg.consume(seq, allow_bad_kmers=False)


Expand All @@ -68,7 +94,7 @@ def test_consume_bad_DNA_ignore():
print(cg.consume(seq, allow_bad_kmers=True))
assert cg.get("ATCG") == 1
assert cg.get("TCGG") == 1
assert cg.get("CCGA") == 1 # rc
assert cg.get("CCGA") == 1 # rc


def test_consume_bad_DNA_ignore_is_default():
Expand All @@ -78,20 +104,151 @@ def test_consume_bad_DNA_ignore_is_default():
print(cg.consume(seq))
assert cg.get("ATCG") == 1
assert cg.get("TCGG") == 1
assert cg.get("CCGA") == 1 # rc
assert cg.get("CCGA") == 1 # rc


def test_count_get():
# test a bug reported by adam taranto: count and get should work together!
kmer = 'TAAACCCTAACCCTAACCCTAACCCTAACCC'
# Test attributes
def test_hashes_attribute():
table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"])
hashes = table.hashes
hash_aaa = table.hash_kmer("AAA") # 10679328328772601858
hash_ttt = table.hash_kmer("TTT") # 10679328328772601858
hash_aac = table.hash_kmer("AAC") # 6579496673972597301

expected_hashes = set(
[hash_aaa, hash_ttt, hash_aac]
) # {10679328328772601858, 6579496673972597301}
assert (
set(hashes) == expected_hashes
), ".hashes attribute should match the expected set of hash keys"


# Getting counts
def test_count_vs_counthash():
# test a bug reported by adam taranto: count and get should work together!
kmer = "TAAACCCTAACCCTAACCCTAACCCTAACCC"
cg = oxli.KmerCountTable(ksize=31)
hashkey = cg.hash_kmer(kmer)

assert cg.get(kmer) == 0
assert cg.count(kmer) == 1
assert cg.count(kmer) == 2
assert cg.get(kmer) == 2
assert cg.count_hash(hashkey) == 3

x = cg.get(kmer)
assert x == 2, x

assert x == 3, x


def test_get_hash():
"""Retrieve counts using hash key."""
table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"])
# Find hash of kmer 'AAA'
hash_aaa = table.hash_kmer("AAA") # 10679328328772601858
# Lookup counts for hash of 'AAA' and rc 'TTT'
count_aaa = table.get_hash(hash_aaa)
assert count_aaa == 2, "Hash count for 'AAA' should be 2"

# Test single kmer
hash_aac = table.hash_kmer("AAC") # 6579496673972597301
count_aac = table.get_hash(hash_aac)
assert count_aac == 1, "Hash count for 'AAC' should be 1"

# Test for kmer that is not in table
hash_aag = table.hash_kmer("AAG") # 12774992397053849803
count_aag = table.get_hash(hash_aag)
assert count_aag == 0, "Missing kmer count for 'AAG' should be 0"


def test_get_hash_array():
"""
Get vector of counts corresponding to vector of hash keys.
"""
table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"])
hash_aaa = table.hash_kmer("AAA")
hash_aac = table.hash_kmer("AAC")
hash_ggg = table.hash_kmer("GGG") # key not in table

hash_keys = [hash_aaa, hash_aac, hash_ggg]
hash_keys_rev = [hash_ggg, hash_aac, hash_aaa]

counts = table.get_hash_array(hash_keys)
rev_counts = table.get_hash_array(hash_keys_rev)

assert (
counts == [2, 1, 0]
), "Hash array counts should match the counts of 'AAA' and 'AAC' and return zero for 'GGG'."
assert rev_counts == [0, 1, 2], "Count should be in same order as input list"


def test_get_array():
"""
Get vector of counts corresponding to vector of kmers.
"""
# TODO: Add function to get list of counts given list of kmers.
pass


# Set operations
def test_union():
table1 = create_sample_kmer_table(3, ["AAA", "AAC"])
table2 = create_sample_kmer_table(3, ["AAC", "AAG"])

union_set = table1.union(table2)
expected_union = set(table1.hashes).union(table2.hashes)

assert union_set == expected_union, "Union of hash sets should match"


def test_intersection():
table1 = create_sample_kmer_table(3, ["AAA", "AAC"])
table2 = create_sample_kmer_table(3, ["AAC", "AAG"])

intersection_set = table1.intersection(table2)
expected_intersection = set(table1.hashes).intersection(table2.hashes)

assert (
intersection_set == expected_intersection
), "Intersection of hash sets should match"


def test_difference():
table1 = create_sample_kmer_table(3, ["AAA", "AAC"])
table2 = create_sample_kmer_table(3, ["AAC", "AAG"])

difference_set = table1.difference(table2)
expected_difference = set(table1.hashes).difference(table2.hashes)

assert difference_set == expected_difference, "Difference of hash sets should match"


def test_symmetric_difference():
table1 = create_sample_kmer_table(3, ["AAA", "AAC"])
table2 = create_sample_kmer_table(3, ["AAC", "AAG"])

symmetric_difference_set = table1.symmetric_difference(table2)
expected_symmetric_difference = set(table1.hashes).symmetric_difference(
table2.hashes
)

assert (
symmetric_difference_set == expected_symmetric_difference
), "Symmetric difference of hash sets should match"


def test_dunder_methods():
table1 = create_sample_kmer_table(3, ["AAA", "AAC"])
table2 = create_sample_kmer_table(3, ["AAC", "AAG"])

assert table1.__or__(table2) == table1.union(
table2
), "__or__ method should match union()"
assert table1.__and__(table2) == table1.intersection(
table2
), "__and__ method should match intersection()"
assert table1.__sub__(table2) == table1.difference(
table2
), "__sub__ method should match difference()"
assert table1.__xor__(table2) == table1.symmetric_difference(
table2
), "__xor__ method should match symmetric_difference()"
Loading