diff --git a/.gitignore b/.gitignore index f498011..c640270 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,12 @@ /target +# Dev files +/dev + +# Jupyter Notebook +.ipynb_checkpoints + # Byte-compiled / optimized / DLL files __pycache__/ .pytest_cache/ diff --git a/src/lib.rs b/src/lib.rs index f8ac399..78c341e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,11 @@ -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; -// use rayon::prelude::*; +// Standard library imports +use std::collections::{HashMap, HashSet}; +// External crate imports use anyhow::{anyhow, Result}; use log::debug; -use std::collections::HashMap; - -// use sourmash::sketch::nodegraph::Nodegraph; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; use sourmash::encodings::HashFunctions; use sourmash::signature::SeqToHashes; @@ -47,14 +46,9 @@ impl KmerCountTable { } pub fn count_hash(&mut self, hashval: u64) -> u64 { - let mut count: u64 = 1; - if self.counts.contains_key(&hashval) { - count = *self.counts.get(&hashval).unwrap(); - count = count + 1; - } - self.counts.insert(hashval, count); - - count + let count = self.counts.entry(hashval).or_insert(0); + *count += 1; + *count } pub fn count(&mut self, kmer: String) -> PyResult { @@ -86,7 +80,26 @@ impl KmerCountTable { } } - // Consume this DNA strnig. Return number of k-mers consumed. + // Get the count for a specific hash value directly + pub fn get_hash(&self, hashval: u64) -> u64 { + // Return the count for the hash value, or 0 if it does not exist + *self.counts.get(&hashval).unwrap_or(&0) + } + + // Get counts for a list of hash keys and return an list of counts + pub fn get_hash_array(&self, hash_keys: Vec) -> Vec { + // Map each hash key to its count, defaulting to 0 if the key is not present + hash_keys.iter().map(|&key| self.get_hash(key)).collect() + } + + // Getter for the 'hashes' attribute, returning all hash keys in the table + #[getter] + pub fn hashes(&self) -> Vec { + // Collect and return all keys from the counts HashMap + self.counts.keys().cloned().collect() + } + + // Consume this DNA string. Return number of k-mers consumed. #[pyo3(signature = (seq, allow_bad_kmers=true))] pub fn consume(&mut self, seq: String, allow_bad_kmers: bool) -> PyResult { let hashes = SeqToHashes::new( @@ -118,8 +131,57 @@ impl KmerCountTable { Ok(n) } + + // Helper method to get hash set of k-mers + fn hash_set(&self) -> HashSet { + self.counts.keys().cloned().collect() + } + + // Set operation methods + pub fn union(&self, other: &KmerCountTable) -> HashSet { + self.hash_set().union(&other.hash_set()).cloned().collect() + } + + pub fn intersection(&self, other: &KmerCountTable) -> HashSet { + self.hash_set() + .intersection(&other.hash_set()) + .cloned() + .collect() + } + + pub fn difference(&self, other: &KmerCountTable) -> HashSet { + self.hash_set() + .difference(&other.hash_set()) + .cloned() + .collect() + } + + pub fn symmetric_difference(&self, other: &KmerCountTable) -> HashSet { + self.hash_set() + .symmetric_difference(&other.hash_set()) + .cloned() + .collect() + } + + // Python dunder methods for set operations + fn __or__(&self, other: &KmerCountTable) -> HashSet { + self.union(other) + } + + fn __and__(&self, other: &KmerCountTable) -> HashSet { + self.intersection(other) + } + + fn __sub__(&self, other: &KmerCountTable) -> HashSet { + self.difference(other) + } + + fn __xor__(&self, other: &KmerCountTable) -> HashSet { + self.symmetric_difference(other) + } } +// Python module definition #[pymodule] fn oxli(m: &Bound<'_, PyModule>) -> PyResult<()> { env_logger::init(); diff --git a/src/python/tests/test_basic.py b/src/python/tests/test_basic.py index edc2a8d..ba44c33 100644 --- a/src/python/tests/test_basic.py +++ b/src/python/tests/test_basic.py @@ -1,7 +1,17 @@ -import pytest import oxli +import pytest + -def test_simple(): +# Helper function, create tables. +def create_sample_kmer_table(ksize, kmers): + table = oxli.KmerCountTable(ksize) + for kmer in kmers: + table.count(kmer) + return table + + +# Adding Kmers +def test_count(): # yo dawg it works cg = oxli.KmerCountTable(4) kmer = "ATCG" @@ -11,6 +21,24 @@ def test_simple(): assert cg.get(kmer) == 1 +def test_count_hash(): + kmer = "TAAACCCTAACCCTAACCCTAACCCTAACCC" + cg = oxli.KmerCountTable(ksize=31) + hashkey = cg.hash_kmer(kmer) + + assert cg.get_hash(hashkey) == 0 + assert cg.count_hash(hashkey) == 1 + assert cg.get_hash(hashkey) == 1 + + +def test_hash_rc(): + table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"]) + hash_aaa = table.hash_kmer("AAA") # 10679328328772601858 + hash_ttt = table.hash_kmer("TTT") # 10679328328772601858 + + assert hash_aaa == hash_ttt, "Hash should be same for reverse complement." + + def test_wrong_ksize(): # but only with the right ksize cg = oxli.KmerCountTable(3) @@ -40,15 +68,14 @@ def test_consume_2(): assert cg.consume(seq) == 2 assert cg.get("ATCG") == 1 assert cg.get("TCGG") == 1 - assert cg.get("CCGA") == 1 # reverse complement! + assert cg.get("CCGA") == 1 # reverse complement! def test_consume_bad_DNA(): # test an invalid base in last position cg = oxli.KmerCountTable(4) seq = "ATCGGX" - with pytest.raises(ValueError, - match="bad k-mer encountered at position 2"): + with pytest.raises(ValueError, match="bad k-mer encountered at position 2"): cg.consume(seq, allow_bad_kmers=False) @@ -56,8 +83,7 @@ def test_consume_bad_DNA_2(): # test an invalid base in first position cg = oxli.KmerCountTable(4) seq = "XATCGG" - with pytest.raises(ValueError, - match="bad k-mer encountered at position 0"): + with pytest.raises(ValueError, match="bad k-mer encountered at position 0"): cg.consume(seq, allow_bad_kmers=False) @@ -68,7 +94,7 @@ def test_consume_bad_DNA_ignore(): print(cg.consume(seq, allow_bad_kmers=True)) assert cg.get("ATCG") == 1 assert cg.get("TCGG") == 1 - assert cg.get("CCGA") == 1 # rc + assert cg.get("CCGA") == 1 # rc def test_consume_bad_DNA_ignore_is_default(): @@ -78,20 +104,151 @@ def test_consume_bad_DNA_ignore_is_default(): print(cg.consume(seq)) assert cg.get("ATCG") == 1 assert cg.get("TCGG") == 1 - assert cg.get("CCGA") == 1 # rc + assert cg.get("CCGA") == 1 # rc -def test_count_get(): - # test a bug reported by adam taranto: count and get should work together! - kmer = 'TAAACCCTAACCCTAACCCTAACCCTAACCC' +# Test attributes +def test_hashes_attribute(): + table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"]) + hashes = table.hashes + hash_aaa = table.hash_kmer("AAA") # 10679328328772601858 + hash_ttt = table.hash_kmer("TTT") # 10679328328772601858 + hash_aac = table.hash_kmer("AAC") # 6579496673972597301 + + expected_hashes = set( + [hash_aaa, hash_ttt, hash_aac] + ) # {10679328328772601858, 6579496673972597301} + assert ( + set(hashes) == expected_hashes + ), ".hashes attribute should match the expected set of hash keys" + +# Getting counts +def test_count_vs_counthash(): + # test a bug reported by adam taranto: count and get should work together! + kmer = "TAAACCCTAACCCTAACCCTAACCCTAACCC" cg = oxli.KmerCountTable(ksize=31) hashkey = cg.hash_kmer(kmer) assert cg.get(kmer) == 0 assert cg.count(kmer) == 1 assert cg.count(kmer) == 2 + assert cg.get(kmer) == 2 + assert cg.count_hash(hashkey) == 3 x = cg.get(kmer) - assert x == 2, x - + assert x == 3, x + + +def test_get_hash(): + """Retrieve counts using hash key.""" + table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"]) + # Find hash of kmer 'AAA' + hash_aaa = table.hash_kmer("AAA") # 10679328328772601858 + # Lookup counts for hash of 'AAA' and rc 'TTT' + count_aaa = table.get_hash(hash_aaa) + assert count_aaa == 2, "Hash count for 'AAA' should be 2" + + # Test single kmer + hash_aac = table.hash_kmer("AAC") # 6579496673972597301 + count_aac = table.get_hash(hash_aac) + assert count_aac == 1, "Hash count for 'AAC' should be 1" + + # Test for kmer that is not in table + hash_aag = table.hash_kmer("AAG") # 12774992397053849803 + count_aag = table.get_hash(hash_aag) + assert count_aag == 0, "Missing kmer count for 'AAG' should be 0" + + +def test_get_hash_array(): + """ + Get vector of counts corresponding to vector of hash keys. + """ + table = create_sample_kmer_table(3, ["AAA", "TTT", "AAC"]) + hash_aaa = table.hash_kmer("AAA") + hash_aac = table.hash_kmer("AAC") + hash_ggg = table.hash_kmer("GGG") # key not in table + + hash_keys = [hash_aaa, hash_aac, hash_ggg] + hash_keys_rev = [hash_ggg, hash_aac, hash_aaa] + + counts = table.get_hash_array(hash_keys) + rev_counts = table.get_hash_array(hash_keys_rev) + + assert ( + counts == [2, 1, 0] + ), "Hash array counts should match the counts of 'AAA' and 'AAC' and return zero for 'GGG'." + assert rev_counts == [0, 1, 2], "Count should be in same order as input list" + + +def test_get_array(): + """ + Get vector of counts corresponding to vector of kmers. + """ + # TODO: Add function to get list of counts given list of kmers. + pass + + +# Set operations +def test_union(): + table1 = create_sample_kmer_table(3, ["AAA", "AAC"]) + table2 = create_sample_kmer_table(3, ["AAC", "AAG"]) + + union_set = table1.union(table2) + expected_union = set(table1.hashes).union(table2.hashes) + + assert union_set == expected_union, "Union of hash sets should match" + + +def test_intersection(): + table1 = create_sample_kmer_table(3, ["AAA", "AAC"]) + table2 = create_sample_kmer_table(3, ["AAC", "AAG"]) + + intersection_set = table1.intersection(table2) + expected_intersection = set(table1.hashes).intersection(table2.hashes) + + assert ( + intersection_set == expected_intersection + ), "Intersection of hash sets should match" + + +def test_difference(): + table1 = create_sample_kmer_table(3, ["AAA", "AAC"]) + table2 = create_sample_kmer_table(3, ["AAC", "AAG"]) + + difference_set = table1.difference(table2) + expected_difference = set(table1.hashes).difference(table2.hashes) + + assert difference_set == expected_difference, "Difference of hash sets should match" + + +def test_symmetric_difference(): + table1 = create_sample_kmer_table(3, ["AAA", "AAC"]) + table2 = create_sample_kmer_table(3, ["AAC", "AAG"]) + + symmetric_difference_set = table1.symmetric_difference(table2) + expected_symmetric_difference = set(table1.hashes).symmetric_difference( + table2.hashes + ) + + assert ( + symmetric_difference_set == expected_symmetric_difference + ), "Symmetric difference of hash sets should match" + + +def test_dunder_methods(): + table1 = create_sample_kmer_table(3, ["AAA", "AAC"]) + table2 = create_sample_kmer_table(3, ["AAC", "AAG"]) + + assert table1.__or__(table2) == table1.union( + table2 + ), "__or__ method should match union()" + assert table1.__and__(table2) == table1.intersection( + table2 + ), "__and__ method should match intersection()" + assert table1.__sub__(table2) == table1.difference( + table2 + ), "__sub__ method should match difference()" + assert table1.__xor__(table2) == table1.symmetric_difference( + table2 + ), "__xor__ method should match symmetric_difference()"