diff --git a/src/spinneret/benchmark.py b/src/spinneret/benchmark.py index 63f6236..d69af9b 100644 --- a/src/spinneret/benchmark.py +++ b/src/spinneret/benchmark.py @@ -381,7 +381,7 @@ def get_shared_ontology(set1: list, set2: list) -> Union[str, None]: def plot_grounding_rates( - grounding_rates: dict, configuration: str, output_file: str + grounding_rates: dict, configuration: str, output_file: str = None ) -> None: """ Plot the grounding rates of the test data. @@ -426,7 +426,8 @@ def plot_grounding_rates( plt.xticks(rotation=-20) plt.legend(title="State") plt.tight_layout() - plt.savefig(output_file, dpi=300) + if output_file: + plt.savefig(output_file, dpi=300) plt.show()