From 102346ac6198e285ae46b7ef1eb1cb7733e1f7ae Mon Sep 17 00:00:00 2001 From: Nick Melnikov Date: Mon, 22 Nov 2021 16:54:56 -0500 Subject: [PATCH] Add row and column colors (#628) * Add row and column colors * Fix code style problems * Update tests/integration/test_clustergram.py * Update tests/integration/test_clustergram.py * Update dash_bio/component_factory/_clustergram.py with removing the unnecessary print Co-authored-by: HammadTheOne <30986043+HammadTheOne@users.noreply.github.com> * Linting fixes * Adding row and column color labels prop * Linting fix * Updated label test Co-authored-by: HammadTheOne <30986043+HammadTheOne@users.noreply.github.com> Co-authored-by: Hammad Khan --- CHANGELOG.md | 1 + dash_bio/component_factory/_clustergram.py | 272 +++++++++++++++++---- tests/integration/test_clustergram.py | 66 +++-- 3 files changed, 275 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6123a961..792258b42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Added * [#587](https://github.com/plotly/dash-bio/pull/587) Added JSME component. +* [#628](https://github.com/plotly/dash-bio/pull/628) Added option to add colored labels to rows and columns on Clustergram. ### Changed * [#589](https://github.com/plotly/dash-bio/pull/589) Removed hardcoded clustergram linkage method, added parameter `link_method` instead. diff --git a/dash_bio/component_factory/_clustergram.py b/dash_bio/component_factory/_clustergram.py index 5cdc3e0cc..63fdd150f 100644 --- a/dash_bio/component_factory/_clustergram.py +++ b/dash_bio/component_factory/_clustergram.py @@ -19,7 +19,11 @@ def Clustergram( return_computed_traces=False, computed_traces=None, row_labels=None, + row_colors=None, + row_colors_label=None, column_labels=None, + column_colors=None, + column_colors_label=None, hidden_labels=None, standardize="none", cluster="all", @@ -65,8 +69,16 @@ def Clustergram( (precomputed) Clustergram component. - row_labels (list; optional): List of row category labels (observation labels). +- row_colors (list; optional): List of row colors + (observation colors). +- row_colors_label (string; optional): String which describes the annotation + label for row_colors. - column_labels (list; optional): List of column category labels (observation labels). +- column_colors (list; optional): List of column colors + (observation colors). +- column_colors_label (string; optional): String which describes the annotation + label for column_colors. - hidden_labels (list; optional): List containing strings 'row' and/or 'col' if row and/or column labels should be hidden on the final plot. - standardize (string; default 'none'): The dimension for standardizing @@ -209,7 +221,11 @@ def __init__( self, data, row_labels=None, + row_colors=None, + row_colors_label=None, column_labels=None, + column_colors=None, + column_colors_label=None, hidden_labels=None, standardize="none", cluster="all", @@ -265,8 +281,12 @@ def linkage(x, **kwargs): self._data = data self._row_labels = row_labels + self._row_colors = row_colors + self._row_colors_label = row_colors_label self._row_ids = row_ids self._column_labels = column_labels + self._column_colors = column_colors + self._column_colors_label = column_colors_label self._column_ids = column_ids self._cluster = cluster self._row_dist = row_dist @@ -331,9 +351,9 @@ def linkage(x, **kwargs): self._hidden_labels = [] if "row" in hidden_labels: - self._hidden_labels.append("yaxis5") + self._hidden_labels.append("yaxis11") if "col" in hidden_labels: - self._hidden_labels.append("xaxis5") + self._hidden_labels.append("xaxis11") # preprocessing data if self._imputer_parameters is not None: @@ -388,8 +408,15 @@ def figure(self, computed_traces=None): # Match reordered rows and columns with their respective labels if self._row_labels: self._row_labels = [self._row_labels[r] for r in self._row_ids] + if self._row_colors: + self._row_colors = [self._row_colors[r] if r < len(self._row_colors) else 'gray' for r + in self._row_ids] if self._column_labels: self._column_labels = [self._column_labels[r] for r in self._column_ids] + if self._column_colors: + self._column_colors = [ + self._column_colors[r] if r < len(self._column_colors) else 'gray' for r in + self._column_ids] # this dictionary relates curve numbers (accessible from the # hoverData/clickData props) to cluster numbers @@ -401,13 +428,15 @@ def figure(self, computed_traces=None): # [row dendro] [heatmap] [heatmap] [row GM] # [empty] [col. GM] [col. GM] [empty] fig = subplots.make_subplots( - rows=4, - cols=4, + rows=5, + cols=5, specs=[ - [{}, {"colspan": 2}, None, {}], - [{"rowspan": 2}, {"colspan": 2, "rowspan": 2}, None, {"rowspan": 2}], - [None, None, None, None], - [{}, {"colspan": 2}, None, {}], + [{}, {}, {"colspan": 2}, None, {}], + [{}, {}, {"colspan": 2}, None, {}], + [{"rowspan": 2}, {"rowspan": 2}, {"colspan": 2, "rowspan": 2}, None, + {"rowspan": 2}], + [None, None, None, None, None], + [{}, {}, {"colspan": 2}, None, {}], ], vertical_spacing=0, horizontal_spacing=0, @@ -471,13 +500,21 @@ def figure(self, computed_traces=None): # update axis settings for dendrograms and heatmap axes = [ "xaxis1", - "xaxis2", - "xaxis4", + "xaxis3", "xaxis5", + "xaxis6", + "xaxis7", + "xaxis9", + "xaxis10", + "xaxis11", "yaxis1", - "yaxis2", - "yaxis4", + "yaxis3", "yaxis5", + "yaxis6", + "yaxis7", + "yaxis9", + "yaxis10", + "yaxis11", ] for a in axes: @@ -499,7 +536,7 @@ def figure(self, computed_traces=None): cdt["line"] = dict(width=self._line_width[1]) cdt["hoverinfo"] = "y+name" cluster_curve_numbers[len(fig.data)] = ["col", i] - fig.append_trace(cdt, 1, 2) + fig.append_trace(cdt, 1, 3) # row dendrogram (displays on left side) for i in range(len(row_dendro_traces)): @@ -508,7 +545,7 @@ def figure(self, computed_traces=None): rdt["line"] = dict(width=self._line_width[0]) rdt["hoverinfo"] = "x+name" cluster_curve_numbers[len(fig.data)] = ["row", i] - fig.append_trace(rdt, 2, 1) + fig.append_trace(rdt, 3, 1) col_dendro_traces_y = [r["y"] for r in col_dendro_traces] # arbitrary extrema if col_dendro_traces_y is empty @@ -520,16 +557,16 @@ def figure(self, computed_traces=None): # ensure that everything is aligned properly # with the heatmap - yaxis4 = fig["layout"]["yaxis4"] # pylint: disable=invalid-sequence-index - yaxis4.update(scaleanchor="y5") - xaxis2 = fig["layout"]["xaxis2"] # pylint: disable=invalid-sequence-index - xaxis2.update(scaleanchor="x5") + yaxis9 = fig["layout"]["yaxis9"] # pylint: disable=invalid-sequence-index + yaxis9.update(scaleanchor="y11") + xaxis3 = fig["layout"]["xaxis3"] # pylint: disable=invalid-sequence-index + xaxis3.update(scaleanchor="x11") if len(tickvals_col) == 0: tickvals_col = [10 * i + 5 for i in range(len(self._column_ids))] # add in all of the labels - fig["layout"]["xaxis5"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["xaxis11"].update( # pylint: disable=invalid-sequence-index tickmode="array", tickvals=tickvals_col, ticktext=self._column_labels, @@ -545,7 +582,7 @@ def figure(self, computed_traces=None): if len(tickvals_row) == 0: tickvals_row = [10 * i + 5 for i in range(len(self._row_ids))] - fig["layout"]["yaxis5"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["yaxis11"].update( # pylint: disable=invalid-sequence-index tickmode="array", tickvals=tickvals_row, ticktext=self._row_labels, @@ -559,6 +596,14 @@ def figure(self, computed_traces=None): for label in self._hidden_labels: fig["layout"][label].update(ticks="", showticklabels=False) + row_colors_heatmap = self._get_row_colors_heatmap() + if row_colors_heatmap is not None: + fig.append_trace(self._get_row_colors_heatmap(), 3, 2) + + col_colors_heatmap = self._get_column_colors_heatmap() + if col_colors_heatmap is not None: + fig.append_trace(col_colors_heatmap, 2, 3) + # recalculate the heatmap, if necessary if heatmap is None: @@ -579,11 +624,11 @@ def figure(self, computed_traces=None): colorbar={"xpad": 100}, ) - fig.append_trace(heatmap, 2, 2) + fig.append_trace(heatmap, 3, 3) # it seems the range must be set after heatmap is appended to the # traces, otherwise the range gets overwritten - fig["layout"]["yaxis4"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["yaxis9"].update( # pylint: disable=invalid-sequence-index range=[min(tickvals_row), max(tickvals_row)], ) @@ -607,33 +652,60 @@ def figure(self, computed_traces=None): # row: dendrogram, heatmap, row labels (left-to-right) # column: dendrogram, column labels, heatmap (top-to-bottom) + row_colors_ratio = 0.02 if row_colors_heatmap is not None else 0 + col_colors_ratio = 0.02 if col_colors_heatmap is not None else 0 + # width adjustment for row dendrogram fig["layout"]["xaxis1"].update( # pylint: disable=invalid-sequence-index domain=[0, 0.95] ) - fig["layout"]["xaxis2"].update( # pylint: disable=invalid-sequence-index - domain=[row_ratio, 0.95], anchor="y4" + fig["layout"]["xaxis3"].update( # pylint: disable=invalid-sequence-index + domain=[row_ratio + row_colors_ratio, 0.95], anchor="y9" + ) + fig["layout"]["xaxis5"].update( # pylint: disable=invalid-sequence-index + domain=[0, 0.95] + ) + fig["layout"]["xaxis7"].update( # pylint: disable=invalid-sequence-index + domain=[row_ratio + row_colors_ratio, 0.95], anchor="y9" ) - fig["layout"]["xaxis4"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["xaxis9"].update( # pylint: disable=invalid-sequence-index domain=[0, row_ratio] ) - fig["layout"]["xaxis5"].update( # pylint: disable=invalid-sequence-index - domain=[row_ratio, 0.95] + fig["layout"]["xaxis10"].update( # pylint: disable=invalid-sequence-index + domain=[row_ratio, row_ratio + row_colors_ratio] + ) + fig["layout"]["xaxis11"].update( # pylint: disable=invalid-sequence-index + domain=[row_ratio + row_colors_ratio, 0.95] ) # height adjustment for column dendrogram fig["layout"]["yaxis1"].update( # pylint: disable=invalid-sequence-index domain=[1 - col_ratio, 1] ) - fig["layout"]["yaxis2"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["yaxis3"].update( # pylint: disable=invalid-sequence-index domain=[1 - col_ratio, 1], range=[col_dendro_traces_min_y, col_dendro_traces_max_y], ) - fig["layout"]["yaxis4"].update( # pylint: disable=invalid-sequence-index - domain=[0, 1 - col_ratio] - ) fig["layout"]["yaxis5"].update( # pylint: disable=invalid-sequence-index - domain=[0, 1 - col_ratio] + domain=[1 - col_ratio - col_colors_ratio, 1 - col_ratio] + ) + + fig["layout"]["yaxis6"].update( # pylint: disable=invalid-sequence-index + domain=[1 - col_ratio - col_colors_ratio, 1 - col_ratio] + ) + fig["layout"]["yaxis7"].update( # pylint: disable=invalid-sequence-index + domain=[1 - col_ratio - col_colors_ratio, 1 - col_ratio] + ) + + fig["layout"]["yaxis9"].update( # pylint: disable=invalid-sequence-index + domain=[0, 1 - col_ratio - col_colors_ratio] + ) + + fig["layout"]["yaxis10"].update( # pylint: disable=invalid-sequence-index + domain=[0, 1 - col_ratio - col_colors_ratio] + ) + fig["layout"]["yaxis11"].update( # pylint: disable=invalid-sequence-index + domain=[0, 1 - col_ratio - col_colors_ratio] ) fig["layout"][ @@ -643,9 +715,10 @@ def figure(self, computed_traces=None): ) # annotations + color_labels = self._get_color_labels() # axis settings for subplots that will display group labels - axes = ["xaxis6", "yaxis6", "xaxis8", "yaxis8"] + axes = ["xaxis12", "yaxis12", "xaxis15", "yaxis15"] for a in axes: fig["layout"][a].update( @@ -659,27 +732,27 @@ def figure(self, computed_traces=None): ) # group labels for row dendrogram - fig["layout"]["yaxis6"].update( # pylint: disable=invalid-sequence-index - domain=[0, 0.95 - col_ratio], scaleanchor="y5", scaleratio=1 + fig["layout"]["yaxis12"].update( # pylint: disable=invalid-sequence-index + domain=[0, 0.95 - col_ratio], scaleanchor="y11", scaleratio=1 ) if len(tickvals_row) > 0: - fig["layout"]["yaxis6"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["yaxis12"].update( # pylint: disable=invalid-sequence-index range=[min(tickvals_row), max(tickvals_row)] ) # padding between group label line and dendrogram - fig["layout"]["xaxis6"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["xaxis12"].update( # pylint: disable=invalid-sequence-index domain=[0.95, 1], range=[-5, 1] ) # group labels for column dendrogram - fig["layout"]["xaxis8"].update( # pylint: disable=invalid-sequence-index - domain=[row_ratio, 0.95], scaleanchor="x5", scaleratio=1 + fig["layout"]["xaxis15"].update( # pylint: disable=invalid-sequence-index + domain=[row_ratio, 0.95], scaleanchor="x11", scaleratio=1 ) if len(tickvals_col) > 0: - fig["layout"]["xaxis8"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["xaxis15"].update( # pylint: disable=invalid-sequence-index range=[min(tickvals_col), max(tickvals_col)] ) - fig["layout"]["yaxis8"].update( # pylint: disable=invalid-sequence-index + fig["layout"]["yaxis15"].update( # pylint: disable=invalid-sequence-index domain=[0.95 - col_ratio, 1 - col_ratio], range=[-0.5, 0.5] ) @@ -691,12 +764,12 @@ def figure(self, computed_traces=None): col_annotations, ) = self._group_label_traces(row_dendro_traces, col_dendro_traces) # add annotations to graph - fig["layout"].update(annotations=row_annotations + col_annotations) + fig["layout"].update(annotations=row_annotations + col_annotations + color_labels) # add label traces to graph for rgl in row_group_labels: - fig.append_trace(rgl, 2, 4) + fig.append_trace(rgl, 3, 5) for cgl in col_group_labels: - fig.append_trace(cgl, 4, 2) + fig.append_trace(cgl, 5, 3) # set background colors fig["layout"].update( @@ -756,6 +829,62 @@ def _get_clusters(self): return (Zcol, Zrow) + def _get_row_colors_heatmap(self): + colors = self._row_colors + + if colors is None: + return None + + colorscale = [] + + i = 0 + + step = round(1 / len(colors), 10) + + for color in colors: + colorscale.append([i, color]) + i = round(i + step, 10) + colorscale.append([i, color]) + + colorscale[-1][0] = 1 + + z = [[i] for i in range(len(colors))] + + return go.Heatmap( + z=z, + colorscale=colorscale, + colorbar={"xpad": 100}, + showscale=False + ) + + def _get_column_colors_heatmap(self): + colors = self._column_colors + + if colors is None: + return None + + colorscale = [] + + i = 0 + + step = round(1 / len(colors), 10) + + for color in colors: + colorscale.append([i, color]) + i = round(i + step, 10) + colorscale.append([i, color]) + + colorscale[-1][0] = 1 + + z = [[i * 5 for i in range(len(colors))]] + + return go.Heatmap( + z=z, + colorscale=colorscale, + colorbar={"xpad": 100}, + showscale=False + ) + def _compute_clustered_data(self): """Get the traces that need to be plotted for the row and column dendrograms, and update the ordering of the 2D data array, @@ -868,6 +997,53 @@ def _sort_traces(self, rdt, cdt): return (tmp_rdt, tmp_cdt) + def _get_color_labels(self): + """Return annotations positioned on the figure to describe the + features represented by the row and/or column colors. + + Parameters: + - row_colors_label (string; optional): String which describes the annotation + label for row_colors. + - column_colors_label (string; optional): String which describes the annotation + label for column_colors. + + Returns: + - list: A list of dicts describing the row and column color labels. + """ + labels = [] + + if self._row_colors_label is not None: + row_label = { + "x": "1", + "y": "0.85", + "xref": "paper", + "yref": "paper", + "xanchor": "left", + "yanchor": "top", + "text": self._row_colors_label, + "font": self._annotation_font, + "showarrow": False + } + + labels.append(row_label) + + if self._column_colors_label is not None: + column_label = { + "x": "0.1", + "y": "0", + "xref": "paper", + "yref": "paper", + "xanchor": "right", + "yanchor": "bottom", + "text": self._column_colors_label, + "font": self._annotation_font, + "showarrow": False + } + + labels.append(column_label) + + return labels + def _group_label_traces(self, row_clusters, col_clusters): """Calculate the traces and annotations that correspond to group labels. @@ -910,8 +1086,8 @@ def _group_label_traces(self, row_clusters, col_clusters): dict( x=0.5, y=1 / 2 * (ymin + ymax), - xref="x6", - yref="y6", + xref="x12", + yref="y12", text=rgm["annotation"], font=self._annotation_font, showarrow=False, @@ -940,8 +1116,8 @@ def _group_label_traces(self, row_clusters, col_clusters): dict( x=1 / 2 * (xmin + xmax), y=-0.5, - xref="x8", - yref="y8", + xref="x15", + yref="y15", text=cgm["annotation"], font=self._annotation_font, showarrow=False, diff --git a/tests/integration/test_clustergram.py b/tests/integration/test_clustergram.py index f762310f2..b2397b7ac 100644 --- a/tests/integration/test_clustergram.py +++ b/tests/integration/test_clustergram.py @@ -51,8 +51,8 @@ def test_dbcl002_cluster_by_row_or_col(dash_duo): prop_value_type="string", ) - assert len(dash_duo.find_elements("g.subplot.x2y2")) == 0 - assert len(dash_duo.find_elements("g.subplot.x4y4")) == 1 + assert len(dash_duo.find_elements("g.subplot.x3y3")) == 0 + assert len(dash_duo.find_elements("g.subplot.x9y9")) == 1 # create a new instance of the app to test column clustering @@ -71,8 +71,8 @@ def test_dbcl002_cluster_by_row_or_col(dash_duo): take_snapshot=True, ) - assert len(dash_duo.find_elements("g.subplot.x4y4")) == 0 - assert len(dash_duo.find_elements("g.subplot.x2y2")) == 1 + assert len(dash_duo.find_elements("g.subplot.x9y9")) == 0 + assert len(dash_duo.find_elements("g.subplot.x3y3")) == 1 def test_dbcl003_row_col_thresholds(dash_duo): @@ -94,10 +94,10 @@ def test_dbcl003_row_col_thresholds(dash_duo): # there should be 9 traces for the column dendrogram # plus one trace for the background - assert len(dash_duo.find_elements("g.subplot.x2y2 > g.plot g.trace.scatter")) == 10 + assert len(dash_duo.find_elements("g.subplot.x3y3 > g.plot g.trace.scatter")) == 10 # 30 traces for the row dendrogram, plus one for the background - assert len(dash_duo.find_elements("g.subplot.x4y4 > g.plot g.trace.scatter")) == 31 + assert len(dash_duo.find_elements("g.subplot.x9y9 > g.plot g.trace.scatter")) == 31 def test_dbcl004_col_annotations(dash_duo): @@ -121,11 +121,11 @@ def test_dbcl004_col_annotations(dash_duo): ) # the annotation has shown up - assert len(dash_duo.find_elements("g.subplot.x8y8")) == 1 + assert len(dash_duo.find_elements("g.subplot.x15y15")) == 1 # the annotation is the correct color dash_duo.wait_for_style_to_equal( - "g.subplot.x8y8 g.plot g.lines > path", "stroke", "rgb(62, 248, 199)" + "g.subplot.x15y15 g.plot g.lines > path", "stroke", "rgb(62, 248, 199)", 1000000000 ) @@ -150,11 +150,11 @@ def test_dbcl005_row_annotations(dash_duo): ) # the annotation has shown up - assert len(dash_duo.find_elements("g.subplot.x6y6")) == 1 + assert len(dash_duo.find_elements("g.subplot.x12y12")) == 1 # the annotation is the correct color dash_duo.wait_for_style_to_equal( - "g.subplot.x6y6 g.plot g.lines > path", "stroke", "rgb(248, 62, 199)" + "g.subplot.x12y12 g.plot g.lines > path", "stroke", "rgb(248, 62, 199)" ) @@ -179,8 +179,8 @@ def test_dbcl006_df_input_row_cluster(dash_duo): prop_value_type="string", ) - assert len(dash_duo.find_elements("g.subplot.x2y2")) == 0 - assert len(dash_duo.find_elements("g.subplot.x4y4")) == 1 + assert len(dash_duo.find_elements("g.subplot.x3y3")) == 0 + assert len(dash_duo.find_elements("g.subplot.x9y9")) == 1 def test_dbcl007_hidden_labels(dash_duo): @@ -210,9 +210,9 @@ def test_dbcl007_hidden_labels(dash_duo): ) # ensure that row labels are hidden - assert len(dash_duo.find_elements("g.yaxislayer-above g.y5tick")) == 0 + assert len(dash_duo.find_elements("g.yaxislayer-above g.y11tick")) == 0 # ensure that column labels are displayed - assert len(dash_duo.find_elements("g.xaxislayer-above g.x5tick")) == len(col_labels) + assert len(dash_duo.find_elements("g.xaxislayer-above g.x11tick")) == len(col_labels) # create a new instance of the app to test hiding of column labels @@ -237,6 +237,40 @@ def test_dbcl007_hidden_labels(dash_duo): ) # ensure that column labels are hidden - assert len(dash_duo.find_elements("g.xaxislayer-above g.x5tick")) == 0 + assert len(dash_duo.find_elements("g.xaxislayer-above g.x11tick")) == 0 # ensure that row labels are displayed - assert len(dash_duo.find_elements("g.yaxislayer-above g.y5tick")) == len(row_labels) + assert len(dash_duo.find_elements("g.yaxislayer-above g.y11tick")) == len(row_labels) + + +def test_dbcl008_row_colors(dash_duo): + + app = dash.Dash(__name__) + + app.layout = html.Div( + nested_component_layout( + dash_bio.Clustergram(data=_data, + row_colors=['green'] * 35) + ) + ) + + dash_duo.start_server(app, dev_tools_props_check=True) + + dash_duo.wait_for_element('g.subplot.x10y10') + dash_duo.percy_snapshot('test-clust_row_colors', convert_canvases=True) + + +def test_dbcl009_column_colors(dash_duo): + + app = dash.Dash(__name__) + + app.layout = html.Div( + nested_component_layout( + dash_bio.Clustergram(data=_data, + column_colors=['green'] * 35, + column_colors_label="Green Boxes") + ) + ) + + dash_duo.start_server(app, dev_tools_props_check=True) + dash_duo.wait_for_element('g.subplot.x7y7') + dash_duo.percy_snapshot('test-clust_col_colors', convert_canvases=True)