Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pattern clustering plot with logos #97

Merged
merged 6 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/plotting/patterns.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Plot contribution scores and analyze them using tfmodisco.
selected_instances
class_instances
clustermap
clustermap_with_pwm_logos
clustermap_tf_motif
tf_expression_per_cell_type
similarity_heatmap
Expand Down
3 changes: 3 additions & 0 deletions src/crested/pl/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _optional_function_warning(*args, **kwargs):
class_instances,
clustermap,
clustermap_tf_motif,
clustermap_with_pwm_logos,
modisco_results,
selected_instances,
similarity_heatmap,
Expand All @@ -48,6 +49,7 @@ def _optional_function_warning(*args, **kwargs):
class_instances = _optional_function_warning
clustermap_tf_motif = _optional_function_warning
tf_expression_per_cell_type = _optional_function_warning
clustermap_with_pwm_logos= _optional_function_warning

# Export these functions for public use
__all__ = [
Expand All @@ -66,5 +68,6 @@ def _optional_function_warning(*args, **kwargs):
"selected_instances",
"clustermap_tf_motif",
"tf_expression_per_cell_type",
"clustermap_with_pwm_logos",
]
)
171 changes: 166 additions & 5 deletions src/crested/pl/patterns/_modisco_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def modisco_results(
y_min: float = -0.05,
y_max: float = 0.25,
background: list[float] = None,
trim_pattern: bool = True,
trim_ic_threshold: float = 0.1,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -62,6 +64,10 @@ def modisco_results(
Maximum y-axis limit for the plot if viz is "contrib".
background
Background probabilities for each nucleotide. Default is [0.27, 0.23, 0.23, 0.27].
trim_pattern
Boolean for trimming modisco patterns.
trim_ic_threshold
If trimming patterns, indicate threshold.
kwargs
Additional keyword arguments for the plot.

Expand Down Expand Up @@ -152,7 +158,7 @@ def modisco_results(
logger.info("total seqlets:", num_seqlets)
if num_seqlets < min_seqlets:
break
pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, 0.1)
pattern_trimmed = _trim_pattern_by_ic(pattern, pos_pat, trim_ic_threshold) if trim_pattern else pattern
if viz == "contrib":
ax = _plot_attribution_map(
ax=ax,
Expand All @@ -165,8 +171,7 @@ def modisco_results(
f"{cell_type}: {np.around(num_seqlets / num_seq * 100, 2)}% seqlet frequency"
)
elif viz == "pwm":
pattern = _trim_pattern_by_ic(pattern, pos_pat, 0.1)
ppm = _pattern_to_ppm(pattern)
ppm = _pattern_to_ppm(pattern_trimmed)
ic, ic_pos, ic_mat = compute_ic(ppm)
pwm = np.array(ic_mat)
rounded_mean = np.around(np.mean(pwm), 2)
Expand Down Expand Up @@ -196,7 +201,7 @@ def modisco_results(
def clustermap(
pattern_matrix: np.ndarray,
classes: list[str],
subset: list[str] | None = None, # Subset option
subset: list[str] | None = None,
figsize: tuple[int, int] = (25, 8),
grid: bool = False,
cmap: str = "coolwarm",
Expand Down Expand Up @@ -359,8 +364,164 @@ def clustermap(

plt.show()

def clustermap_with_pwm_logos(
pattern_matrix: np.ndarray,
classes: list[str],
pattern_dict: dict,
subset: list[str] | None = None,
figsize: tuple[int, int] = (25, 8),
grid: bool = False,
cmap: str = "coolwarm",
center: float = 0,
method: str = "average",
fig_path: str | None = None,
dendrogram_ratio: tuple[float, float] = (0.05, 0.2),
importance_threshold: float = 0,
logo_height_fraction: float = 0.35,
logo_y_padding: float = 0.3,
) -> sns.matrix.ClusterGrid:
"""
Create a clustermap with additional PWM logo plots below the heatmap.

Parameters
----------
pattern_matrix:
A 2D array representing the data matrix for clustering.
classes:
The class labels for the rows of the matrix.
pattern_dict:
A dictionary containing PWM patterns for x-tick plots.
subset
List of class labels to subset the matrix.
figsize:
Size of the clustermap figure (width, height). Default is (25, 8).
grid:
Whether to overlay grid lines on the heatmap. Default is False.
cmap:
Colormap for the heatmap. Default is "coolwarm".
center:
The value at which to center the colormap. Default is 0.
method:
Linkage method for hierarchical clustering. Default is "average".
fig_path:
Path to save the final figure. If None, the figure is not saved. Default is None.
dendrogram_ratio:
Ratios for the size of row and column dendrograms. Default is (0.05, 0.2).
importance_threshold:
Threshold for filtering columns based on maximum absolute importance. Default is 0.
logo_height_fraction:
Fraction of clustermap height to allocate for PWM logos. Default is 0.35.
logo_y_padding:
Vertical padding for the PWM logos relative to the heatmap. Default is 0.3.

Returns
-------
sns.matrix.ClusterGrid: A seaborn ClusterGrid object containing the clustermap with the PWM logos.
"""
# Subset the pattern_matrix and classes if subset is provided
if subset is not None:
subset_indices = [
i for i, class_label in enumerate(classes) if class_label in subset
]
pattern_matrix = pattern_matrix[subset_indices, :]
classes = [classes[i] for i in subset_indices]

# Filter columns based on importance threshold
max_importance = np.max(np.abs(pattern_matrix), axis=0)
above_threshold = max_importance > importance_threshold
pattern_matrix = pattern_matrix[:, above_threshold]

# Subset the pattern_dict to match filtered columns
selected_patterns = [pattern_dict[str(i)] for i in np.where(above_threshold)[0]]

data = pd.DataFrame(pattern_matrix)

# Generate the clustermap with the specified figsize
g = sns.clustermap(
data,
cmap=cmap,
figsize=figsize,
row_colors=None,
yticklabels=classes,
center=center,
xticklabels=False,
method=method,
dendrogram_ratio=dendrogram_ratio,
cbar_pos=(1.05, 0.4, 0.01, 0.3),
)

col_order = g.dendrogram_col.reordered_ind
cbar = g.ax_heatmap.collections[0].colorbar
cbar.set_label("Motif importance", rotation=270, labelpad=20)

# Reorder selected_patterns based on clustering
reordered_patterns = [selected_patterns[i] for i in col_order]

# Compute space for x-tick images
original_height = figsize[1]
extra_height = logo_height_fraction * original_height
total_height = original_height + extra_height

# Update the figure size to accommodate the logos
fig = g.fig
fig.set_size_inches(figsize[0], total_height)

# Adjust width and height of logos
logo_width = g.ax_heatmap.get_position().width / len(reordered_patterns) * 2.5
logo_height = logo_height_fraction * g.ax_heatmap.get_position().height
ratio = logo_height / logo_width

for i, pattern in enumerate(reordered_patterns):
plot_start_x = g.ax_heatmap.get_position().x0 + ((i - 0.75) / len(reordered_patterns)) * g.ax_heatmap.get_position().width
plot_start_y = g.ax_heatmap.get_position().y0 - logo_height - logo_height * logo_y_padding
pwm_ax = fig.add_axes([plot_start_x, plot_start_y, logo_width, logo_height])
pwm_ax.clear()

# Plot the PWM logo with dynamic figsize
ppm = _pattern_to_ppm(pattern["pattern"])
ic, ic_pos, ic_mat = compute_ic(ppm)
pwm = np.array(ic_mat)
pwm_ax = _plot_attribution_map(
ax=pwm_ax,
saliency_df=pwm,
return_ax=True,
figsize=(8 * ratio, 8),
rotate=True,
)
pwm_ax.axis("off")

if grid:
ax = g.ax_heatmap
x_positions = np.arange(pattern_matrix.shape[1] + 1)
y_positions = np.arange(len(pattern_matrix) + 1)

# Add horizontal grid lines
for y in y_positions:
ax.hlines(y, *ax.get_xlim(), color="grey", linewidth=0.25)

# Add vertical grid lines
for x in x_positions:
ax.vlines(x, *ax.get_ylim(), color="grey", linewidth=0.25)

g.fig.canvas.draw()

ax = g.ax_heatmap
ax.xaxis.tick_bottom()
ax.set_xticks(np.arange(pattern_matrix.shape[1]) + 0.5)
ax.set_xticklabels([f"{i}" for i in col_order], rotation=90)
for tick in ax.get_xticklabels():
tick.set_verticalalignment("top")

if fig_path is not None:
plt.savefig(fig_path, bbox_inches="tight", dpi=600)

plt.show()
return g

def selected_instances(pattern_dict: dict, idcs: list[int]) -> None:
def selected_instances(
pattern_dict: dict,
idcs: list[int],
)-> None:
"""
Plot the patterns specified by the indices in `idcs` from the `pattern_dict`.

Expand Down
67 changes: 56 additions & 11 deletions src/crested/pl/patterns/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image


def grad_times_input_to_df(x, grad, alphabet="ACGT"):
Expand Down Expand Up @@ -55,24 +56,68 @@ def _plot_attribution_map(
ax=None,
return_ax: bool = True,
spines: bool = True,
figsize: tuple | None = (20, 1),
figsize: tuple[int, int] = (20, 1),
rotate: bool = False,
):
"""Plot an attribution map using logomaker."""
if type(saliency_df) is not pd.DataFrame:
"""
Plot an attribution map (PWM logo) and optionally rotate it by 90 degrees.

Parameters
----------
saliency_df (pd.DataFrame or np.ndarray): A DataFrame or array with attribution scores,
where columns are nucleotide bases (A, C, G, T).
ax (matplotlib.axes.Axes, optional): Axes object to plot on. Default is None,
which creates a new Axes.
return_ax (bool, optional): Whether to return the Axes object. Default is True.
spines (bool, optional): Whether to display spines (axes borders). Default is True.
figsize (tuple[int, int], optional): Figure size for temporary rendering. Default is (20, 1).
rotate (bool, optional): Whether to rotate the resulting plot by 90 degrees. Default is False.

Returns
-------
matplotlib.axes.Axes: The Axes object with the plotted attribution map, if `return_ax` is True.
"""
# Convert input to DataFrame if needed
if not isinstance(saliency_df, pd.DataFrame):
saliency_df = pd.DataFrame(saliency_df, columns=["A", "C", "G", "T"])
if figsize is not None:
logomaker.Logo(saliency_df, figsize=figsize, ax=ax)
else:

# Standard plotting (no rotation)
if not rotate:
if ax is None:
_, ax = plt.subplots(figsize=figsize)
logomaker.Logo(saliency_df, ax=ax)
if not spines:
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if return_ax:
return ax
return

# Rotation case: render plot to an image
temp_fig, temp_ax = plt.subplots(figsize=figsize)
logomaker.Logo(saliency_df, ax=temp_ax)
temp_ax.axis("off") # Remove axes for clean rendering

# Render the plot as an image
temp_fig.canvas.draw()
width, height = map(int, temp_fig.get_size_inches() * temp_fig.get_dpi())
image = np.frombuffer(temp_fig.canvas.tostring_rgb(), dtype="uint8").reshape(height, width, 3)
plt.close(temp_fig) # Close the temporary figure to avoid memory leaks

# Rotate the rendered image
rotated_image = np.rot90(image)
rotated_image_pil = Image.fromarray(rotated_image)

# Display the rotated image on the given Axes
if ax is None:
ax = plt.gca()
if not spines:
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
_, ax = plt.subplots(figsize=figsize)
ax.clear()
ax.imshow(rotated_image_pil)
ax.axis("off") # Hide axes for a clean look

if return_ax:
return ax


def _plot_mutagenesis_map(mutagenesis_df, ax=None):
"""Plot an attribution map for mutagenesis using different colored dots, with adjusted x-axis limits."""
colors = {"A": "green", "C": "blue", "G": "orange", "T": "red"}
Expand Down
14 changes: 12 additions & 2 deletions src/crested/tl/modisco/_modisco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,18 @@ def _trim_pattern_by_ic(
v = (v - v.min()) / (v.max() - v.min() + 1e-9)

try:
start_idx = min(np.where(np.diff((v > min_v) * 1))[0])
end_idx = max(np.where(np.diff((v > min_v) * 1))[0]) + 1
if min_v>0:
start_idx = min(np.where(v > min_v)[0])
end_idx = max(np.where(v > min_v)[0])+1
else:
start_idx=0
end_idx=len(ppm)

if end_idx==start_idx:
end_idx=start_idx+1

if end_idx==len(v):
end_idx=len(v)-1
except ValueError:
logger.error("No valid pattern found. Aborting...")

Expand Down
Loading