Skip to content

Commit

Permalink
use enums all the way to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
luizirber committed Nov 1, 2019
1 parent f13ea27 commit a918b26
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 47 deletions.
11 changes: 11 additions & 0 deletions include/sourmash.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
#include <stdint.h>
#include <stdlib.h>

enum HashFunctions {
HASH_FUNCTIONS_MURMUR64_DNA = 1,
HASH_FUNCTIONS_MURMUR64_PROTEIN = 2,
HASH_FUNCTIONS_MURMUR64_DAYHOFF = 3,
};
typedef uint32_t HashFunctions;

enum SourmashErrorCode {
SOURMASH_ERROR_CODE_NO_ERROR = 0,
SOURMASH_ERROR_CODE_PANIC = 1,
Expand Down Expand Up @@ -79,6 +86,10 @@ const uint64_t *kmerminhash_get_mins(KmerMinHash *ptr);

uintptr_t kmerminhash_get_mins_size(KmerMinHash *ptr);

HashFunctions kmerminhash_hash_function(KmerMinHash *ptr);

void kmerminhash_hash_function_set(KmerMinHash *ptr, HashFunctions hash_function);

uint64_t kmerminhash_intersection(KmerMinHash *ptr, const KmerMinHash *other);

bool kmerminhash_is_protein(KmerMinHash *ptr);
Expand Down
44 changes: 40 additions & 4 deletions sourmash/_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ._compat import string_types, range_type
from ._lowlevel import ffi, lib
from .hash_functions import HashFunctions
from .utils import RustObject, rustcall, decode_str
from .exceptions import SourmashError

Expand Down Expand Up @@ -98,6 +99,16 @@ def __init__(
mins=None,
scaled=0,
):
_hash_function = None
if is_protein and dayhoff:
_hash_function = HashFunctions.murmur64_dayhoff
elif is_protein:
_hash_function = HashFunctions.murmur64_protein
elif not is_protein and not dayhoff:
_hash_function = HashFunctions.murmur64_DNA
else:
raise ValueError('invalid hash_function')

if max_hash and scaled:
raise ValueError("cannot set both max_hash and scaled")
elif scaled:
Expand Down Expand Up @@ -279,6 +290,22 @@ def track_abundance(self, b):
else:
self._methodcall(lib.kmerminhash_enable_abundance)

@property
def hash_function(self):
enum_value = self._methodcall(lib.kmerminhash_hash_function)
return HashFunctions(enum_value)

@hash_function.setter
def hash_function(self, v):
# TODO: validate v
# TODO: allow passing a string too?
# TODO: same sort of validation as track_abundance:
# - can only change on an empty minhash
if self.hash_function == v:
return

self._methodcall(lib.kmerminhash_hash_function_set, v.value)

def add_hash(self, h):
return self._methodcall(lib.kmerminhash_add_hash, h)

Expand Down Expand Up @@ -486,13 +513,22 @@ def add_protein(self, sequence):
)

def is_molecule_type(self, molecule):
if molecule.upper() == "DNA" and not self.is_protein:
return True
# TODO: only accept molecule as enum
if isinstance(molecule, str):
if molecule.upper() == "DNA":
molecule = HashFunctions.murmur64_DNA
elif molecule == "protein":
molecule = HashFunctions.murmur64_protein
elif molecule == "dayhoff":
molecule = HashFunctions.murmur64_dayhoff

if molecule == HashFunctions.murmur64_DNA and not self.is_protein:
return True
if self.is_protein:
if self.dayhoff:
if molecule == 'dayhoff':
if molecule == HashFunctions.murmur64_dayhoff:
return True
else:
if molecule == "protein":
if molecule == HashFunctions.murmur64_protein:
return True
return False
31 changes: 22 additions & 9 deletions sourmash/hash_functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
from ._minhash import _HashFunctions as HashFunctions
from enum import Enum

from ._lowlevel import lib


class HashFunctions(Enum):
murmur64_DNA = lib.HASH_FUNCTIONS_MURMUR64_DNA
murmur64_protein = lib.HASH_FUNCTIONS_MURMUR64_PROTEIN
murmur64_dayhoff = lib.HASH_FUNCTIONS_MURMUR64_DAYHOFF

@classmethod
def from_string(cls, hash_str):
if hash_str == "0.murmur64_DNA":
return cls.murmur64_DNA
elif hash_str == "0.murmur64_protein":
return cls.murmur64_protein
elif hash_str == "0.murmur64_dayhoff":
return cls.murmur64_dayhoff
else:
raise Exception("unknown molecule type: {}".format(hash_str))


def hashfunction_from_string(hash_str):
if hash_str == "0.murmur64_DNA":
return HashFunctions.murmur64_DNA
elif hash_str == "0.murmur64_protein":
return HashFunctions.murmur64_protein
elif hash_str == "0.murmur64_dayhoff":
return HashFunctions.murmur64_dayhoff
else:
raise Exception("unknown molecule type: {}".format(hash_str))
return HashFunctions.from_string(hash_str)
6 changes: 3 additions & 3 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pub enum SourmashError {
#[fail(display = "different signatures cannot be compared")]
MismatchSignatureType,

#[fail(display = "Can only set track_abundance=True if the MinHash is empty")]
NonEmptyMinHash,
#[fail(display = "Can only set {} if the MinHash is empty", message)]
NonEmptyMinHash { message: String },

#[fail(display = "invalid DNA character in input k-mer: {}", message)]
InvalidDNA { message: String },
Expand Down Expand Up @@ -82,7 +82,7 @@ impl SourmashErrorCode {
SourmashError::MismatchSignatureType => {
SourmashErrorCode::MismatchSignatureType
}
SourmashError::NonEmptyMinHash => SourmashErrorCode::NonEmptyMinHash,
SourmashError::NonEmptyMinHash { .. } => SourmashErrorCode::NonEmptyMinHash,
SourmashError::InvalidDNA { .. } => SourmashErrorCode::InvalidDNA,
SourmashError::InvalidProt { .. } => SourmashErrorCode::InvalidProt,
SourmashError::InvalidCodonLength { .. } => {
Expand Down
29 changes: 27 additions & 2 deletions src/ffi/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::slice;

use crate::errors::SourmashError;
use crate::signature::SigsTrait;
use crate::sketch::minhash::{aa_to_dayhoff, translate_codon, KmerMinHash};
use crate::sketch::minhash::{aa_to_dayhoff, translate_codon, HashFunctions, KmerMinHash};

#[no_mangle]
pub unsafe extern "C" fn kmerminhash_new(
Expand Down Expand Up @@ -271,7 +271,7 @@ unsafe fn kmerminhash_enable_abundance(ptr: *mut KmerMinHash) -> Result<()> {
};

if mh.mins.len() != 0 {
return Err(SourmashError::NonEmptyMinHash.into());
return Err(SourmashError::NonEmptyMinHash {message: "track_abundance=True".into()} .into());
}

mh.abunds = Some(vec![]);
Expand Down Expand Up @@ -306,6 +306,31 @@ pub unsafe extern "C" fn kmerminhash_max_hash(ptr: *mut KmerMinHash) -> u64 {
mh.max_hash()
}

#[no_mangle]
pub unsafe extern "C" fn kmerminhash_hash_function(ptr: *mut KmerMinHash) -> HashFunctions {
let mh = {
assert!(!ptr.is_null());
&mut *ptr
};
mh.hash_function()
}

ffi_fn! {
unsafe fn kmerminhash_hash_function_set(ptr: *mut KmerMinHash, hash_function: HashFunctions) -> Result<()> {
let mh = {
assert!(!ptr.is_null());
&mut *ptr
};

if mh.mins.len() != 0 {
return Err(SourmashError::NonEmptyMinHash { message: "hash_function".into()}.into());
}

mh.hash_function = hash_function;
Ok(())
}
}

ffi_fn! {
unsafe fn kmerminhash_merge(ptr: *mut KmerMinHash, other: *const KmerMinHash) -> Result<()> {
let mh = {
Expand Down
80 changes: 51 additions & 29 deletions src/sketch/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@ use crate::signature::SigsTrait;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;

#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(u32)]
pub enum HashFunctions {
murmur64_DNA = 1,
murmur64_protein = 2,
murmur64_dayhoff = 3,
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen)]
#[derive(Debug, Clone, PartialEq)]
pub struct KmerMinHash {
num: u32,
ksize: u32,
is_protein: bool,
dayhoff: bool,
pub(crate) hash_function: HashFunctions,
seed: u64,
max_hash: u64,
pub(crate) mins: Vec<u64>,
Expand All @@ -35,8 +43,7 @@ impl Default for KmerMinHash {
KmerMinHash {
num: 1000,
ksize: 21,
is_protein: false,
dayhoff: false,
hash_function: HashFunctions::murmur64_DNA,
seed: 42,
max_hash: 0,
mins: Vec::with_capacity(1000),
Expand Down Expand Up @@ -69,9 +76,9 @@ impl Serialize for KmerMinHash {

partial.serialize_field(
"molecule",
match &self.is_protein {
match &self.is_protein() {
true => {
if self.dayhoff {
if self.dayhoff() {
"dayhoff"
} else {
"protein"
Expand Down Expand Up @@ -105,7 +112,12 @@ impl<'de> Deserialize<'de> for KmerMinHash {
let tmpsig = TempSig::deserialize(deserializer)?;

let num = if tmpsig.max_hash != 0 { 0 } else { tmpsig.num };
let molecule = tmpsig.molecule.to_lowercase();
let hash_function = match tmpsig.molecule.to_lowercase().as_ref() {
"protein" => HashFunctions::murmur64_protein,
"dayhoff" => HashFunctions::murmur64_dayhoff,
"dna" => HashFunctions::murmur64_DNA,
_ => unimplemented!(), // TODO: throw error here
};

Ok(KmerMinHash {
num,
Expand All @@ -114,13 +126,7 @@ impl<'de> Deserialize<'de> for KmerMinHash {
max_hash: tmpsig.max_hash,
mins: tmpsig.mins,
abunds: tmpsig.abundances,
is_protein: match molecule.as_ref() {
"protein" => true,
"dayhoff" => true,
"dna" => false,
_ => unimplemented!(),
},
dayhoff: molecule == "dayhoff",
hash_function,
})
}
}
Expand Down Expand Up @@ -150,11 +156,18 @@ impl KmerMinHash {
abunds = None
}

let hash_function = if is_protein && dayhoff {
HashFunctions::murmur64_dayhoff
} else if is_protein {
HashFunctions::murmur64_protein
} else {
HashFunctions::murmur64_DNA
};

KmerMinHash {
num,
ksize,
is_protein,
dayhoff,
hash_function,
seed,
max_hash,
mins,
Expand All @@ -167,7 +180,11 @@ impl KmerMinHash {
}

pub fn is_protein(&self) -> bool {
self.is_protein
match self.hash_function {
HashFunctions::murmur64_dayhoff => true,
HashFunctions::murmur64_protein => true,
HashFunctions::murmur64_DNA => false,
}
}

pub fn seed(&self) -> u64 {
Expand Down Expand Up @@ -407,8 +424,8 @@ impl KmerMinHash {
let mut combined_mh = KmerMinHash::new(
self.num,
self.ksize,
self.is_protein,
self.dayhoff,
self.is_protein(),
self.dayhoff(),
self.seed,
self.max_hash,
self.abunds.is_some(),
Expand Down Expand Up @@ -440,8 +457,8 @@ impl KmerMinHash {
let mut combined_mh = KmerMinHash::new(
self.num,
self.ksize,
self.is_protein,
self.dayhoff,
self.is_protein(),
self.dayhoff(),
self.seed,
self.max_hash,
self.abunds.is_some(),
Expand Down Expand Up @@ -476,7 +493,14 @@ impl KmerMinHash {
}

pub fn dayhoff(&self) -> bool {
self.dayhoff
match self.hash_function {
HashFunctions::murmur64_dayhoff => true,
_ => false,
}
}

pub fn hash_function(&self) -> HashFunctions {
self.hash_function
}
}

Expand All @@ -497,10 +521,8 @@ impl SigsTrait for KmerMinHash {
if self.ksize != other.ksize {
return Err(SourmashError::MismatchKSizes.into());
}
if self.is_protein != other.is_protein {
return Err(SourmashError::MismatchDNAProt.into());
}
if self.dayhoff != other.dayhoff {
if self.hash_function != other.hash_function {
// TODO: fix this error
return Err(SourmashError::MismatchDNAProt.into());
}
if self.max_hash != other.max_hash {
Expand All @@ -518,7 +540,7 @@ impl SigsTrait for KmerMinHash {
.map(|&x| (x as char).to_ascii_uppercase() as u8)
.collect();
if sequence.len() >= (self.ksize as usize) {
if !self.is_protein {
if !self.is_protein() {
// dna
for kmer in sequence.windows(self.ksize as usize) {
if _checkdna(kmer) {
Expand Down Expand Up @@ -547,15 +569,15 @@ impl SigsTrait for KmerMinHash {
.skip(i)
.take(sequence.len() - i)
.collect();
let aa = to_aa(&substr, self.dayhoff)?;
let aa = to_aa(&substr, self.dayhoff())?;

aa.windows(aa_ksize as usize)
.map(|n| self.add_word(n))
.count();

let rc_substr: Vec<u8> =
rc.iter().cloned().skip(i).take(rc.len() - i).collect();
let aa_rc = to_aa(&rc_substr, self.dayhoff)?;
let aa_rc = to_aa(&rc_substr, self.dayhoff())?;

aa_rc
.windows(aa_ksize as usize)
Expand Down
Loading

0 comments on commit a918b26

Please sign in to comment.