Skip to content

Commit

Permalink
Parameter to make line strength proportional to coupling strength in …
Browse files Browse the repository at this point in the history
…Plot (#280)

* added transparency to plotting

* made the requested changes for the transparency parameter

* refactored transparency

* shortened comments

* renamed scaling parameter, cleaned comments/empty lines, moved clipping outside of if
  • Loading branch information
gjhuizing authored Mar 2, 2023
1 parent 7a36eb1 commit ce0301e
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(
cost_threshold: float = -1.0, # should be negative for animations.
scale: int = 200,
show_lines: bool = True,
cmap: str = "cool"
cmap: str = "cool",
scale_alpha_by_coupling: bool = False,
alpha: float = 0.7,
):
if plt is None:
raise RuntimeError("Please install `matplotlib` first.")
Expand All @@ -85,6 +87,8 @@ def __init__(
self._threshold = cost_threshold
self._scale = scale
self._cmap = cmap
self._scale_alpha_by_coupling = scale_alpha_by_coupling
self._alpha = alpha

def _scatter(self, ot: Transport):
"""Compute the position and scales of the points on a 2D plot."""
Expand All @@ -100,13 +104,34 @@ def _scatter(self, ot: Transport):

def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray):
"""Compute the lines representing the mapping between the 2 point clouds."""
# Only plot the lines with a cost above the threshold.
u, v = jnp.where(matrix > self._threshold)
c = matrix[jnp.where(matrix > self._threshold)]
xy = jnp.concatenate([x[u], y[v]], axis=-1)

# Check if we want to adjust transparency.
scale_alpha_by_coupling = self._scale_alpha_by_coupling

# We can only adjust transparency if max(c) != min(c).
if scale_alpha_by_coupling:
min_matrix, max_matrix = jnp.min(c), jnp.max(c)
scale_alpha_by_coupling = max_matrix != min_matrix

result = []
for i in range(xy.shape[0]):
strength = jnp.max(jnp.array(matrix.shape)) * c[i]
result.append((xy[i, [0, 2]], xy[i, [1, 3]], strength))
if scale_alpha_by_coupling:
normalized_strength = (c[i] - min_matrix) / (max_matrix - min_matrix)
alpha = self._alpha * float(normalized_strength)
else:
alpha = self._alpha

# Matplotlib's transparency is sensitive to numerical errors.
alpha = np.clip(alpha, 0.0, 1.0)

start, end = xy[i, [0, 2]], xy[i, [1, 3]]
result.append((start, end, strength, alpha))

return result

def __call__(self, ot: Transport) -> List["plt.Artist"]:
Expand All @@ -125,14 +150,14 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]:
lines = self._mapping(x, y, ot.matrix)
cmap = plt.get_cmap(self._cmap)
self._lines = []
for start, end, strength in lines:
for start, end, strength, alpha in lines:
line, = self.ax.plot(
start,
end,
linewidth=0.5 + 4 * strength,
color=cmap(strength),
zorder=0,
alpha=0.7
alpha=alpha
)
self._lines.append(line)
return [self._points_x, self._points_y] + self._lines
Expand All @@ -148,23 +173,26 @@ def update(self, ot: Transport) -> List["plt.Artist"]:
new_lines = self._mapping(x, y, ot.matrix)
cmap = plt.get_cmap(self._cmap)
for line, new_line in zip(self._lines, new_lines):
start, end, strength = new_line
start, end, strength, alpha = new_line

line.set_data(start, end)
line.set_linewidth(0.5 + 4 * strength)
line.set_color(cmap(strength))
line.set_alpha(alpha)

# Maybe add new lines to the plot.
num_lines = len(self._lines)
num_to_plot = len(new_lines) if self._show_lines else 0
for i in range(num_lines, num_to_plot):
start, end, strength = new_lines[i]
start, end, strength, alpha = new_lines[i]

line, = self.ax.plot(
start,
end,
linewidth=0.5 + 4 * strength,
color=cmap(strength),
zorder=0,
alpha=0.7
alpha=alpha
)
self._lines.append(line)

Expand Down

0 comments on commit ce0301e

Please sign in to comment.