diff --git a/src/crested/pl/bar/_region.py b/src/crested/pl/bar/_region.py index 1ed73948..1c5ca30e 100644 --- a/src/crested/pl/bar/_region.py +++ b/src/crested/pl/bar/_region.py @@ -183,9 +183,10 @@ def _check_input_params(): def prediction( prediction: np.array, classes: list, - ylabel="Prediction", - xlabel="Cell types", - title="Prediction plot", + ylabel: str = "Prediction", + xlabel: str = "Cell types", + title: str = "Prediction plot", + ylim: tuple(float, float) | None = None, **kwargs, ) -> plt.Figure: """ @@ -203,6 +204,8 @@ def prediction( Label for the x-axis. Default is 'cell types'. title Title of the plot. Default is 'Prediction plot'. + ylim + Manually set the y-axis limits. kwargs Additional keyword arguments to pass to `render_plot`. @@ -230,6 +233,9 @@ def prediction( ax.set_title(title) ax.grid(True) + if ylim: + ax.set_ylim(ylim) + # Set the x-ticks to match the number of classes ax.set_xticks(range(len(classes))) ax.set_xticklabels(classes, rotation=45, ha="center")