diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 905712b5..944c1884 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,9 +1,9 @@ -# Contributing to the dsd Python package +# Contributing to the nuad Python package First off, thanks for taking the time to contribute! -The following is a set of guidelines for contributing to dsd. +The following is a set of guidelines for contributing to nuad. Feel free to propose changes to this document in a pull request, -or post questions as issues on the [issues page](https://github.com/UC-Davis-molecular-computing/dsd/issues). +or post questions as issues on the [issues page](https://github.com/UC-Davis-molecular-computing/nuad/issues). @@ -22,13 +22,13 @@ or post questions as issues on the [issues page](https://github.com/UC-Davis-mol ### Python First, read the [README](README.md) to familiarize yourself with the package from a user's perspective. -The dsd Python package requires at least Python 3.6. +The nuad Python package requires at least Python 3.6. ### What to install Follow the [installation instructions](README.md#installation) to install the correct version of Python if you don't have it already. -I suggest using a powerful IDE such as [PyCharm](https://www.jetbrains.com/pycharm/download/download-thanks.html). [Visual Studio Code](https://code.visualstudio.com/) is also good with the right plugins. The dsd Python package uses type hints, and these tools are very helpful in giving static analysis warnings about the code that may represent errors that will manifest at run time. +I suggest using a powerful IDE such as [PyCharm](https://www.jetbrains.com/pycharm/download/download-thanks.html). [Visual Studio Code](https://code.visualstudio.com/) is also good with the right plugins. The nuad Python package uses type hints, and these tools are very helpful in giving static analysis warnings about the code that may represent errors that will manifest at run time. ### git @@ -54,11 +54,11 @@ We use [git](https://git-scm.com/docs/gittutorial) and [GitHub](https://guides.g The first step is cloning the repository so you have it available locally. ``` -git clone https://github.com/UC-Davis-molecular-computing/dsd.git +git clone https://github.com/UC-Davis-molecular-computing/nuad.git ``` -Changes to the dsd package should be pushed to the -[`dev`](https://github.com/UC-Davis-molecular-computing/dsd/tree/dev) branch. So switch to the `dev` branch: +Changes to the nuad package should be pushed to the +[`dev`](https://github.com/UC-Davis-molecular-computing/nuad/tree/dev) branch. So switch to the `dev` branch: ``` git checkout dev @@ -112,7 +112,7 @@ For any more significant change that is made (e.g., closing an issue, adding a n ## Pushing to the repository main branch and documenting changes (done less frequently) -Less frequently, pull requests (abbreviated PR) can be made from `dev` to `main`, but make sure that `dev` is working before merging to `main`, since changes to the docstrings automatically update [readthedocs](https://dsddna.readthedocs.io/en/latest/), which is the site hosting the API documentation. That is, changes to main immediately affect users reading online documentation, so it is critical that these work. Eventually we will automatically upload to PyPI, so this will also affect users installing via pip. +Less frequently, pull requests (abbreviated PR) can be made from `dev` to `main`, but make sure that `dev` is working before merging to `main`, since changes to the docstrings automatically update [readthedocs](https://nuad.readthedocs.io/en/latest/), which is the site hosting the API documentation. That is, changes to main immediately affect users reading online documentation, so it is critical that these work. Eventually we will automatically upload to PyPI, so this will also affect users installing via pip. **WARNING:** Always wait for the checks to complete. This is important to ensure that unit tests pass. diff --git a/README.md b/README.md index e866a003..32395f6e 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ In more detail, there are five main types of objects you create to describe your - There are two types of `Domain`'s with no associated `DomainPool`. One type is a `Domain` with the field `fixed` set to `True` by calling the method `Domain.set_fixed_sequence()`, which has some fixed DNA sequence that cannot be changed. A fixed `Domain` has no `DomainPool`.) - - The other type is a `Domain` with the field `dependent` set the `True` (by assigning the field directly). Such a domain is dependent for its sequence on the sequence of some other `Domain` with `dependent = False` that either contains it as a subsequence, or is contained in it as a subsequence. For example, one can declare the domain `a` is independent (has `dependent = False`), with length 8, and has dependent subdomains `b` and `c` of length 5 and 3. `a` would have a `DomainPool`, and if `a` is assigned sequence AAACCGTT, then `b` is automatically assigned sequence AAACC, and `c` is automatically assigned sequence GTT. Such subdomains are assigned via the field `Domain.subdomains`; see the API documentation for more details: https://dnadsd.readthedocs.io/en/latest/#constraints.Domain.dependent and https://dnadsd.readthedocs.io/en/latest/#constraints.Domain.subdomains. + - The other type is a `Domain` with the field `dependent` set the `True` (by assigning the field directly). Such a domain is dependent for its sequence on the sequence of some other `Domain` with `dependent = False` that either contains it as a subsequence, or is contained in it as a subsequence. For example, one can declare the domain `a` is independent (has `dependent = False`), with length 8, and has dependent subdomains `b` and `c` of length 5 and 3. `a` would have a `DomainPool`, and if `a` is assigned sequence AAACCGTT, then `b` is automatically assigned sequence AAACC, and `c` is automatically assigned sequence GTT. Such subdomains are assigned via the field `Domain.subdomains`; see the API documentation for more details: https://nuad.readthedocs.io/en/latest/#constraints.Domain.dependent and https://nuad.readthedocs.io/en/latest/#constraints.Domain.subdomains. - `Strand`: A `Strand` contains an ordered list `domains` of `Domain`'s, together with an identification of which `Domain`'s are starred in this `Strand`, the latter specified as a set `starred_domain_indices` of indices (starting at 0) into the list `domains`. For example, the `Strand` consisting of `Domain`'s `a`, `b*`, `c`, `b`, `d*`, in that order, would have `domains = [a, b, c, b, d]` and `starred_domain_indices = {1, 4}`. @@ -146,7 +146,7 @@ In more detail, there are five main types of objects you create to describe your - `DomainConstraint`: This only looks at a single `Domain`. In practice this is not used much, since there's not much information in a `Domain` other than its DNA sequence, so a `SequenceConstraint` or `NumpyConstraint` typically would already have filtered out any DNA sequence not satisfying such a constraint. - - `StrandConstraint`: This evaluates a whole `Strand`. A common example is that NUPACK's `pfunc` should indicate a complex free energy above a certain threshold, indicating the `Strand` has little secondary structure. This example constraint is available in the library by calling [nupack_strand_complex_free_energy_constraint](https://dnadsd.readthedocs.io/en/latest/#constraints.nupack_strand_complex_free_energy_constraint). + - `StrandConstraint`: This evaluates a whole `Strand`. A common example is that NUPACK's `pfunc` should indicate a complex free energy above a certain threshold, indicating the `Strand` has little secondary structure. This example constraint is available in the library by calling [nupack_strand_complex_free_energy_constraint](https://nuad.readthedocs.io/en/latest/#constraints.nupack_strand_complex_free_energy_constraint). - `DomainPairConstraint`: This evaluates a pair of `Domain`'s. @@ -170,7 +170,7 @@ In more detail, there are five main types of objects you create to describe your The search algorithm evaluates the constraints, and for each violated constraint, it turns the `excess` value into a "score" by first passing it through the "score transfer function", which by default squares the value, and then multiplies by the value `Constraint.weight` (by default 1). The goal of the search is to minimize the sum of scores across all violated `Constraint`'s. The reason that the score is squared is that this leads the search algorithm to (slightly) favor reducing the excess of constraint violations that are "more in excess", i.e., it would reduce the total score more to reduce an excess from 4 to 3 (reducing the score from 42=16 to 32=9, a reduction of 16-9=7) than to reduce an excess from 2 to 1 (which reduces 22=4 to 12=1, a reduction of only 4-1=3). - The full search algorithm is described in the [API documentation for the function nuad.search.search_for_dna_sequences](https://dnadsd.readthedocs.io/en/latest/#search.search_for_dna_sequences). + The full search algorithm is described in the [API documentation for the function nuad.search.search_for_sequences](https://nuad.readthedocs.io/en/latest/#search.search_for_sequences). ## Constraint evaluations must be pure functions of their inputs diff --git a/nuad/__version__.py b/nuad/__version__.py index 2f3f5c52..1bf14c59 100644 --- a/nuad/__version__.py +++ b/nuad/__version__.py @@ -1 +1 @@ -version = '0.4.1' # version line; WARNING: do not remove or change this line or comment +version = '0.4.2' # version line; WARNING: do not remove or change this line or comment diff --git a/nuad/constraints.py b/nuad/constraints.py index a59129f3..b5259b07 100644 --- a/nuad/constraints.py +++ b/nuad/constraints.py @@ -24,7 +24,7 @@ import json from decimal import Decimal from typing import List, Set, Dict, Callable, Iterable, Tuple, Collection, TypeVar, Any, \ - cast, Generic, DefaultDict, FrozenSet, Iterator, Sequence, Type + cast, Generic, DefaultDict, FrozenSet, Iterator, Sequence, Type, Optional from dataclasses import dataclass, field, InitVar from abc import ABC, abstractmethod from collections import defaultdict @@ -140,6 +140,13 @@ class M13Variant(enum.Enum): https://www.tilibit.com/collections/scaffold-dna/products/single-stranded-scaffold-dna-type-p8064 """ + p8634 = "p8634" + """Variant of M13mp18 that is 8634 bases long. At the time of this writing, not listed as available + from any biotech vender, but Tilibit will make it for you if you ask. + (https://www.tilibit.com/pages/contact-us) + """ + + def length(self) -> int: """ :return: length of this variant of M13 (e.g., 7249 for variant :data:`M13Variant.p7249`) @@ -150,6 +157,8 @@ def length(self) -> int: return 7560 if self is M13Variant.p8064: return 8064 + if self is M13Variant.p8634: + return 8634 raise AssertionError('should be unreachable') def scadnano_variant(self) -> sc.M13Variant: @@ -159,6 +168,8 @@ def scadnano_variant(self) -> sc.M13Variant: return sc.M13Variant.p7560 if self is M13Variant.p8064: return sc.M13Variant.p8064 + if self is M13Variant.p8634: + return sc.M13Variant.p8634 raise AssertionError('should be unreachable') @@ -273,12 +284,12 @@ def m13_substrings_of_length(length: int, except_indices: Iterable[int] = tuple( def default_score_transfer_function(x: float) -> float: """ - A quadratic transfer function. + A cubic transfer function. :return: - max(0.0, x^2) + max(0.0, x^3) """ - return max(0.0, x ** 2) + return max(0.0, x ** 3) logger = logging.Logger('dsd', level=logging.DEBUG) @@ -378,7 +389,7 @@ class NumpyFilter(ABC): for a :any:`Domain`; a sequence not passing the filter is never allowed to be assigned to a :any:`Domain`. This constrasts with the various subclasses of :any:`Constraint`, which are different in two ways: 1) they can apply to large parts of the design than just a domain, - e.g., a :any:`Strand` or a pair of :any:`Domain`'s, and 2) they are "soft" constraint that are + e.g., a :any:`Strand` or a pair of :any:`Domain`'s, and 2) they are "soft" constraints that are allowed to be violated during the course of the search. A :any:`NumpyFilter` is one that can be efficiently encoded @@ -574,7 +585,7 @@ def remove_violating_sequences(self, seqs: nn.DNASeqList) -> nn.DNASeqList: f'when sequences only have length {seqs.seqlen}') if self.five_prime: - good_left = np.zeros(shape=len(seqs), dtype=np.bool) + good_left = np.zeros(shape=len(seqs), dtype=bool) left = seqs.seqarr[:, self.distance_from_end] for bits in all_bits: if good_left is None: @@ -583,7 +594,7 @@ def remove_violating_sequences(self, seqs: nn.DNASeqList) -> nn.DNASeqList: good_left |= (left == bits) if self.three_prime: - good_right = np.zeros(shape=len(seqs), dtype=np.bool) + good_right = np.zeros(shape=len(seqs), dtype=bool) right = seqs.seqarr[:, -1 - self.distance_from_end] for bits in all_bits: if good_right is None: @@ -639,7 +650,7 @@ def remove_violating_sequences(self, seqs: nn.DNASeqList) -> nn.DNASeqList: if not 0 <= self.position < seqs.seqlen: raise ValueError(f'position must be between 0 and {seqs.seqlen} but it is {self.position}') mid = seqs.seqarr[:, self.position] - good = np.zeros(shape=len(seqs), dtype=np.bool) + good = np.zeros(shape=len(seqs), dtype=bool) for base in self.bases: good |= (mid == nn.base2bits[base]) seqarr_pass = seqs.seqarr[good] @@ -708,7 +719,7 @@ def remove_violating_sequences(self, seqs: nn.DNASeqList) -> nn.DNASeqList: sub_vals = np.dot(sub_ints, pow_arr) toeplitz = nn.create_toeplitz(seqs.seqlen, sub_len, self.indices) convolution = np.dot(toeplitz, seqs.seqarr.transpose()) - pass_all = np.ones(seqs.numseqs, dtype=np.bool) + pass_all = np.ones(seqs.numseqs, dtype=bool) for sub_val in sub_vals: pass_sub = np.all(convolution != sub_val, axis=0) pass_all = pass_all & pass_sub @@ -4292,11 +4303,8 @@ class Result(Generic[DesignPart]): a threshold. """ - summary: str = '' - """ - This string is displayed in the text report on constraints, after the name of the "part" (e.g., - strand, pair of domains, pair of strands). - """ + _summary: Optional[str] = None + value: pint.Quantity[Decimal] | None = None """ @@ -4328,15 +4336,33 @@ def __init__(self, if summary is None: if value is None: raise ValueError('at least one of value or summary must be specified') - self.summary = str(value) + # note summary getter calculates summary from value if summary is None, + # so no need to set it here else: - self.summary = summary + self._summary = summary if value is not None: self.value = parse_and_normalize_quantity(value) self.score = 0.0 self.part = None # type:ignore + @property + def summary(self) -> str: + """ + This string is displayed in the text report on constraints, after the name of the "part" (e.g., + strand, pair of domains, pair of strands). + + It can be set explicitly, or calculated from :data:`Result.value` if not set explicitly. + """ + if self._summary is None: + return str(self.value) + else: + return self._summary + + @summary.setter + def summary(self, summary: str) -> None: + self._summary = summary + def parse_and_normalize_quantity(quantity: float | int | str | pint.Quantity) \ -> pint.Quantity[Decimal]: @@ -4353,9 +4379,10 @@ def Q_(qty: int | str | Decimal | float, unit: str | pint.Unit) -> pint.Quantity return ureg.Quantity(qty, unit) else: # we convert to string to avoid floating-point weirdness. For example - # ureg.Quantity(Decimal(-2.1), 'kcal/mol') gives + # ureg.Quantity(Decimal(-2.1), 'kcal/mol') gives # -2.100000000000000088817841970012523233890533447265625 kilocalorie / mole, - # but ureg.Quantity(Decimal(str(-2.1)), 'kcal/mol') gives + # whereas + # ureg.Quantity(Decimal(str(-2.1)), 'kcal/mol') gives # -2.1 kilocalorie / mole, qty_str = str(qty) return ureg.Quantity(Decimal(qty_str), unit) @@ -4396,7 +4423,7 @@ def normalize_quantity(quantity: pint.Quantity, compact: bool = False) -> pint.Q @dataclass(eq=False) class SingularConstraint(Constraint[DesignPart], Generic[DesignPart], ABC): evaluate: Callable[[Tuple[str, ...], DesignPart | None], - Result[DesignPart]] = lambda _: _raise_unreachable() + Result[DesignPart]] = lambda _: _raise_unreachable() """ Essentially a wrapper for a function that evaluates the :any:`Constraint`. It takes as input a tuple of DNA sequences @@ -4463,7 +4490,7 @@ def call_evaluate(self, seqs: Tuple[str, ...], part: DesignPart | None) -> Resul @dataclass(eq=False) class BulkConstraint(Constraint[DesignPart], Generic[DesignPart], ABC): evaluate_bulk: Callable[[Sequence[DesignPart]], - List[Result]] = lambda _: _raise_unreachable() + List[Result]] = lambda _: _raise_unreachable() def call_evaluate_bulk(self, parts: Sequence[DesignPart]) -> List[Result]: results: List[Result[DesignPart]] = (self.evaluate_bulk)(parts) # noqa @@ -4693,7 +4720,7 @@ class DesignConstraint(Constraint[Design]): """ evaluate_design: Callable[[Design, Iterable[Domain]], - List[Tuple[DesignPart, float, str]]] = lambda _: _raise_unreachable() + List[Tuple[DesignPart, float, str]]] = lambda _: _raise_unreachable() """ Evaluates the :any:`Design` (first argument), possibly taking into account which :any:`Domain`'s have changed in the last iteration (second argument). @@ -5706,6 +5733,404 @@ def rna_duplex_strand_pairs_constraints_by_number_matching_domains( ) +def longest_complementary_subsequences_python_loop(arr1: np.ndarray, arr2: np.ndarray, + gc_double: bool) -> List[int]: + """ + Like :func:`longest_complementary_subsequences`, but uses a Python loop instead of numpy operations. + This is slower, but is easier to understand and useful for testing. + """ + lcs_sizes = [] + for s1, s2 in zip(arr1, arr2): + s1len = s1.shape[0] + s2len = s2.shape[0] + table = np.zeros(shape=(s1len + 1, s2len + 1), dtype=np.int8) + for i in range(s1len): + for j in range(s2len): + b1 = s1[i] + b2 = s2[j] + if b1 + b2 == 3: + weight = 1 + if gc_double and (b1 == 1 or b1 == 2): + weight = 2 + table[i + 1][j + 1] = weight + table[i][j] + else: + table[i + 1][j + 1] = max(table[i + 1][j], table[i][j + 1]) + lcs_size = table[s1len][s2len] + lcs_sizes.append(lcs_size) + return lcs_sizes + + +def longest_complementary_subsequences_two_loops(arr1: np.ndarray, arr2: np.ndarray, + gc_double: bool) -> List[int]: + """ + Calculate length of longest common subsequences between `arr1[i]` and `arr2[i]` + for each i, storing in returned list `result[i]`. + + This uses two nested Python loops to calculate the whole dynamic programming table. + :func:`longest_complementary_subsequences` is slightly faster because it maintains only the diagonal + of the DP table, and uses numpy vectorized operations to calculate the next diagonal of the table. + + When used for DNA sequences, this assumes `arr2` has been reversed along axis 1, i.e., + the sequences in `arr1` are assumed to be oriented 5' --> 3', and the sequences in `arr2` + are assumed to be oriented 3' --> 5'. + + Args: + arr1: 2D array of DNA sequences, with each sequence represented as a 1D array of 0, 1, 2, 3 + corresponding to A, C, G, T, respectively, with each row being a single DNA sequence + oriented 5' --> 3'. + arr2: 2D array of DNA sequences, with each row being a single DNA sequence + oriented 3' --> 5'. + gc_double: Whether to double the score for G-C base pairs. + + Returns: + list `ret` of ints, where `ret[i]` is the length of the longest complementary subsequence + between `arr1[i]` and `arr2[i]`. + """ + assert arr1.shape[0] == arr2.shape[0] + num_pairs = arr1.shape[0] + s1len = arr1.shape[1] + s2len = arr2.shape[1] + max_length = max(s1len, s2len) + dtype = np.min_scalar_type(max_length) # e.g., uint8 for 0-255, uint16 for 256-65535, etc. + table = np.zeros(shape=(num_pairs, s1len + 1, s2len + 1), dtype=dtype) + + # convert arr2 to complement and search for longest common subsequence (instead of complementary) + arr2 = 3 - arr2 + + for i in range(s1len): + for j in range(s2len): + bases1 = arr1[:, i] + bases2 = arr2[:, j] + + equal_idxs = bases1 == bases2 + if gc_double: + gc_idxs = np.logical_or(bases1[equal_idxs] == 1, bases1[equal_idxs] == 2) + weight = np.ones(len(bases1[equal_idxs]), dtype=dtype) + weight[gc_idxs] = 2 + table[equal_idxs, i + 1, j + 1] = weight + table[equal_idxs, i, j] + else: + table[equal_idxs, i + 1, j + 1] = 1 + table[equal_idxs, i, j] + + noncomp_idxs = np.logical_not(equal_idxs) + rec1 = table[noncomp_idxs, i + 1, j] + rec2 = table[noncomp_idxs, i, j + 1] + table[noncomp_idxs, i + 1, j + 1] = np.maximum(rec1, rec2) + + lcs_sizes = table[:, s1len, s2len] + + return lcs_sizes + + +def longest_complementary_subsequences(arr1: np.ndarray, arr2: np.ndarray, gc_double: bool) -> List[int]: + """ + Calculate length of longest common subsequences between `arr1[i]` and `arr2[i]` + for each i, storing in returned list `result[i]`. + + Unlike :func:`longest_complementary_subsequences_two_loops`, this uses only one Python loop, + + When used for DNA sequences, this assumes `arr2` has been reversed along axis 1, i.e., + the sequences in `arr1` are assumed to be oriented 5' --> 3', and the sequences in `arr2` + are assumed to be oriented 3' --> 5'. + + Args: + arr1: 2D array of DNA sequences, with each sequence represented as a 1D array of 0, 1, 2, 3 + corresponding to A, C, G, T, respectively, with each row being a single DNA sequence + oriented 5' --> 3'. + arr2: 2D array of DNA sequences, with each row being a single DNA sequence + oriented 3' --> 5'. + gc_double: Whether to double the score for G-C base pairs. + + Returns: + list `ret` of ints, where `ret[i]` is the length of the longest complementary subsequence + between `arr1[i]` and `arr2[i]`. + """ + assert arr1.shape[0] == arr2.shape[0] + num_pairs = arr1.shape[0] + s1len = arr1.shape[1] + s2len = arr2.shape[1] + assert s1len == s2len # for now, assume same length, but should be relaxed + + max_length = max(s1len, s2len) + dtype = np.min_scalar_type(max_length) # e.g., uint8 for 0-255, uint16 for 256-65535, etc. + + # convert arr2 to WC complement and search for longest common subsequence (instead of complementary) + arr2 = 3 - arr2 + + length_prev_prev = length_prev = s1len + prev_prev_larger = s1len % 2 == 0 + if prev_prev_larger: + length_prev_prev += 1 + else: + length_prev += 1 + + # using this spreadsheet to visual DP table: + # https://docs.google.com/spreadsheets/d/1FIOgQYFSJ_6r3ThBivDjf0epUxVLgk0xlQnQS6TUeSw/ + diag_prev_prev = np.zeros(shape=(num_pairs, length_prev_prev), dtype=dtype) + diag_prev = np.zeros(shape=(num_pairs, length_prev), dtype=dtype) + + # do dynamic programming to figure out longest complementary subsequence, + # maintaining only the diagonal of the table and the previous two diagonals + + # allocate these arrays just once to avoid re-allocating new memory each iteration + # they are used for telling which bases are equal between the two sequences + eq_idxs_larger = np.zeros((num_pairs, s1len + 1), dtype=bool) + eq_idxs_smaller = np.zeros((num_pairs, s1len), dtype=bool) + gc_idxs_larger = np.zeros((num_pairs, s1len + 1), dtype=bool) + gc_idxs_smaller = np.zeros((num_pairs, s1len), dtype=bool) + for i in range(0, 2 * s1len, 2): + diag_cur = update_diagonal(arr1, arr2, diag_prev, diag_prev_prev, + eq_idxs_larger if prev_prev_larger else eq_idxs_smaller, + gc_idxs_larger if prev_prev_larger else gc_idxs_smaller, + i, prev_prev_larger, gc_double) + if i < 2 * s1len - 2: + diag_next = update_diagonal(arr1, arr2, diag_cur, diag_prev, + eq_idxs_larger if not prev_prev_larger else eq_idxs_smaller, + gc_idxs_larger if not prev_prev_larger else gc_idxs_smaller, + i + 1, not prev_prev_larger, gc_double) + diag_prev = diag_next + diag_prev_prev = diag_cur + + middle_idx = s1len // 2 + lcs_sizes = diag_prev_prev[:, middle_idx] + + return lcs_sizes + + +def update_diagonal(arr1: np.ndarray, arr2: np.ndarray, + diag_prev: np.ndarray, diag_prev_prev: np.ndarray, + eq_idxs: np.ndarray, + gc_idxs: np.ndarray, + i: int, prev_prev_larger: bool, gc_double: bool) -> np.ndarray: + s1len = arr1.shape[1] + s2len = arr2.shape[1] + assert s1len == s2len # for now, assume same length, but should be relaxed + + # determine which bases in arr1 and arr2 are equal; + # compute LCS for that case and store in diag_eq + # creates view, not copy, so don't modify! + eq_idxs[:, :] = False + if i < s1len: + sub1 = arr1[:, i::-1] # indices i, i-1, ..., 0 + sub2 = arr2[:, :i + 1] # indices 0, 1, ..., i + else: + sub1 = arr1[:, :i - s1len:-1] # indices s1len-1, s1len-2, , ..., s1len-i + sub2 = arr2[:, i - s1len + 1:] # indices s1len-i+1, s1len-i+2, ..., s1len-1 + + # need to set eq_idxs only on entries "within" the DP table, not the padded 0s on the edges + # see https://docs.google.com/spreadsheets/d/1FIOgQYFSJ_6r3ThBivDjf0epUxVLgk0xlQnQS6TUeSw for example + + if i < s1len: + start = (s1len - i) // 2 + else: + start = (i - s1len) // 2 + 1 + end = s1len - start + if not prev_prev_larger: + end -= 1 + # TODO: if there's a way to avoid allocating new memory for the Boolean array eq, that will save time. + # With 10,000 pairs of sequences, each of length 64, this takes 1/4 the time if we just set + # eq_idxs[:, start:end + 1] = True, compared to computing sub1==sub2 allocating new memory for eq + # (not sure if the computation or the memory allocation dominates, however) + eq = sub1 == sub2 + eq_idxs[:, start:end + 1] = eq + + # don't want to allocate new memory, but give variable a better name + # to reflect that we are looking at the case where the bases are equal + # XXX: note that this is modifying diag_prev_prev, + # so only safe to do this after we aren't using it anymore + diag_cur = diag_prev_prev + diag_cur[eq_idxs] += 1 + if gc_double: + gc_idxs[:, start:end + 1] = np.logical_and(np.logical_or(sub1 == 1, sub1 == 2), eq) + diag_cur[gc_idxs] += 1 + + # now take maximum with immediately previous diagonal + if prev_prev_larger: + # diag_cur is 1 larger than diag_prev + diag_cur_L = diag_cur[:, :-1] + diag_cur_R = diag_cur[:, 1:] + np.maximum(diag_cur_L, diag_prev, out=diag_cur_L) # looks "above" in DP table + np.maximum(diag_cur_R, diag_prev, out=diag_cur_R) # looks "left" in DP table + else: + # diag_cur is 1 smaller than diag_prev + diag_prev_L = diag_prev[:, :-1] + diag_prev_R = diag_prev[:, 1:] + np.maximum(diag_cur, diag_prev_L, out=diag_cur) # looks "above" in DP table + np.maximum(diag_cur, diag_prev_R, out=diag_cur) # looks "left" in DP table + + return diag_cur + + +def lcs(seqs1: Sequence[str], seqs2: Sequence[str], gc_double: bool) -> List[int]: + arr1 = nn.seqs2arr(seqs1) + arr2 = nn.seqs2arr(seqs2) + arr2 = np.flip(arr2, axis=1) + return longest_complementary_subsequences(arr1, arr2, gc_double) + + +def lcs_loop(s1: str, s2: str, gc_double: bool) -> int: + arr1 = nn.seqs2arr([s1]) + arr2 = nn.seqs2arr([s2[::-1]]) + return longest_complementary_subsequences_python_loop(arr1, arr2, gc_double)[0] + + +def lcs_strand_pairs_constraint( + *, + threshold: int, + weight: float = 1.0, + score_transfer_function: Callable[[float], float] = default_score_transfer_function, + description: str | None = None, + short_description: str = 'lcs strand pairs', + pairs: Iterable[Tuple[Strand, Strand]] | None = None, + check_strand_against_itself: bool = True, + gc_double: bool = True, +) -> StrandPairsConstraint: + """ + TODO: describe + + Args + threshold: + + weight: + + score_transfer_function: + + description: + + short_description: + + pairs: + + gc_double: Whether to weigh G-C base pairs as double (i.e., they count for 2 instead of 1). + + Returns + A :any: StrandPairsConstraint` that checks given pairs of :any:`Strand`'s for excessive + interaction due to having long complementary subsequences. + """ + if description is None: + description = f'Longest complementary subsequence between strands is > {threshold}' + + def evaluate_bulk(strand_pairs: Iterable[StrandPair]) -> List[Result]: + # import time + # start_eb = time.time() + + seqs1 = [pair.strand1.sequence() for pair in strand_pairs] + seqs2 = [pair.strand2.sequence() for pair in strand_pairs] + arr1 = nn.seqs2arr(seqs1) + arr2 = nn.seqs2arr(seqs2) + arr2_rev = np.flip(arr2, axis=1) + + # start = time.time() + lcs_sizes = longest_complementary_subsequences(arr1, arr2_rev, gc_double) + # lcs_sizes = longest_complementary_subsequences_two_loops(arr1, arr2_rev, gc_double) + # end = time.time() + + results = [] + for lcs_size in lcs_sizes: + excess = lcs_size - threshold + value = f'{lcs_size}' + result = Result(excess=excess, value=value) + results.append(result) + + # end_eb = time.time() + # elapsed_ms = int(round((end - start) * 1000, 0)) + # elapsed_eb_ms = int(round((end_eb - start_eb) * 1000, 0)) + # print(f'\n{elapsed_ms} ms to measure LCS of {len(seqs1)} pairs') + # print(f'{elapsed_eb_ms} ms to run evaluate_bulk') + + return results + + pairs_tuple = None + if pairs is not None: + pairs_tuple = tuple(pairs) + + return StrandPairsConstraint( + description=description, + short_description=short_description, + weight=weight, + score_transfer_function=score_transfer_function, + evaluate_bulk=evaluate_bulk, + pairs=pairs_tuple, + check_strand_against_itself=check_strand_against_itself, + ) + + +def lcs_strand_pairs_constraints_by_number_matching_domains( + *, + thresholds: Dict[int, int], + weight: float = 1.0, + score_transfer_function: Callable[[float], float] = default_score_transfer_function, + descriptions: Dict[int, str] | None = None, + short_descriptions: Dict[int, str] | None = None, + parallel: bool = False, + strands: Iterable[Strand] | None = None, + pairs: Iterable[Tuple[Strand, Strand]] | None = None, + gc_double: bool = True, + parameters_filename: str = '', + ignore_missing_thresholds: bool = False, +) -> List[StrandPairsConstraint]: + """ + TODO + """ + if parameters_filename != '': + raise ValueError('should not specify parameters_filename when calling ' + 'lcs_strand_pairs_constraints_by_number_matching_domains; ' + 'it is only listed as a parameter for technical reasons relating to code resuse ' + 'with other constraints that use that parameter') + + def lcs_strand_pairs_constraint_with_dummy_parameters( + *, + threshold: float, + temperature: float = nv.default_temperature, + weight: float = 1.0, + score_transfer_function: Callable[[float], float] = default_score_transfer_function, + description: str | None = None, + short_description: str = 'lcs strand pairs', + parallel: bool = False, + pairs: Iterable[Tuple[Strand, Strand]] | None = None, + parameters_filename: str = nv.default_vienna_rna_parameter_filename + ) -> StrandPairsConstraint: + threshold_int = int(threshold) + return lcs_strand_pairs_constraint( + threshold=threshold_int, + weight=weight, + score_transfer_function=score_transfer_function, + description=description, + short_description=short_description, + pairs=pairs, + check_strand_against_itself=True, + # TODO: rewrite signature of other strand pair constraints to include this + gc_double=gc_double, + ) + + if descriptions is None: + descriptions = { + num_matching: (f'Longest complementary subsequence between strands is > {threshold}' + + f'\nfor strands with {num_matching} complementary ' + f'{"domain" if num_matching == 1 else "domains"}') + for num_matching, threshold in thresholds.items() + } + + if short_descriptions is None: + short_descriptions = { + num_matching: f'LCS{num_matching}comp' + for num_matching, threshold in thresholds.items() + } + + return _strand_pairs_constraints_by_number_matching_domains( + constraint_creator=lcs_strand_pairs_constraint_with_dummy_parameters, + thresholds=thresholds, + temperature=-1, + weight=weight, + score_transfer_function=score_transfer_function, + descriptions=descriptions, + short_descriptions=short_descriptions, + parallel=parallel, + strands=strands, + pairs=pairs, + ignore_missing_thresholds=ignore_missing_thresholds, + ) + + def rna_duplex_strand_pairs_constraint( *, threshold: float, @@ -7515,7 +7940,7 @@ def __get_base_pair_domain_endpoints_to_check( # End Input Validation # addr_to_starting_base_pair_idx: Dict[StrandDomainAddress, - int] = _get_addr_to_starting_base_pair_idx(strand_complex) + int] = _get_addr_to_starting_base_pair_idx(strand_complex) all_bound_domain_addresses.update(_get_implicitly_bound_domain_addresses( strand_complex, nonimplicit_base_pairs_domain_names)) diff --git a/nuad/np.py b/nuad/np.py index 4b7d36de..780eb7fc 100644 --- a/nuad/np.py +++ b/nuad/np.py @@ -43,22 +43,77 @@ def seq2arr(seq: str, base2bits_local: Dict[str, int] | None = None) -> np.ndarr return np.array([base2bits_local[base] for base in seq], dtype=np.ubyte) +# this was about 5 times slower than the new implementation of `seqs2arr` below +# def seqs2arr_old(seqs: Sequence[str]) -> np.ndarray: +# """Return numpy 2D array converting the given DNA sequences to integers.""" +# if len(seqs) == 0: +# return np.empty((0, 0), dtype=np.ubyte) +# if isinstance(seqs, str): +# raise ValueError('seqs must be a sequence of strings, not a single string') +# seq_len = len(seqs[0]) +# for seq in seqs: +# if len(seq) != seq_len: +# raise ValueError('All sequences in seqs must be equal length') +# num_seqs = len(seqs) +# arr = np.empty((num_seqs, seq_len), dtype=np.ubyte) +# for i in range(num_seqs): +# arr[i] = [base2bits[base] for base in seqs[i]] +# return arr + + def seqs2arr(seqs: Sequence[str]) -> np.ndarray: """Return numpy 2D array converting the given DNA sequences to integers.""" if len(seqs) == 0: return np.empty((0, 0), dtype=np.ubyte) + if isinstance(seqs, str): + raise ValueError('seqs must be a sequence of strings, not a single string') + + # check equal length of each sequence (a bit faster than a Python loop, + # e.g., 3.5 ms for 10^5 seqs compared to 5 ms with Python loop) seq_len = len(seqs[0]) - for seq in seqs: - if len(seq) != seq_len: - raise ValueError('All sequences in seqs must be equal length') + lengths_it = map(len, seqs) + lengths_arr = np.fromiter(lengths_it, dtype=int) + if np.any(lengths_arr != seq_len): + raise ValueError('All sequences in seqs must be equal length') + num_seqs = len(seqs) - arr = np.empty((num_seqs, seq_len), dtype=np.ubyte) - for i in range(num_seqs): - arr[i] = [base2bits[base] for base in seqs[i]] - return arr + + # the code below is about 5 times faster than the old implementation (commented out above) + seqs_cat = ''.join(seqs) + seqs_cat = seqs_cat.upper() + + seqs_cat_bytes = seqs_cat.encode() + seqs_cat_byte_array = bytearray(seqs_cat_bytes) + arr1d = np.frombuffer(seqs_cat_byte_array, dtype=np.ubyte) + # arr1d = np.fromstring(seqs_cat_bytes, dtype=np.ubyte) # generates warning about using frombuffer + + # convert ASCII bytes for 'A', 'C', 'G', 'T' to 0, 1, 2, 3, respectively + + # code below is magical to me, but it works and is slightly faster than more obvious ways: + # https://stackoverflow.com/a/35464758 + from_values = np.array([ord(base) for base in ['A', 'C', 'G', 'T']]) + to_values = np.arange(4) + sort_idx = np.argsort(from_values) + idx = np.searchsorted(from_values, arr1d, sorter=sort_idx) + arr1d = to_values[sort_idx][idx] + + # this is a bit slower than the code above, e.g., 75 ms compared to 55 ms for 10^5 sequences + # for i, base in enumerate(['A', 'C', 'G', 'T']): + # idxs_with_base = arr2d == ord(base) + # arr2d[idxs_with_base] = i + + arr2d = arr1d.reshape((num_seqs, seq_len)) + + return arr2d + + +def arr2seqs(arr: np.ndarray) -> List[str]: + """Return list of strings converting the given numpy array of integers to DNA sequences.""" + return [''.join(bits2base[base] for base in row) for row in arr] def arr2seq(arr: np.ndarray) -> str: + """Return string converting the given numpy array of integers to DNA sequence.""" bases_ch = [bits2base[base] for base in arr] return ''.join(bases_ch) @@ -88,7 +143,7 @@ def make_array_with_all_dna_seqs(length: int, bases: Collection[str] = ('A', 'C' if len(bases) == 0: raise ValueError('bases cannot be empty') if not set(bases) <= {'A', 'C', 'G', 'T'}: - raise ValueError(f"bases must be a subset of {'A', 'C', 'G', 'T'}; cannot be {bases}") + raise ValueError(f"bases must be a subset of {{A, C, G, T}}; cannot be {bases}") base_bits = [base2bits[base] for base in bases] digits = np.array(base_bits, dtype=np.ubyte) @@ -457,7 +512,7 @@ def longest_common_substrings_singlea1(a1: np.ndarray, a2s: np.ndarray) \ for i1 in range(len(a1)): idx = (a2s == a1[i1]) - idx_shifted = np.insert(idx, 0, np.zeros(numa2s, dtype=np.bool), axis=1) + idx_shifted = np.insert(idx, 0, np.zeros(numa2s, dtype=bool), axis=1) counter[i1 + 1, idx_shifted] = counter[i1, idx] + 1 counter = np.swapaxes(counter, 0, 1) @@ -518,7 +573,7 @@ def _longest_common_substrings_pairs(a1s: np.ndarray, a2s: np.ndarray) \ a1s_cp_col_rp = np.repeat(a1s_cp_col, len_a2, axis=1) idx = (a2s == a1s_cp_col_rp) - idx_shifted = np.hstack([np.zeros(shape=(numpairs, 1), dtype=np.bool), idx]) + idx_shifted = np.hstack([np.zeros(shape=(numpairs, 1), dtype=bool), idx]) counter[i1 + 1, idx_shifted] = counter[i1, idx] + 1 counter = np.swapaxes(counter, 0, 1) @@ -566,7 +621,7 @@ def _strongest_common_substrings_all_pairs_return_energies_and_counter( # find matching chars and extend length of substring match_idxs = (a2s == a1s_col_rp) - match_shifted_idxs = np.hstack([np.zeros(shape=(numpairs, 1), dtype=np.bool), match_idxs]) + match_shifted_idxs = np.hstack([np.zeros(shape=(numpairs, 1), dtype=bool), match_idxs]) counter[i1 + 1, match_shifted_idxs] = counter[i1, match_idxs] + 1 if i1 > 0: @@ -576,10 +631,10 @@ def _strongest_common_substrings_all_pairs_return_energies_and_counter( loops = (prev_bases << 2) + cur_bases latest_energies = loop_energies[loops].reshape(numpairs, 1) latest_energies_rp = np.repeat(latest_energies, len_a2, axis=1) - match_idxs_false_at_end = np.hstack([match_idxs, np.zeros(shape=(numpairs, 1), dtype=np.bool)]) + match_idxs_false_at_end = np.hstack([match_idxs, np.zeros(shape=(numpairs, 1), dtype=bool)]) both_match_idxs = match_idxs_false_at_end & prev_match_shifted_idxs prev_match_shifted_shifted_idxs = np.hstack( - [np.zeros(shape=(numpairs, 1), dtype=np.bool), prev_match_shifted_idxs])[:, :-1] + [np.zeros(shape=(numpairs, 1), dtype=bool), prev_match_shifted_idxs])[:, :-1] both_match_shifted_idxs = match_shifted_idxs & prev_match_shifted_shifted_idxs energies[i1 + 1, both_match_shifted_idxs] = energies[i1, both_match_idxs] + latest_energies_rp[ both_match_idxs] @@ -676,7 +731,7 @@ def __init__(self, shuffle: bool = False, alphabet: Collection[str] = ('A', 'C', 'G', 'T'), seqs: Sequence[str] | None = None, - seqarr: np.ndarray = None, + seqarr: np.ndarray | None = None, filename: str | None = None, rng: np.random.Generator = default_rng, hamming_distance_from_sequence: Tuple[int, str] | None = None): @@ -1069,14 +1124,14 @@ def filter_base_at_pos(self, pos: int, base: str) -> DNASeqList: def filter_substring(self, subs: Sequence[str]) -> DNASeqList: """Remove any sequence with any elements from subs as a substring.""" if len(set([len(sub) for sub in subs])) != 1: - raise ValueError('All substrings in subs must be equal length: %s' % subs) + raise ValueError(f'All substrings in subs must be equal length: {subs}') sublen = len(subs[0]) subints = [[base2bits[base] for base in sub] for sub in subs] powarr = [4 ** k for k in range(sublen)] subvals = np.dot(subints, powarr) toeplitz = create_toeplitz(self.seqlen, sublen) convolution = np.dot(toeplitz, self.seqarr.transpose()) - passall = np.ones(self.numseqs, dtype=np.bool) + passall = np.ones(self.numseqs, dtype=bool) for subval in subvals: passsub = np.all(convolution != subval, axis=0) passall = passall & passsub @@ -1123,7 +1178,7 @@ def create_toeplitz(seqlen: int, sublen: int, indices: Sequence[int] | None = No f'but {idx} is not; all indices = {indices}') num_rows = len(rows) num_cols = seqlen - toeplitz = np.zeros((num_rows, num_cols), dtype=np.int) + toeplitz = np.zeros((num_rows, num_cols), dtype=int) toeplitz[:, 0:sublen] = [powarr] * num_rows shift = list(rows) for i in range(len(rows)): @@ -1209,100 +1264,21 @@ def wcenergies_str(seqs: Sequence[str], temperature: float, negate: bool = False def wcenergy_str(seq: str, temperature: float, negate: bool = False) -> float: - seqarr = seqs2arr([seq]) - return list(calculate_wc_energies(seqarr, temperature, negate))[0] - - -def hash_ndarray(arr: np.ndarray) -> int: - writeable = arr.flags.writeable - if writeable: - arr.flags.writeable = False - h = hash(bytes(arr.data)) # hash(arr.data) - arr.flags.writeable = writeable - return h - - -CACHE_WC = False -_calculate_wc_energies_cache: np.ndarray | None = None -_calculate_wc_energies_cache_hash: int = 0 + return wcenergies_str([seq], temperature, negate)[0] def calculate_wc_energies(seqarr: np.ndarray, temperature: float, negate: bool = False) -> np.ndarray: """Calculate and store in an array all energies of all sequences in seqarr with their Watson-Crick complements.""" - global _calculate_wc_energies_cache - global _calculate_wc_energies_cache_hash - if CACHE_WC and _calculate_wc_energies_cache is not None: - if _calculate_wc_energies_cache_hash == hash_ndarray(seqarr): - return _calculate_wc_energies_cache loop_energies = calculate_loop_energies(temperature, negate) left_index_bits = seqarr[:, :-1] << 2 right_index_bits = seqarr[:, 1:] pair_indices = left_index_bits + right_index_bits pair_energies = loop_energies[pair_indices] energies: np.ndarray = np.sum(pair_energies, axis=1) - if CACHE_WC: - _calculate_wc_energies_cache = energies - _calculate_wc_energies_cache_hash = hash_ndarray(_calculate_wc_energies_cache) return energies def wc_arr(seqarr: np.ndarray) -> np.ndarray: - """Return numpy array of complements of sequences in `seqarr`.""" + """Return numpy array of reverse complements of sequences in `seqarr`.""" return (3 - seqarr)[:, ::-1] - - -def prefilter_length_10_11(low_dg: float, high_dg: float, temperature: float, end_gc: bool, - convert_to_list: bool = True) \ - -> Tuple[List[str], List[str]] | Tuple[DNASeqList, DNASeqList]: - """Return sequences of length 10 and 11 with wc energies between given values.""" - s10: DNASeqList = DNASeqList(length=10) - s11: DNASeqList = DNASeqList(length=11) - s10 = s10.filter_energy(low=low_dg, high=high_dg, temperature=temperature) - s11 = s11.filter_energy(low=low_dg, high=high_dg, temperature=temperature) - forbidden_subs = [f'{a}{b}{c}{d}' for a in ['G', 'C'] - for b in ['G', 'C'] - for c in ['G', 'C'] - for d in ['G', 'C']] - s10 = s10.filter_substring(forbidden_subs) - s11 = s11.filter_substring(forbidden_subs) - if end_gc: - print( - 'Removing any domains that end in either A or T; ' - 'also ensuring every domain has an A or T within 2 indexes of the end') - s10 = s10.filter_end_gc() - s11 = s11.filter_end_gc() - for seqs in (s10, s11): - if len(seqs) == 0: - raise ValueError( - f'low_dg {low_dg:.2f} and high_dg {high_dg:.2f} too strict! ' - f'no sequences of length {seqs.seqlen} found') - return (s10.to_list(), s11.to_list()) if convert_to_list else (s10, s11) - - -def all_cats(seq: Sequence[int], seqs: Sequence[int]) -> np.ndarray: - """ - Return all sequences obtained by concatenating seq to either end of a sequence in seqs. - - For example, - - .. code-block:: Python - - all_cats([0,1,2,3], [[3,3,3], [0,0,0]]) - - returns the numpy array - - .. code-block:: Python - - [[0,1,2,3,3,3,3], - [3,3,3,0,1,2,3], - [0,1,2,3,0,0,0], - [0,0,0,0,1,2,3]] - """ - seqarr = np.asarray([seq]) - seqsarr = np.asarray(seqs) - ar = seqarr.repeat(seqsarr.shape[0], axis=0) - ret = np.concatenate((seqsarr, ar), axis=1) - ret2 = np.concatenate((ar, seqsarr), axis=1) - ret = np.concatenate((ret, ret2)) - return ret diff --git a/nuad/search.py b/nuad/search.py index 08ff27a7..4ecbfa77 100644 --- a/nuad/search.py +++ b/nuad/search.py @@ -867,17 +867,17 @@ def search_for_sequences(design: nc.Design, params: SearchParameters) -> None: The function has some side effects. It writes a report on the optimal sequence assignment found so far every time a new improve assignment is found. - Whenever a new optimal sequence assignment is found, the following are written to files: - - DNA sequences of each strand are written to a text file . - - the whole dsd design - - a report on the DNA sequences indicating how well they do on constraints. + Whenever a new optimal sequence assignment is found, the following are also be written to files: + + * DNA sequences of each strand are written to a text file . + * the whole design itself + * a report on the DNA sequences indicating how well they do on constraints. :param design: The :any:`Design` containing the :any:`Domain`'s to which to assign DNA sequences and the :any:`Constraint`'s that apply to them :param params: - A :any:`SearchParameters` object with attributes that can be called within this function - for flexibility. + A :any:`SearchParameters` object with attributes that can be used to specify options for the search. """ @@ -1813,9 +1813,8 @@ def replace_with_new(self) -> None: _assert_violations_are_accurate(self.evaluations, self.violations) def update_scores_and_counts(self) -> None: - """ - :return: Total score of all evaluations. - """ + # return: Total score of all evaluations. + self.total_score = self.total_score_fixed = self.total_score_nonfixed = 0.0 self.num_evaluations = self.num_evaluations_nonfixed = self.num_evaluations_fixed = 0 self.num_violations = self.num_violations_nonfixed = self.num_violations_fixed = 0 @@ -1949,11 +1948,11 @@ class Evaluation(Generic[DesignPart]): def __init__(self, constraint: Constraint, violated: bool, part: DesignPart, domains: Iterable[Domain], score: float, summary: str, result: nc.Result) -> None: - # :param constraint: + # constraint: # :any:`Constraint` that was violated to result in this - # :param domains: + # domains: # :any:`Domain`'s that were involved in violating :py:data:`Evaluation.constraint` - # :param score: + # score: # total "score" of this violation, typically something like an excess energy over a # threshold, squared, multiplied by the :data:`Constraint.weight` self.constraint = constraint @@ -2200,7 +2199,7 @@ def display_report(design: nc.Design, constraints: Iterable[Constraint], with the same rules as `xlims` """ import matplotlib.pyplot as plt - from IPython.display import display, Markdown + from IPython.display import display, Markdown # noqa def dm(obj): display(Markdown(obj)) @@ -2228,7 +2227,7 @@ def dm(obj): for report in reports_without_values: part_type_name = report.constraint.part_name() dm(f'## {report.constraint.description}') - dm(f'### {report.num_violations}/{report.num_evaluations} (\#violations/\#evaluations)') + dm(f'### {report.num_violations}/{report.num_evaluations} (\#violations/\#evaluations)') # noqa for viol in report.violations: print(f' {part_type_name} {viol.part.name}: {viol.summary}') diff --git a/nuad/vienna_nupack.py b/nuad/vienna_nupack.py index 87096360..d3b1d562 100644 --- a/nuad/vienna_nupack.py +++ b/nuad/vienna_nupack.py @@ -19,7 +19,6 @@ import subprocess as sub import sys from multiprocessing.pool import ThreadPool -from pathos.pools import ProcessPool from typing import Sequence, Tuple, List, Iterable import numpy as np @@ -137,48 +136,53 @@ def tupleize(seqs: str | Iterable[str]) -> Tuple[str, ...]: return (seqs,) if isinstance(seqs, str) else tuple(seqs) -def pfunc_parallel( - pool: ProcessPool, - all_seqs: Sequence[str | Tuple[str, ...]], - temperature: float = default_temperature, - sodium: float = default_sodium, - magnesium: float = default_magnesium, - strand_association_penalty: bool = True, -) -> Tuple[float]: - num_seqs = len(all_seqs) - if num_seqs == 0: - return tuple() - - all_seqs = tuple(tupleize(seqs) for seqs in all_seqs) - - first_seqs = all_seqs[0] - - bases = sum(len(seq) for seq in first_seqs) - num_cores = nc.cpu_count(logical=True) - - # these thresholds were measured empirically; see notebook nuad_parallel_time_trials.ipynb - call_sequential = (len(all_seqs) == 1 - or (bases <= 30 and num_seqs <= 50) - or (bases <= 40 and num_seqs <= 40) - or (bases <= 50 and num_seqs <= 30) - or (bases <= 75 and num_seqs <= 20) - or (bases <= 100 and num_seqs <= 10) - or (bases <= 125 and num_seqs <= 4) - or (bases <= 150 and num_seqs <= 3) - or (num_seqs <= 1) - ) - - def calculate_energies_sequential(all_tuples: Sequence[Tuple[str, ...]]) -> Tuple[float]: - return tuple(pfunc(seqs, temperature, sodium, magnesium, strand_association_penalty) - for seqs in all_tuples) - - if call_sequential: - return calculate_energies_sequential(all_seqs) - - lists_of_sequence_pairs = nc.chunker(all_seqs, num_chunks=num_cores) - lists_of_energies = pool.map(calculate_energies_sequential, lists_of_sequence_pairs) - energies = nc.flatten(lists_of_energies) - return tuple(energies) +try: + from pathos.pools import ProcessPool + + def pfunc_parallel( + pool: ProcessPool, + all_seqs: Sequence[str | Tuple[str, ...]], + temperature: float = default_temperature, + sodium: float = default_sodium, + magnesium: float = default_magnesium, + strand_association_penalty: bool = True, + ) -> Tuple[float]: + num_seqs = len(all_seqs) + if num_seqs == 0: + return tuple() + + all_seqs = tuple(tupleize(seqs) for seqs in all_seqs) + + first_seqs = all_seqs[0] + + bases = sum(len(seq) for seq in first_seqs) + num_cores = nc.cpu_count(logical=True) + + # these thresholds were measured empirically; see notebook nuad_parallel_time_trials.ipynb + call_sequential = (len(all_seqs) == 1 + or (bases <= 30 and num_seqs <= 50) + or (bases <= 40 and num_seqs <= 40) + or (bases <= 50 and num_seqs <= 30) + or (bases <= 75 and num_seqs <= 20) + or (bases <= 100 and num_seqs <= 10) + or (bases <= 125 and num_seqs <= 4) + or (bases <= 150 and num_seqs <= 3) + or (num_seqs <= 1) + ) + + def calculate_energies_sequential(all_tuples: Sequence[Tuple[str, ...]]) -> Tuple[float]: + return tuple(pfunc(seqs, temperature, sodium, magnesium, strand_association_penalty) + for seqs in all_tuples) + + if call_sequential: + return calculate_energies_sequential(all_seqs) + + lists_of_sequence_pairs = nc.chunker(all_seqs, num_chunks=num_cores) + lists_of_energies = pool.map(calculate_energies_sequential, lists_of_sequence_pairs) + energies = nc.flatten(lists_of_energies) + return tuple(energies) +except ModuleNotFoundError as e: + raise e def nupack_complex_base_pair_probabilities(strand_complex: 'nc.Complex', # circular import causes problems