Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tangermeme backend for tomtom #75

Merged
merged 18 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,12 @@ RUN pip install captum==0.5.0 wandb tensorboard plotnine

RUN pip install bioframe biopython genomepy scanpy \
pyjaspar pyBigWig pyfaidx pytabix
RUN pip install bpnet-lite>=0.5.7 ledidi enformer-pytorch genomepy
RUN pip install bpnet-lite>=0.5.7 ledidi enformer-pytorch genomepy statsmodels
RUN pip install pygenomeviz

# Install modiscolite
RUN pip install modisco-lite@git+https://github.com/jmschrei/tfmodisco-lite.git

# Install MEME suite
RUN wget https://meme-suite.org/meme/meme-software/5.5.1/meme-5.5.1.tar.gz && \
tar -xvzf meme-5.5.1.tar.gz && \
cd meme-5.5.1 && \
./configure --prefix=/usr --enable-build-libxml2 --enable-build-libxslt && \
make && \
make install

# Run jupyterlab
WORKDIR /
CMD jupyter lab --no-browser --allow-root --port 8891 --ip 0.0.0.0 --NotebookApp.token=''
201 changes: 101 additions & 100 deletions docs/tutorials/1_inference.ipynb

Large diffs are not rendered by default.

1,389 changes: 398 additions & 991 deletions docs/tutorials/2_finetune.ipynb

Large diffs are not rendered by default.

426 changes: 168 additions & 258 deletions docs/tutorials/3_train.ipynb

Large diffs are not rendered by default.

449 changes: 280 additions & 169 deletions docs/tutorials/4_design.ipynb

Large diffs are not rendered by default.

368 changes: 181 additions & 187 deletions docs/tutorials/5_variant.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ install_requires =
ledidi
tangermeme
pygenomeviz <= 0.4.4
statsmodels >=0.11.1


[options.packages.find]
Expand Down
266 changes: 266 additions & 0 deletions src/grelu/interpret/modisco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import os
from typing import Callable, List, Optional, Union

import numpy as np
import pandas as pd


def _add_tomtom_to_modisco_report(
modisco_dir: str,
tomtom_results: pd.DataFrame,
meme_file: str,
top_n_matches: int,
) -> None:
"""
Modified from https://github.com/jmschrei/tfmodisco-lite/blob/3c6e38f/modiscolite/report.py#L245
"""
from modiscolite.report import make_logo, path_to_image_html, read_meme

from grelu.resources import get_meme_file_path

# Paths to outputs
html_file = os.path.join(modisco_dir, "motifs.html")
meme_logo_dir = os.path.join(modisco_dir, "trimmed_meme_logos")
modisco_logo_dir = os.path.join(modisco_dir, "trimmed_logos")

# Loading html report
report = pd.read_html(html_file)[0]
cols = report.columns.tolist()
report["query"] = report.apply(
lambda row: row.pattern[:3] + "_" + row.pattern.split(".")[-1], axis=1
)
report["modisco_cwm_fwd"] = report.pattern.apply(
lambda x: os.path.join(modisco_logo_dir, f"{x}.cwm.fwd.png")
)
report["modisco_cwm_rev"] = report.pattern.apply(
lambda x: os.path.join(modisco_logo_dir, f"{x}.cwm.rev.png")
)

# Compiling top TOMTOM matches
tomtom_dict = dict()
for i in range(top_n_matches):
tomtom_dict[f"match{i}"] = []
tomtom_dict[f"qval{i}"] = []

for row in report.itertuples():
query_tomtom = tomtom_results.loc[
tomtom_results.Query_ID == row.query, ["Target_ID", "q-value"]
].sort_values("q-value")[:top_n_matches]

i = -1
for i, row in enumerate(query_tomtom.itertuples()):
tomtom_dict[f"match{i}"].append(row[1])
tomtom_dict[f"qval{i}"].append(row[2])

for j in range(i + 1, top_n_matches):
tomtom_dict[f"match{j}"].append(None)
tomtom_dict[f"qval{j}"].append(None)

report = pd.concat([report, pd.DataFrame(tomtom_dict)], axis=1)

# Reading reference motifs from the meme file
meme_file = get_meme_file_path(meme_file)
motifs = read_meme(meme_file)

# Generating logos for the reference motifs
if not os.path.exists(meme_logo_dir):
os.makedirs(meme_logo_dir)

for i in range(top_n_matches):
name = f"match{i}"
logos = []
for _, row in report.iterrows():
if name in report.columns:
if pd.isnull(row[name]):
logos.append("NA")
else:
make_logo(row[name], meme_logo_dir, motifs)
logos.append(os.path.join(meme_logo_dir, f"{row[name]}.png"))
else:
break
report[f"{name}_logo"] = logos
cols.extend([name, f"qval{i}", f"{name}_logo"])

# Saving html file
with open(html_file, "w") as f:
report[cols].to_html(
f,
escape=False,
formatters=dict(
modisco_cwm_fwd=path_to_image_html,
modisco_cwm_rev=path_to_image_html,
match0_logo=path_to_image_html,
match1_logo=path_to_image_html,
match2_logo=path_to_image_html,
),
index=False,
)


def run_modisco(
model,
seqs: Union[pd.DataFrame, np.array, List[str]],
genome: Optional[str] = None,
prediction_transform: Optional[Callable] = None,
window: int = None,
meme_file: str = None,
out_dir: str = "outputs",
devices: Union[str, int] = "cpu",
num_workers: int = 1,
batch_size: int = 64,
n_shuffles: int = 10,
seed=None,
method: str = "deepshap",
**kwargs,
) -> None:
"""
Run TF-Modisco to get relevant motifs for a set of inputs, and optionally score the
motifs against a reference set of motifs using TOMTOM

Args:
model: A trained deep learning model
seqs: Input DNA sequences as genomic intervals, strings, or integer-encoded form.
genome: Name of the genome to use. Only used if genomic intervals are provided.
prediction_transform: A module to transform the model output
window: Sequence length over which to consider attributions
meme_file: Path to a MEME file containing reference motifs for TOMTOM.
out_dir: Output directory
devices: Indices of devices to use for model inference
num_workers: Number of workers to use for model inference
batch_size: Batch size to use for model inference
n_shuffles: Number of times to shuffle the background sequences for deepshap.
seed: Random seed
method: Either "deepshap", "saliency" or "ism".
**kwargs: Additional arguments to pass to TF-Modisco.

Raises:
NotImplementedError: if the method is neither "deepshap" nor "ism"
"""
from modiscolite.io import save_hdf5
from modiscolite.report import create_modisco_logos, report_motifs
from modiscolite.tfmodisco import TFMoDISco

from grelu.data.dataset import ISMDataset, SeqDataset
from grelu.interpret.motifs import run_tomtom
from grelu.interpret.score import get_attributions
from grelu.io.motifs import read_modisco_report
from grelu.sequence.format import convert_input_type
from grelu.sequence.utils import get_unique_length

# Get start and end positions
if window is None:
start = 0
end = get_unique_length(seqs)
else:
center = get_unique_length(seqs) // 2
start = center - window // 2
end = start + window

# Get one-hot encoded sequence
one_hot = convert_input_type(seqs, "one_hot", genome=genome)
one_hot_arr = one_hot[:, :, start:end].numpy()

if method in ["deepshap", "saliency"]:
print("Getting attributions")
attrs = get_attributions(
model=model,
seqs=one_hot,
prediction_transform=prediction_transform,
device=devices,
n_shuffles=n_shuffles,
method=method,
hypothetical=True,
genome=genome,
seed=seed,
)
attrs = attrs[:, :, start:end]

elif method == "ism":
print("Performing ISM")
ref_ds = SeqDataset(seqs, genome=genome)
ism_ds = ISMDataset(
seqs, drop_ref=True, positions=range(start, end), genome=genome
)

# Add transform to model
model.add_transform(prediction_transform)

# Get predictions for reference sequences
ref_preds = model.predict_on_dataset(
ref_ds, devices=devices, num_workers=num_workers, batch_size=batch_size
) # B, 1, T, L
assert (ref_preds.shape[-1] == 1) and (ref_preds.shape[-2] == 1)

# Get predictions for all mutated sequences
ism_preds = model.predict_on_dataset(
ism_ds,
devices=devices,
num_workers=num_workers,
batch_size=batch_size,
) # B, l, 3, 1, 1
assert (ism_preds.shape[-1] == 1) and (ism_preds.shape[-2] == 1)
ism_preds = ism_preds.squeeze((-1, -2)) # B, l, 3

# Remove transform
model.reset_transform()

# Get the negative log ratio
attrs = -np.log2(np.divide(ism_preds, ref_preds)) # B, l, 3

# Mean over all possible mutations
attrs = np.expand_dims(attrs.mean(-1), 1) # B, 1, l

# Multiply by original sequence
attrs = np.multiply(attrs, one_hot_arr) # B, 4, l

else:
raise NotImplementedError

print("Running modisco")
one_hot_arr = one_hot_arr.transpose(0, 2, 1).astype("float32")
attrs = attrs.transpose(0, 2, 1).astype("float32")
pos_patterns, neg_patterns = TFMoDISco(
hypothetical_contribs=attrs,
one_hot=one_hot_arr,
**kwargs,
)

print("Writing modisco output")
if not os.path.exists(out_dir):
os.makedirs(out_dir)

h5_file = os.path.join(out_dir, "modisco_report.h5")
save_hdf5(h5_file, pos_patterns, neg_patterns, window_size=20)

print("Creating sequence logos")
modisco_logo_dir = os.path.join(out_dir, "trimmed_logos")
if not os.path.isdir(modisco_logo_dir):
os.mkdir(modisco_logo_dir)
create_modisco_logos(
h5_file,
modisco_logo_dir,
trim_threshold=0.2,
pattern_groups=["pos_patterns", "neg_patterns"],
)

print("Creating html report")
report_motifs(
modisco_h5py=h5_file,
output_dir=out_dir,
img_path_suffix=out_dir,
meme_motif_db=None,
is_writing_tomtom_matrix=False,
)

if meme_file is not None:
print("Running TOMTOM")
tomtom_file = os.path.join(out_dir, "tomtom.csv")
motifs = read_modisco_report(h5_file, trim_threshold=0.3)
tomtom_results = run_tomtom(motifs, meme_file)
tomtom_results.to_csv(tomtom_file)
_add_tomtom_to_modisco_report(
modisco_dir=out_dir,
tomtom_results=tomtom_results,
meme_file=meme_file,
top_n_matches=10,
)
47 changes: 46 additions & 1 deletion src/grelu/interpret/motifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def motifs_to_strings(
return indices_to_strings(indices)

# Convert multiple motifs
elif isinstance(motifs, Dict):
elif isinstance(motifs, dict):
return [
motifs_to_strings(motif, rng=rng, sample=sample)
for motif in motifs.values()
Expand Down Expand Up @@ -339,3 +339,48 @@ def compare_motifs(
scan["foldChange"] = scan.alt / scan.ref
scan = scan.sort_values("foldChange").reset_index(drop=True)
return scan


def run_tomtom(motifs: Dict[str, np.ndarray], meme_file: str) -> pd.DataFrame:
"""
Function to compare given motifs to reference motifs using the
tomtom algorithm, as implemented in tangermeme.

Args:
motifs: A dictionary whose values are Position Probability Matrices
(PPMs) of shape (4, L).
meme_file: Path to a meme file containing reference motifs.

Returns:
df: Pandas dataframe containing all tomtom results.
"""
from statsmodels.stats.multitest import fdrcorrection
suragnair marked this conversation as resolved.
Show resolved Hide resolved
from tangermeme.tools.tomtom import tomtom

from grelu.interpret.motifs import motifs_to_strings
from grelu.resources import get_meme_file_path

meme_file = get_meme_file_path(meme_file)
ref_motifs = read_meme_file(meme_file)
query_consensuses = {k: motifs_to_strings(v) for k, v in motifs.items()}
target_consensuses = {k: motifs_to_strings(v) for k, v in ref_motifs.items()}
pvals, scores, offsets, overlaps, strands = tomtom(
list(motifs.values()), list(ref_motifs.values())
)

df = pd.DataFrame(
{
"Query_ID": np.repeat(list(motifs.keys()), len(ref_motifs)),
"Target_ID": list(ref_motifs.keys()) * len(motifs),
"Optimal_offset": offsets.flatten(),
"p-value": pvals.flatten(),
}
)
df["E-value"] = df["p-value"] * len(ref_motifs)
df["q-value"] = fdrcorrection(df["p-value"])[1]
df["Overlap"] = overlaps.flatten()
df["Query_consensus"] = df.Query_ID.map(query_consensuses)
df["Target_consensus"] = df.Target_ID.map(target_consensuses)
df["Orientation"] = ["+" if x == 0 else "-" for x in strands.flatten()]

return df
Loading
Loading