Skip to content

Commit

Permalink
Merge pull request #15 from Genentech/tangermeme_negatives
Browse files Browse the repository at this point in the history
replace bpnetlite with tangermeme
  • Loading branch information
avantikalal authored Aug 21, 2024
2 parents 194a663 + 6bdceab commit 1ca2765
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 60 deletions.
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ install_requires =
genomepy
bioframe >= 0.4
captum == 0.5.0
bpnet-lite == 0.5.7
logomaker >= 0.8
pyBigWig
ledidi
Expand Down
30 changes: 9 additions & 21 deletions src/grelu/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
import os
import subprocess
import tempfile
from typing import Callable, List, Optional, Union

import bioframe as bf
Expand Down Expand Up @@ -483,7 +482,6 @@ def get_gc_matched_intervals(
genome: str,
binwidth: float = 0.1,
chroms: str = "autosomes",
gc_bw_file: str = None,
blacklist: str = "hg38",
seed: Optional[int] = None,
) -> pd.DataFrame:
Expand All @@ -495,15 +493,13 @@ def get_gc_matched_intervals(
genome: Name of the genome corresponding to intervals
binwidth: Resolution of GC content
chroms: Chromosomes to search for matched intervals
gc_bw_file: Path to a bigWig file of genomewide GC content.
If None, will be created.
blacklist: Blacklist file of regions to exclude
seed: Random seed
Returns:
A pandas dataframe containing GC-matched negative intervals.
"""
from bpnetlite.negatives import calculate_gc_genomewide, extract_matching_loci
from tangermeme.match import extract_matching_loci

from grelu.io.genome import get_genome
from grelu.sequence.utils import get_unique_length
Expand All @@ -514,25 +510,17 @@ def get_gc_matched_intervals(
# Get seq_len
seq_len = get_unique_length(intervals)

# Get bigWig file of GC content
if gc_bw_file is None:
gc_bw_file = "gc_{}_{}.bw".format(genome.name, seq_len)
print("Calculating GC content genomewide and saving to {}".format(gc_bw_file))
calculate_gc_genomewide(
fasta=genome.genome_file,
bigwig=gc_bw_file,
width=seq_len,
include_chroms=chroms,
verbose=True,
)

print("Extracting matching intervals")
_, tmpfile = tempfile.mkstemp()
intervals.iloc[:, :3].to_csv(tmpfile, sep="\t", index=False, header=False)
matched_loci = extract_matching_loci(
bed=tmpfile, bigwig=gc_bw_file, width=seq_len, bin_width=binwidth, verbose=True
intervals,
fasta=genome.genome_file,
in_window=seq_len,
gc_bin_width=binwidth,
chroms=chroms,
verbose=False,
random_state=seed,
)
os.remove(tmpfile)

print("Filtering blacklist")
if blacklist is not None:
matched_loci = filter_blacklist(matched_loci, blacklist)
Expand Down
1 change: 1 addition & 0 deletions src/grelu/interpret/motifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def marginalize_patterns(
genome=genome,
rc=rc,
n_shuffles=n_shuffles,
seed=seed,
)

# Get predictions on the sequences before motif insertion
Expand Down
16 changes: 12 additions & 4 deletions src/grelu/sequence/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,15 @@ def strings_to_indices(
)


def indices_to_one_hot(indices: np.ndarray) -> Tensor:
def indices_to_one_hot(indices: np.ndarray, add_batch_axis: bool = False) -> Tensor:
"""
Convert integer-encoded DNA sequences to one-hot encoded format.
Args:
indices: Integer-encoded DNA sequences.
add_batch_axis: If True, a batch axis will be included in the output for single
sequences. If False, the output for a single sequence will be a 2-dimensional
tensor.
Returns:
The one-hot encoded sequences.
Expand All @@ -274,9 +277,12 @@ def indices_to_one_hot(indices: np.ndarray) -> Tensor:

# Convert a single sequence
if indices.ndim == 1:
return one_hot(torch.LongTensor(indices.copy()), num_classes=5)[:, :4].T.type(
one_hot = one_hot(torch.LongTensor(indices.copy()), num_classes=5)[
:, :4
].T.type(
torch.float32
) # Output shape: 4, L
return one_hot.unsqueeze(0) if add_batch_axis else one_hot

# Convert multiple sequences
else:
Expand Down Expand Up @@ -370,6 +376,7 @@ def convert_input_type(
output_type: str = "indices",
genome: Optional[str] = None,
add_batch_axis: bool = False,
input_type: Optional[str] = None,
) -> Union[pd.DataFrame, str, List[str], np.ndarray, Tensor]:
"""
Convert input DNA sequence data into the desired format.
Expand All @@ -381,6 +388,7 @@ def convert_input_type(
add_batch_axis: If True, a batch axis will be included in the output for single
sequences. If False, the output for a single sequence will be a 2-dimensional
tensor.
input_type: Format of the input sequence (optional)
Returns:
The converted DNA sequence(s) in the desired format.
Expand All @@ -390,7 +398,7 @@ def convert_input_type(
"""
# Determine input type
input_type = get_input_type(inputs)
input_type = input_type or get_input_type(inputs)

# If no conversion needed, return inputs as is
if input_type == output_type:
Expand Down Expand Up @@ -419,7 +427,7 @@ def convert_input_type(
# Convert indices
if input_type == "indices":
if output_type == "one_hot":
return indices_to_one_hot(inputs)
return indices_to_one_hot(inputs, add_batch_axis=add_batch_axis)
elif output_type == "strings":
return indices_to_strings(inputs)

Expand Down
30 changes: 11 additions & 19 deletions src/grelu/sequence/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ def reverse_complement(
def dinuc_shuffle(
seqs: Union[pd.DataFrame, np.ndarray, List[str]],
n_shuffles: int = 1,
start=0,
end=-1,
input_type: Optional[str] = None,
seed: Optional[int] = None,
genome: Optional[str] = None,
Expand All @@ -393,32 +395,22 @@ def dinuc_shuffle(
Returns:
Shuffled sequences in the same format as the input
"""
import torch
from bpnetlite.attributions import dinucleotide_shuffle
from einops import rearrange
from tangermeme.ersatz import dinucleotide_shuffle

# Input format
input_type = input_type or get_input_type(seqs)

# One-hot encode
seqs = convert_input_type(seqs, "one_hot", genome=genome) # N, 4, L
seqs = convert_input_type(
seqs, "one_hot", genome=genome, add_batch_axis=True
) # B, 4, L

# Shuffle sequences as many times as required
if n_shuffles > 0:
if seqs.ndim == 2: # 4, L
shuf_seqs = dinucleotide_shuffle(
seqs, n_shuffles=n_shuffles, random_state=seed
) # N, 4, L
else:
shuf_seqs = torch.vstack(
[
dinucleotide_shuffle(seq, n_shuffles=n_shuffles, random_state=seed)
for seq in seqs
]
) # B, 4, L

# If no shuffling is required, return the original sequences
else:
return seqs
shuf_seqs = dinucleotide_shuffle(
X=seqs, start=start, end=end, n=n_shuffles, random_state=seed, verbose=False
) # B, n, 4, L
shuf_seqs = rearrange(shuf_seqs, "b n t l -> (b n) t l")

return convert_input_type(shuf_seqs, input_type)

Expand Down
32 changes: 20 additions & 12 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,34 +797,37 @@ def test_ism_dataset():
def test_marginalize_dataset_variants():
# Marginalize variants
ds = VariantMarginalizeDataset(
variants=variants, genome="hg38", seq_len=6, n_shuffles=2, seed=0
variants=variants, genome="hg38", seq_len=12, n_shuffles=2, seed=0
)
assert (
(ds.n_shuffles == 2)
and (ds.seq_len == 6)
and (ds.seq_len == 12)
and (ds.n_seqs == 2)
and (ds.ref.shape == (2, 1))
and (ds.alt.shape == (2, 1))
and (len(ds) == 8)
and (ds.n_augmented == 2)
and (np.allclose(ds.ref, np.array([[2], [2]])))
and (np.allclose(ds.alt, np.array([[0], [0]])))
)
assert convert_input_type(ds.seqs, "strings") == ["CATACGTGAGGC", "AGGAGGCCAAAG"]
xs = [convert_input_type(ds[i], "strings") for i in range(len(ds))]
assert xs == [
"ACGTGA",
"ACATGA",
"ACGTGA",
"ACATGA",
"AGGCCA",
"AGACCA",
"AGGCCA",
"AGACCA",
"CACGTGTGAGGC",
"CACGTATGAGGC",
"CACGAGAGTGGC",
"CACGAAAGTGGC",
"AAGGGGGCCAAG",
"AAGGGAGCCAAG",
"AAGAGGGCCAAG",
"AAGAGAGCCAAG",
]


def test_marginalize_dataset_motifs():
# Marginalize motifs
ds = PatternMarginalizeDataset(
seqs=["ACCTACACT"], patterns=["AAA"], n_shuffles=2, seed=0
seqs=["AAGACATACAACGCGCGCTAACATAGCAAC"], patterns=["AAA"], n_shuffles=2, seed=0
)
assert (
(ds.n_shuffles == 2)
Expand All @@ -836,7 +839,12 @@ def test_marginalize_dataset_motifs():
)

xs = [convert_input_type(ds[i], "strings") for i in range(len(ds))]
assert xs == ["ACACCGACG", "ACAAAAACG", "ACACGACCG", "ACAAAACCG"]
assert xs == [
"ACGCATACGAGCGCTACAGCAACATAAAAC",
"ACGCATACGAGCGAAACAGCAACATAAAAC",
"ACTAACAACAGCACGCGCGATATAAGCAAC",
"ACTAACAACAGCAAAAGCGATATAAGCAAC",
]


# Test Motif scanning dataset
Expand Down
11 changes: 8 additions & 3 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_trim_pwm():


def test_marginalize_patterns():
seqs = ["ACTGT", "GATCC"]
seqs = ["CATACGTGAGGC", "AGGAGGCCAAAG"]
preds_before, preds_after = marginalize_patterns(
model,
patterns=["A"],
Expand All @@ -81,9 +81,14 @@ def test_marginalize_patterns():
compare_func=None,
)
assert preds_before.shape == (2, 3, 1)
assert np.allclose(preds_before.squeeze(), [[0.4, 0.4, 0.4], [0, 0, 0]])
assert np.allclose(
preds_before.squeeze(), [[0.5, 0.5, 0.5], [1.3333334, 1.3333334, 1.3333334]]
)
assert preds_after.shape == (2, 3, 1)
assert np.allclose(preds_after.squeeze(), [[1.2, 1.2, 1.2], [0.8, 0.8, 0.8]])
assert np.allclose(
preds_after.squeeze(),
[[0.5, 0.8333333, 0.8333333], [1.3333334, 1.6666666, 1.6666666]],
)


def test_ISM_predict():
Expand Down
3 changes: 3 additions & 0 deletions tests/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def test_seq_formatting():

# indices to one-hot
assert torch.allclose(convert_input_type(indices, "one_hot"), batch)
assert torch.allclose(
convert_input_type(indices[0], "one_hot", add_batch_axis=True), batch[[0]]
)


# Test Metrics functions
Expand Down

0 comments on commit 1ca2765

Please sign in to comment.