diff --git a/examples/many_strands_no_common_domains.py b/examples/many_strands_no_common_domains.py index 9f56a21f..dcf24e9c 100644 --- a/examples/many_strands_no_common_domains.py +++ b/examples/many_strands_no_common_domains.py @@ -52,7 +52,8 @@ def main() -> None: # num_strands = 10 # num_strands = 10 # num_strands = 50 - num_strands = 100 + # num_strands = 100 + num_strands = 200 # num_strands = 355 design = nc.Design() @@ -152,8 +153,8 @@ def main() -> None: params = ns.SearchParameters(constraints=[ # domain_nupack_ss_constraint, # strand_individual_ss_constraint, - strand_pairs_rna_duplex_constraint, - # strand_pairs_rna_plex_constraint, + # strand_pairs_rna_duplex_constraint, + strand_pairs_rna_plex_constraint, # strand_pair_nupack_constraint, # domain_pair_nupack_constraint, # domain_pairs_rna_plex_constraint, @@ -169,8 +170,8 @@ def main() -> None: # save_report_for_all_updates=True, # save_design_for_all_updates=True, force_overwrite=True, - # log_time=True, - scrolling_output=False, + log_time=True, + # scrolling_output=False, # report_only_violations=False, ) ns.search_for_sequences(design, params) diff --git a/nuad/constraints.py b/nuad/constraints.py index 723c8602..2af4fddb 100644 --- a/nuad/constraints.py +++ b/nuad/constraints.py @@ -72,7 +72,7 @@ domain_pools_num_sampled_key = 'domain_pools_num_sampled' domain_names_key = 'domain_names' starred_domain_indices_key = 'starred_domain_indices' -group_key = 'group' +# group_key = 'label' domain_pool_name_key = 'pool_name' length_key = 'length' substring_length_key = 'substring_length' @@ -2247,7 +2247,7 @@ def _check_vendor_string_not_none_or_empty(value: str, field_name: str) -> None: raise ValueError(f'field {field_name} in VendorFields cannot be empty') -default_strand_group = 'default_strand_group' +default_strand_label = 'default_strand_label' @dataclass @@ -2261,9 +2261,6 @@ class Strand(Part, JSONSerializable): """Set of positions of :any:`Domain`'s in :data:`Strand.domains` on this :any:`Strand` that are starred.""" - group: str - """Optional "group" field to describe strands that share similar properties.""" - _domain_names_concatenated: str """Concatenation of domain names; cached for efficiency since these are used in calculating hash values.""" @@ -2322,9 +2319,8 @@ class Strand(Part, JSONSerializable): def __init__(self, domains: Iterable[Domain] | None = None, starred_domain_indices: Iterable[int] = (), - group: str = default_strand_group, name: str | None = None, - label: str | None = None, + label: str = default_strand_label, vendor_fields: VendorFields | None = None, ) -> None: """ @@ -2347,7 +2343,6 @@ def __init__(self, methods for exporting to IDT formats (e.g., :meth:`Strand.write_idt_bulk_input_file`) """ self._all_intersecting_domains = None - self.group = group self._name = name # XXX: moved this check to Design constructor to allow subdomain graphs to be @@ -2401,8 +2396,7 @@ def clone(self, name: str | None) -> Strand: starred_domain_indices = list(self.starred_domain_indices) name = name if name is not None else self.name vendor_fields = None if self.vendor_fields is None else self.vendor_fields.clone() - return Strand(domains=domains, starred_domain_indices=starred_domain_indices, name=name, - group=self.group, label=self.label, vendor_fields=vendor_fields) + return Strand(domains=domains, starred_domain_indices=starred_domain_indices, name=name, label=self.label, vendor_fields=vendor_fields) def compute_derived_fields(self): """ @@ -2526,7 +2520,7 @@ def to_json_serializable(self, suppress_indent: bool = True) -> NoIndent | Dict[ Dictionary ``d`` representing this :any:`Strand` that is "naturally" JSON serializable, by calling ``json.dumps(d)``. """ - dct: Dict[str, Any] = {name_key: self.name, group_key: self.group} + dct: Dict[str, Any] = {name_key: self.name, label_key: self.label} domains_list = [domain.name for domain in self.domains] dct[domain_names_key] = NoIndent(domains_list) if suppress_indent else domains_list @@ -2569,18 +2563,15 @@ def from_json_serializable(json_map: Dict[str, Any], domains: List[Domain] = [domain_with_name[name] for name in domain_names_json] starred_domain_indices = mandatory_field(Strand, json_map, starred_domain_indices_key) - group = json_map.get(group_key, default_strand_group) - label: str = json_map.get(label_key) - idt_json = json_map.get(vendor_fields_key) - idt = None - if idt_json is not None: - idt = VendorFields.from_json_serializable(idt_json) + vendor_fields_json = json_map.get(vendor_fields_key) + vendor_fields = None + if vendor_fields_json is not None: + vendor_fields = VendorFields.from_json_serializable(vendor_fields_json) strand: Strand = Strand( - domains=domains, starred_domain_indices=starred_domain_indices, - group=group, name=name, label=label, vendor_fields=idt) + domains=domains, starred_domain_indices=starred_domain_indices, name=name, label=label, vendor_fields=vendor_fields) return strand def __repr__(self) -> str: @@ -3078,7 +3069,7 @@ class Design(JSONSerializable): Computed from :data:`Design.strands`, so not specified in constructor. """ - strands_by_group_name: Dict[str, List[Strand]] = field(init=False) + strands_by_label_name: Dict[str, List[Strand]] = field(init=False) """ Dict mapping each group name to a list of the :any:`Strand`'s in this :any:`Design` in the group. @@ -3133,9 +3124,9 @@ def compute_derived_fields(self) -> None: self.domains = remove_duplicates(domains) - self.strands_by_group_name = defaultdict(list) + self.strands_by_label_name = defaultdict(list) for strand in self.strands: - self.strands_by_group_name[strand.group].append(strand) + self.strands_by_label_name[strand.label].append(strand) self.store_domain_pools() @@ -3274,10 +3265,9 @@ def add_strand(self, domain_names: List[str] | None = None, domains: List[Domain] | None = None, starred_domain_indices: Iterable[int] | None = None, - group: str = default_strand_group, name: str | None = None, label: str | None = None, - idt: VendorFields | None = None, + vendor_fields: VendorFields | None = None, ) -> Strand: """ This is an alternative way to create strands instead of calling the :any:`Strand` constructor @@ -3309,7 +3299,7 @@ def add_strand(self, Name of this :any:`Strand`. :param label: Label to associate with this :any:`Strand`. - :param idt: + :param vendor_fields: :any:`VendorFields` object to associate with this :any:`Strand`; needed to call methods for exporting to IDT formats (e.g., :meth:`Strand.write_idt_bulk_input_file`) :return: @@ -3346,10 +3336,9 @@ def add_strand(self, domains_of_strand = list(domains) # type: ignore strand = Strand(domains=domains_of_strand, starred_domain_indices=starred_domain_indices, - group=group, name=name, label=label, - vendor_fields=idt) + vendor_fields=vendor_fields) for existing_strand in self.strands: if strand.name == existing_strand.name: @@ -3441,7 +3430,7 @@ def to_idt_bulk_input_format(self, domain_delimiter: str = '', key: KeyFunction[Strand] | None = None, warn_duplicate_name: bool = False, - only_strands_with_idt: bool = False, + only_strands_with_vendor_fields: bool = False, strands: Iterable[Strand] | None = None) -> str: """Called by :meth:`Design.write_idt_bulk_input_file` to determine what string to write to the file. This function can be used to get the string directly without creating a file. @@ -3459,7 +3448,7 @@ def to_idt_bulk_input_format(self, domain_delimiter=domain_delimiter, key=key, warn_duplicate_name=warn_duplicate_name, - only_strands_with_idt=only_strands_with_idt, + only_strands_with_vendor_fields=only_strands_with_vendor_fields, ) def write_idt_bulk_input_file(self, *, @@ -3470,7 +3459,7 @@ def write_idt_bulk_input_file(self, *, delimiter: str = ',', domain_delimiter: str = '', warn_duplicate_name: bool = True, - only_strands_with_idt: bool = False, + only_strands_with_vendor_fields: bool = False, strands: Iterable[Strand] | None = None) -> None: """Write ``.idt`` text file encoding the strands of this :any:`Design` with the field :data:`Strand.vendor_fields`, suitable for pasting into the "Bulk Input" field of IDT @@ -3505,7 +3494,7 @@ def write_idt_bulk_input_file(self, *, is raised (regardless of the value of this parameter) if two different :any:`Strand`'s have the same name but different sequences, IDT scales, or IDT purifications. - :param only_strands_with_idt: + :param only_strands_with_vendor_fields: If False (the default), all non-scaffold sequences are output, with reasonable default values chosen if the field :data:`Strand.vendor_fields` is missing. If True, then strands lacking the field :data:`Strand.vendor_fields` will not be exported. @@ -3518,7 +3507,7 @@ def write_idt_bulk_input_file(self, *, domain_delimiter=domain_delimiter, key=key, warn_duplicate_name=warn_duplicate_name, - only_strands_with_idt=only_strands_with_idt, + only_strands_with_vendor_fields=only_strands_with_vendor_fields, strands=strands) if extension is None: extension = 'idt' @@ -3529,7 +3518,7 @@ def write_idt_plate_excel_file(self, *, directory: str = '.', key: KeyFunction[Strand] | None = None, warn_duplicate_name: bool = False, - only_strands_with_idt: bool = False, + only_strands_with_vendor_fields: bool = False, use_default_plates: bool = True, warn_using_default_plates: bool = True, plate_type: PlateType = PlateType.wells96, strands: Iterable[Strand] | None = None) -> None: @@ -3544,7 +3533,7 @@ def write_idt_plate_excel_file(self, *, For instance, if the script is named ``my_origami.py``, then the sequences will be written to ``my_origami.xls``. - If the last plate as fewer than 24 strands for a 96-well plate, or fewer than 96 strands for a + If the last plate has fewer than 24 strands for a 96-well plate, or fewer than 96 strands for a 384-well plate, then the last two plates are rebalanced to ensure that each plate has at least that number of strands, because IDT charges extra for a plate with too few strands: https://www.idtdna.com/pages/products/custom-dna-rna/dna-oligos/custom-dna-oligos @@ -3564,16 +3553,16 @@ def write_idt_plate_excel_file(self, *, raised (regardless of the value of this parameter) if two different :any:`Strand`'s have the same name but different sequences, IDT scales, or IDT purifications. - :param only_strands_with_idt: + :param only_strands_with_vendor_fields: If False (the default), all non-scaffold sequences are output, with reasonable default values chosen if the field :data:`Strand.vendor_fields` is missing. (though scaffold is included if `export_scaffold` is True). If True, then strands lacking the field :data:`Strand.vendor_fields` will not be exported. If False, then `use_default_plates` must be True. :param use_default_plates: - Use default values for plate and well (ignoring those in idt fields, which may be None). - If False, each Strand to export must have the field :data:`Strand.vendor_fields`, so in particular - the parameter `only_strands_with_idt` must be True. + Use default values for plate and well (ignoring those in :data:`Strand.vendor_fields`, which + may be None). If False, each Strand to export must have the field :data:`Strand.vendor_fields`, + so in particular the parameter `only_strands_with_vendor_fields` must be True. :param warn_using_default_plates: specifies whether, if `use_default_plates` is True, to print a warning for strands whose :data:`Strand.vendor_fields` has the fields :data:`VendorFields.plate` and :data:`VendorFields.well`, @@ -3595,7 +3584,7 @@ def write_idt_plate_excel_file(self, *, filename=filename, key=key, warn_duplicate_name=warn_duplicate_name, - only_strands_with_idt=only_strands_with_idt, + only_strands_with_vendor_fields=only_strands_with_vendor_fields, use_default_plates=use_default_plates, warn_using_default_plates=warn_using_default_plates, plate_type=plate_type) @@ -3700,39 +3689,17 @@ def from_scadnano_design(sc_design: sc.Design, strands_to_include = [strand for strand in sc_design.strands if strand not in ignored_strands] \ if ignored_strands is not None else sc_design.strands - # warn if not labels are dicts containing group_name_key on strands - for sc_strand in strands_to_include: - if (isinstance(sc_strand.label, dict) and group_key not in sc_strand.label) or \ - (not isinstance(sc_strand.label, dict) and not hasattr(sc_strand.label, group_key)): - logger.warning(f'Strand label {sc_strand.label} should be an object with attribute named ' - f'"{group_key}" (for instance a dict or namedtuple).\n' - f' The label is type {type(sc_strand.label)}. ' - f'In order to auto-populate StrandGroups, ensure the label has attribute ' - f'named "{group_key}" with associated value of type str.') - else: - label_value = Design.get_group_name_from_strand_label(sc_strand) - if not isinstance(label_value, str): - logger.warning(f'Strand label {sc_strand.label} has attribute named ' - f'"{group_key}", but its associated value is not a string.\n' - f'The value is type {type(label_value)}. ' - f'In order to auto-populate StrandGroups, ensure the label has attribute ' - f'named "{group_key}" with associated value of type str.') - - # raise TypeError(f'strand label {sc_strand.label} must be a dict, ' - # f'but instead is type {type(sc_strand.label)}') - # groups scadnano strands by strand labels sc_strand_groups: DefaultDict[str, List[sc.Strand]] = defaultdict(list) for sc_strand in strands_to_include: assigned = False - if hasattr(sc_strand.label, group_key) or ( - isinstance(sc_strand.label, dict) and group_key in sc_strand.label): - group = Design.get_group_name_from_strand_label(sc_strand) - if isinstance(group, str): - sc_strand_groups[group].append(sc_strand) + if isinstance(sc_strand.label, dict): + label = Design.get_group_name_from_strand_label(sc_strand) + if isinstance(label, str): + sc_strand_groups[label].append(sc_strand) assigned = True if not assigned: - sc_strand_groups[default_strand_group].append(sc_strand) + sc_strand_groups[default_strand_label].append(sc_strand) # make dsd StrandGroups, taking names from Strands and Domains, # and assign (and maybe fix) DNA sequences @@ -3750,7 +3717,6 @@ def from_scadnano_design(sc_design: sc.Design, domain_names: List[str] = [domain.name for domain in sc_strand.domains] sequence = sc_strand.dna_sequence nuad_strand: Strand = design.add_strand(domain_names=domain_names, - group=group, name=sc_strand.name, label=sc_strand.label) # assign sequence @@ -3783,12 +3749,12 @@ def from_scadnano_design(sc_design: sc.Design, @staticmethod def get_group_name_from_strand_label(sc_strand: Strand) -> Any: - if hasattr(sc_strand.label, group_key): - return getattr(sc_strand.label, group_key) - elif isinstance(sc_strand.label, dict) and group_key in sc_strand.label: - return sc_strand.label[group_key] + if hasattr(sc_strand.label, label_key): + return getattr(sc_strand.label, label_key) + elif isinstance(sc_strand.label, dict) and label_key in sc_strand.label: + return sc_strand.label[label_key] else: - raise AssertionError(f'label does not have either an attribute or a dict key "{group_key}"') + raise AssertionError(f'label does not have either an attribute or a dict key "{label_key}"') def assign_fields_to_scadnano_design(self, sc_design: sc.Design, ignored_strands: Iterable[Strand] = (), @@ -3880,34 +3846,6 @@ def shared_strands_with_scadnano_design(self, sc_design: sc.Design, return pairs - def assign_strand_groups_to_labels(self, sc_design: sc.Design, - ignored_strands: Iterable[Strand] = (), - overwrite: bool = False) -> None: - """ - TODO: document this - """ - strand_pairs = self.shared_strands_with_scadnano_design(sc_design, ignored_strands) - - for nuad_strand, sc_strands in strand_pairs: - for sc_strand in sc_strands: - if nuad_strand.group is not None: - if sc_strand.label is None: - sc_strand.label = {} - elif not isinstance(sc_strand.label, dict): - raise ValueError(f'cannot assign strand group to strand {sc_strand.name} ' - f'because it already has a label that is not a dict. ' - f'It must either have label None or a dict.') - - # if we get here, then sc_strand.label is a dict. Need to check whether - # it already has a 'group' key. - if group_key in sc_strand.label is not None and not overwrite: - raise ValueError(f'Cannot assign strand group from nuad strand to scadnano strand ' - f'{sc_strand.name} (through its label field) because the ' - f'scadnano strand already has a label with group key ' - f'\n{sc_strand.label[group_key]}. ' - f'Set overwrite to True to force an overwrite.') - sc_strand.label[group_key] = nuad_strand.group - def assign_idt_fields_to_scadnano_design(self, sc_design: sc.Design, ignored_strands: Iterable[Strand] = (), overwrite: bool = False) -> None: @@ -4752,12 +4690,12 @@ def verify_designs_match(design1: Design, design2: Design, check_fixed: bool = T if strand1.name != strand2.name: raise ValueError(f'strand names at position {idx} don\'t match: ' f'{strand1.name} and {strand2.name}') - if (strand1.group is not None - and strand2.group is not None - and strand1.group != strand2.group): # noqa - raise ValueError(f'strand {strand2.name} group name does not match:' - f'design1 strand {strand1.name} group = {strand1.group},\n' - f'design2 strand {strand2.name} group = {strand2.group}') + if (strand1.label is not None + and strand2.label is not None + and strand1.label != strand2.label): # noqa + raise ValueError(f'strand {strand2.name} label name does not match:' + f'design1 strand {strand1.name} label = {strand1.label},\n' + f'design2 strand {strand2.name} label = {strand2.label}') for domain1, domain2 in zip(strand1.domains, strand2.domains): if domain1.name != domain2.name: raise ValueError(f'domain of strand {strand2.name} don\'t match: ' @@ -6746,7 +6684,8 @@ def rna_duplex_strand_pairs_constraint( # we use ThreadPool instead of pathos because we're farming this out to processes through # subprocess module anyway, no need for pathos to boot up separate processes or serialize through dill - thread_pool = ThreadPool(processes=num_cores) + if parallel: + thread_pool = ThreadPool(processes=num_cores) def calculate_energies(seq_pairs: Sequence[Tuple[str, str]]) -> Tuple[float]: if parallel: @@ -6832,7 +6771,8 @@ def rna_plex_strand_pairs_constraint( # we use ThreadPool instead of pathos because we're farming this out to processes through # subprocess module anyway, no need for pathos to boot up separate processes or serialize through dill - thread_pool = ThreadPool(processes=num_cores) + if parallel: + thread_pool = ThreadPool(processes=num_cores) def calculate_energies(seq_pairs: Sequence[Tuple[str, str]]) -> Tuple[float]: if parallel: diff --git a/nuad/np.py b/nuad/np.py index c72b30f2..0cdf3f26 100644 --- a/nuad/np.py +++ b/nuad/np.py @@ -522,7 +522,7 @@ def longest_common_substrings_singlea1(a1: np.ndarray, a2s: np.ndarray) \ idx_longest_raveled = np.argmax(counter_flat, axis=1) len_longest = counter_flat[np.arange(counter_flat.shape[0]), idx_longest_raveled] - idx_longest = np.unravel_index(idx_longest_raveled, dims=(len_a1 + 1, len_a2 + 1)) + idx_longest = np.unravel_index(idx_longest_raveled, shape=(len_a1 + 1, len_a2 + 1)) a1idx_longest = idx_longest[0] - len_longest a2idx_longest = idx_longest[1] - len_longest @@ -583,7 +583,7 @@ def _longest_common_substrings_pairs(a1s: np.ndarray, a2s: np.ndarray) \ idx_longest_raveled = np.argmax(counter_flat, axis=1) len_longest = counter_flat[np.arange(counter_flat.shape[0]), idx_longest_raveled] - idx_longest = np.unravel_index(idx_longest_raveled, dims=(len_a1 + 1, len_a2 + 1)) + idx_longest = np.unravel_index(idx_longest_raveled, shape=(len_a1 + 1, len_a2 + 1)) a1idx_longest = idx_longest[0] - len_longest a2idx_longest = idx_longest[1] - len_longest @@ -669,7 +669,7 @@ def _strongest_common_substrings_all_pairs(a1s: np.ndarray, a2s: np.ndarray, tem len_strongest = counter_flat[np.arange(counter_flat.shape[0]), idx_strongest_raveled] energy_strongest = energies_flat[np.arange(counter_flat.shape[0]), idx_strongest_raveled] - idx_strongest = np.unravel_index(idx_strongest_raveled, dims=(len_a1 + 1, len_a2 + 1)) + idx_strongest = np.unravel_index(idx_strongest_raveled, shape=(len_a1 + 1, len_a2 + 1)) a1idx_strongest = idx_strongest[0] - len_strongest a2idx_strongest = idx_strongest[1] - len_strongest @@ -902,10 +902,6 @@ def write_to_file(self, filename: str) -> None: for i in range(self.numseqs): f.write(self.get_seq_str(i) + '\n') - def wcenergy(self, idx: int, temperature: float) -> float: - """Return energy of idx'th sequence binding to its complement.""" - return wcenergy(self.seqarr[idx], temperature) - def __repr__(self) -> str: return 'DNASeqSet(seqs={})'.format(str([self[i] for i in range(self.numseqs)])) @@ -1050,10 +1046,19 @@ def filter_energy(self, low: float, high: float, temperature: float) -> DNASeqLi def energies(self, temperature: float) -> np.ndarray: """ + Calculates the nearest-neighbor binding energy of each sequence with its perfect complement + (summing over all length-2 substrings of the domain's sequence), + using parameters from the 2004 Santa-Lucia and Hicks paper + (https://www.annualreviews.org/doi/abs/10.1146/annurev.biophys.32.110601.141800, + see Table 1, and example on page 419). + + This is used by :any:`NearestNeighborEnergyFilter` to calculate the energy + of domains when filtering. + :param temperature: temperature in Celsius :return: - nearest-neighbor energies of each sequence with its perfect Watson-Crick complement + array of nearest-neighbor energies of each sequence with its perfect Watson-Crick complement """ wcenergies = calculate_wc_energies(self.seqarr, temperature) return wcenergies @@ -1122,31 +1127,6 @@ def filter_base_at_pos(self, pos: int, base: str) -> DNASeqList: seqarrpass = self.seqarr[good] return DNASeqList(seqarr=seqarrpass) - def filter_substring(self, subs: Sequence[str]) -> DNASeqList: - """Remove any sequence with any elements from subs as a substring.""" - if len(set([len(sub) for sub in subs])) != 1: - raise ValueError(f'All substrings in subs must be equal length: {subs}') - sublen = len(subs[0]) - subints = [[base2bits[base] for base in sub] for sub in subs] - powarr = [4 ** k for k in range(sublen)] - subvals = np.dot(subints, powarr) - toeplitz = create_toeplitz(self.seqlen, sublen) - convolution = np.dot(toeplitz, self.seqarr.transpose()) - passall = np.ones(self.numseqs, dtype=bool) - for subval in subvals: - passsub = np.all(convolution != subval, axis=0) - passall = passall & passsub - seqarrpass = self.seqarr[passall] - return DNASeqList(seqarr=seqarrpass) - - def filter_seqs_by_g_quad(self) -> DNASeqList: - """Removes any sticky ends with 4 G's in a row (a G-quadruplex).""" - return self.filter_substring(['GGGG']) - - def filter_seqs_by_g_quad_c_quad(self) -> DNASeqList: - """Removes any sticky ends with 4 G's or C's in a row (a quadruplex).""" - return self.filter_substring(['GGGG', 'CCCC']) - def index(self, sequence: str | np.ndarray) -> int: # finds index of sequence in (rows of) self.seqarr # raises IndexError if not present @@ -1283,3 +1263,84 @@ def calculate_wc_energies(seqarr: np.ndarray, temperature: float, negate: bool = def wc_arr(seqarr: np.ndarray) -> np.ndarray: """Return numpy array of reverse complements of sequences in `seqarr`.""" return (3 - seqarr)[:, ::-1] + + +def energy_hist(length: int | Iterable[int], temperature: float = 37, + combine_lengths: bool = False, + num_random_sequences: int = 100_000, + figsize: Tuple[int, int] = (15, 6), **kwargs) -> None: + """ + Make a matplotlib histogram of the nearest-neighbor energies (as defined by + :meth:`DNASeqList.energies`) of all DNA sequences of the given length(s), + or a randomly selected subset if length(s) is too large to enumerate all DNA sequences + of that length. + + This is useful, for example, to choose low and high energy values to pass to + :any:`NearestNeighborEnergyFilter`. + + :param length: + length of DNA sequences to consider, or an iterable (e.g., list) of lengths + :param temperature: + temperature in Celsius + :param combine_lengths: + If True, then `length` should be an iterable, and the histogram will combine all calculated energies + from all lengths into one histogram to plot. If False (the default), then different lengths are + plotted in different colors in the histogram. + :param num_random_sequences: + If the length is too large to enumerate all DNA sequences of that length, + then this many random sequences are used to estimate the histogram. + :param figsize: + Size of the figure in inches. + :param kwargs: + Any keyword arguments given are passed along as keyword arguments to matplotlib.pyplot.hist: + https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.hist.html + """ + import matplotlib.pyplot as plt + + if combine_lengths and isinstance(length, int): + raise ValueError(f'length must be an iterable if combine_lengths is False, ' + f'but it is the int {length}') + + plt.figure(figsize=figsize) + plt.xlabel(f'Nearest-neighbor energy (kcal/mol)') + + lengths = [length] if isinstance(length, int) else length + + alpha = 1 + if len(lengths) > 1: + alpha = 0.5 + if 'label' in kwargs: + raise ValueError(f'label (={kwargs["label"]}) ' + f'should not be specified if multiple lengths are given') + + bins = kwargs.pop('bins', 20) + + all_energies = [] + for length in lengths: + if length < 9: + seqs = DNASeqList(length=length) + else: + seqs = DNASeqList(length=length, num_random_seqs=num_random_sequences) + energies = seqs.energies(temperature=temperature) + + if combine_lengths: + all_energies.extend(energies) + else: + label = kwargs['label'] if 'label' in kwargs else f'length {length}' + _ = plt.hist(energies, alpha=alpha, label=label, bins=bins, **kwargs) + + if combine_lengths: + if 'label' in kwargs: + label = kwargs['label'] + del kwargs['label'] + else: + if len(lengths) == 1: + label = f'length {length}' + else: + lengths_delimited = ', '.join(map(str, lengths)) + label = f'lengths {lengths_delimited} combined' + _ = plt.hist(all_energies, alpha=alpha, label=label, bins=bins, **kwargs) + + plt.legend(loc='upper right') + title = kwargs.pop('title', f'Nearest-neighbor energies of DNA sequences at {temperature} C') + plt.title(title) diff --git a/nuad/search.py b/nuad/search.py index 89e249f6..d76e5b2b 100644 --- a/nuad/search.py +++ b/nuad/search.py @@ -927,6 +927,9 @@ def search_for_sequences(design: nc.Design, params: SearchParameters) -> None: if rng_restart is not None: rng = rng_restart + iteration = 0 + stopwatch = Stopwatch() + eval_set = EvaluationSet(params.constraints, params.never_increase_score) eval_set.evaluate_all(design) @@ -935,11 +938,11 @@ def search_for_sequences(design: nc.Design, params: SearchParameters) -> None: _write_intermediate_files(design=design, params=params, rng=rng, num_new_optimal=num_new_optimal, directories=directories, eval_set=eval_set) - iteration = 0 - stopwatch = Stopwatch() while not _done(iteration, params, eval_set): if params.log_time: + stopwatch.stop() + _log_time(stopwatch) stopwatch.restart() _check_cpu_count(cpu_count) @@ -976,9 +979,6 @@ def search_for_sequences(design: nc.Design, params: SearchParameters) -> None: eval_set=eval_set) iteration += 1 - if params.log_time: - stopwatch.stop() - _log_time(stopwatch) _log_constraint_summary(params=params, eval_set=eval_set, iteration=iteration, num_new_optimal=num_new_optimal) @@ -1293,11 +1293,11 @@ def _log_time(stopwatch: Stopwatch, include_median: bool = False) -> None: med_time = statistics.median(time_last_n_calls) content += f' median: {med_time:.1f} ms |' content_width = len(content) - logger.info('\n' + ('-' * content_width) + '\n' + content) + logger.info('\n' + content) else: # skip appending first time, since it is much larger and skews the average content = f'| time for first call: {stopwatch.milliseconds_str()} ms |' - logger.info('\n' + ('-' * len(content)) + '\n' + content) + logger.info('\n' + content) time_last_n_calls_available = True diff --git a/tests/test.py b/tests/test.py index 910ef48e..e75c9b51 100644 --- a/tests/test.py +++ b/tests/test.py @@ -7,6 +7,7 @@ import nuad.constraints as nc import nuad.search as ns +import nuad.vienna_nupack as nv import scadnano as sc from nuad.constraints import Design, Domain, _get_base_pair_domain_endpoints_to_check, \ _get_implicitly_bound_domain_addresses, _exterior_base_type_of_domain_3p_end, _BasePairDomainEndpoint, \ @@ -58,7 +59,7 @@ def construct_strand(design: Design, domain_names: List[str], domain_lengths: Li class TestIntersectingDomains(unittest.TestCase): def test_strand_intersecting_domains(self) -> None: - """ + r""" Test strand construction with nested subdomains .. code-block:: none @@ -280,7 +281,7 @@ class TestExportDNASequences(unittest.TestCase): def test_idt_bulk_export(self) -> None: custom_idt = nc.VendorFields(scale='100nm', purification='PAGE') design = nc.Design() - design.add_strand(domain_names=['a', 'b*', 'c', 'd*'], name='s0', idt=custom_idt) + design.add_strand(domain_names=['a', 'b*', 'c', 'd*'], name='s0', vendor_fields=custom_idt) design.add_strand(domain_names=['d', 'c*', 'e', 'f'], name='s1') # a b c d e f @@ -314,7 +315,7 @@ def test_write_idt_plate_excel_file(self) -> None: design = nc.Design() for strand_idx in range(3 * plate_type.num_wells_per_plate() + 10): idt = nc.VendorFields() - strand = design.add_strand(name=f's{strand_idx}', domain_names=[f'd{strand_idx}'], idt=idt) + strand = design.add_strand(name=f's{strand_idx}', domain_names=[f'd{strand_idx}'], vendor_fields=idt) strand.domains[0].set_fixed_sequence('T' * strand_len) design.write_idt_plate_excel_file(filename=filename, plate_type=plate_type) @@ -931,7 +932,7 @@ def test_error_constructed_unfixed_domain_with_fixed_subdomains(self): subdomains=[b, c]) def test_construct_strand(self): - """ + r""" Test strand construction with nested subdomains .. code-block:: none @@ -957,7 +958,7 @@ def test_construct_strand(self): self.assertEqual(strand.domains[0], a) def test_error_strand_with_unassignable_subsequence(self): - """ + r""" Test that constructing a strand with an unassignable subsequence raises a ValueError. @@ -987,7 +988,7 @@ def test_error_strand_with_unassignable_subsequence(self): self.assertRaises(ValueError, Design, strands=[strand]) def test_error_strand_with_redundant_independence(self): - """ + r""" Test that constructing a strand with an redundant indepndence in subdomain graph raises a ValueError. @@ -1039,7 +1040,7 @@ def test_error_cycle(self): self.assertRaises(ValueError, Design, strands=[strand]) def sample_nested_domains(self) -> Dict[str, Domain]: - """Returns domains with the following subdomain hierarchy: + r"""Returns domains with the following subdomain hierarchy: .. code-block:: none @@ -1064,7 +1065,7 @@ def sample_nested_domains(self) -> Dict[str, Domain]: return {domain.name: domain for domain in [a, b, C, E, F, g, h]} def test_assign_dna_sequence_to_parent(self): - """ + r""" Test assigning dna sequence to parent (a) and propagating it downwards .. code-block:: none @@ -1088,7 +1089,7 @@ def test_assign_dna_sequence_to_parent(self): self.assertEqual(sequence[18:], domains['h'].sequence()) def test_assign_dna_sequence_to_leaf(self): - """ + r""" Test assigning dna sequence to E, F and propgate upward to b .. code-block:: none @@ -1109,7 +1110,7 @@ def test_assign_dna_sequence_to_leaf(self): self.assertEqual('CATAGCTTTCC', domains['b'].sequence()) def test_assign_dna_sequence_mixed(self): - """ + r""" Test assigning dna sequence to E, F, and C and propgate to entire tree. .. code-block:: none @@ -1176,7 +1177,7 @@ def test_error_assign_dna_sequence_to_parent_with_incorrect_size_subdomain(self) a.set_sequence('A' * 15) def test_construct_strand_using_dependent_subdomain(self) -> None: - """Test constructing a strand using a dependent subdomain (not parent) + r"""Test constructing a strand using a dependent subdomain (not parent) .. code-block:: none @@ -1205,6 +1206,22 @@ def test_design_finds_independent_subdomains(self) -> None: self.assertIn(B, domains) self.assertIn(C, domains) +class TestNUPACK(unittest.TestCase): + + def test_pfunc(self) -> None: + seq = 'ACGTACGTAGCTGATCCAGCTGATCG' + energy = nv.pfunc(seq) + self.assertTrue(energy < 0) + +class TestViennaRNA(unittest.TestCase): + def test_rna_plex(self) -> None: + pairs = [ + ('ACGT','ACGT'), + ('TTAC','AATG'), + ] + energies = nv.rna_plex_multiple(pairs) + self.assertEqual(2, len(energies)) + if __name__ == '__main__': unittest.main()