diff --git a/xbitinfo/graphics.py b/xbitinfo/graphics.py index da89aacd..49e3cd18 100644 --- a/xbitinfo/graphics.py +++ b/xbitinfo/graphics.py @@ -287,8 +287,17 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): subfigure_data[d]["nbits"] = (n_sign, n_exp, n_bits, n_mant, nonmantissa_bits) subfigure_data[d]["bits_to_show"] = bits_to_show - total_fig_height = np.sum([d["fig_height"] for d in subfigure_data]) - fig, axs = plt.subplots(len(subfigure_data), 1, figsize=(12, total_fig_height)) + fig_heights = [subfig["fig_height"] for subfig in subfigure_data] + fig = plt.figure(figsize=(12, sum(fig_heights) + 2 * 2)) + fig_heights_incl_cax = fig_heights + [2 / (sum(fig_heights) + 2)] * 2 + grid = fig.add_gridspec( + len(subfigure_data) + 2, 1, height_ratios=fig_heights_incl_cax + ) + + axs = [] + for i in range(len(subfigure_data) + 2): + ax = fig.add_subplot(grid[i, 0]) + axs.append(ax) if isinstance(axs, plt.Axes): axs = [axs] @@ -332,9 +341,8 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): pcm = axs[d].pcolormesh(ICnan, vmin=0, vmax=1, cmap=cmap) if d == len(subfigure_data) - 1: - pos = axs[d].get_position() - cax = fig.add_axes([pos.x0, 0.12, pos.x1 - pos.x0, 0.05]) - lax = fig.add_axes([pos.x0, 0.07, pos.x1 - pos.x0, 0.07]) + cax = axs[len(subfigure_data)] + lax = axs[len(subfigure_data) + 1] lax.axis("off") cbar = plt.colorbar(pcm, cax=cax, orientation="horizontal") cbar.set_label("information content [bit]") @@ -391,8 +399,8 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): label="false information", alpha=0.3, ) - axs[d].fill_betweenx( - [-1, -1], [-1, -1], [-1, -1], color="w", label="unused bits" + l3 = axs[d].fill_betweenx( + [-1, -1], [-1, -1], [-1, -1], color="w", label="unused bits", edgecolor="k" ) if n_sign > 0: @@ -479,12 +487,11 @@ def plot_bitinformation(bitinfo, cmap="turku", crop=None): bbox_to_anchor=(0.5, 0), loc="center", framealpha=0.6, - ncol=3, - handles=[l1, l2, l0[0]], + ncol=4, + handles=[l0[0], l1, l2, l3], ) axs[d].set_xlim(0, bits_to_show) - plt.tight_layout() fig.show() return fig