Skip to content

Commit

Permalink
feat: visualize similarity metrics by configuration
Browse files Browse the repository at this point in the history
Implement a visualization to assess the accuracy of different OntoGPT
configurations relative to a baseline. Use a simple box plot to display
and compare configurations.
  • Loading branch information
clnsmth authored Dec 20, 2024
1 parent 1bd1184 commit b392629
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/spinneret/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,48 @@ def plot_similarity_scores_by_predicate(
if output_file:
plt.savefig(output_file, dpi=300)
plt.show()


def plot_similarity_scores_by_configuration(
benchmark_results: pd.DataFrame,
metric: str,
output_file: str = None,
) -> None:
"""
To see configuration level performance for an OntoGPT predicate
:param benchmark_results: The return value from the
`benchmark_against_standard` function.
:param metric: The metric to plot. This should be a column name from the
benchmark_results DataFrame, e.g. "average_score", "best_score", etc.
:param output_file: The path to save the plot to, as a PNG file.
:return: None
"""
# Subset the benchmark results dataframe to only include the desired
# columns: test_dir, metric
df = benchmark_results[["test_dir", metric]]

# Remove empty rows where the metric is 0 or NaN to avoid plotting them
df = df.dropna(subset=[metric])
df = df[df[metric] != 0]

plt.figure(figsize=(10, 6))
grouped_data_long = df.groupby("test_dir")[metric].apply(list)
plt.boxplot(
grouped_data_long.values, labels=grouped_data_long.index, showmeans=True
)

# Add individual data points (jittered)
for i, group_data in enumerate(grouped_data_long):
x = np.random.normal(i + 1, 0.08, size=len(group_data)) # Jitter x-values
plt.plot(x, group_data, "o", alpha=0.25, color="grey")

plt.xlabel("Configuration")
plt.ylabel("Score")
title = f"Similarity Score '{metric}' Across Configurations"
plt.title(title)
plt.xticks(rotation=-20)
plt.tight_layout()
if output_file:
plt.savefig(output_file, dpi=300)
plt.show()
10 changes: 10 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_grounding_rates,
is_grounded,
plot_similarity_scores_by_predicate,
plot_similarity_scores_by_configuration,
)
from spinneret.utilities import is_url

Expand Down Expand Up @@ -298,3 +299,12 @@ def test_plot_similarity_scores_by_predicate(termset_similarity_score_dataframe)
test_dir_path="tests/data/benchmark/test_a",
metric="average_score",
)


@pytest.mark.skip(reason="Manual inspection required")
def test_plot_similarity_scores_by_configuration(termset_similarity_score_dataframe):
"""Test the plot_similarity_scores_by_configuration function"""
plot_similarity_scores_by_configuration(
benchmark_results=termset_similarity_score_dataframe,
metric="average_score",
)

0 comments on commit b392629

Please sign in to comment.