Skip to content

Commit

Permalink
Add metric for general MSAS statistics (#649)
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho authored Nov 5, 2024
1 parent dd93b1a commit 6bec051
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sdmetrics/timeseries/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sdmetrics.timeseries.efficacy.classification import LSTMClassifierEfficacy
from sdmetrics.timeseries.inter_row_msas import InterRowMSAS
from sdmetrics.timeseries.sequence_length_similarity import SequenceLengthSimilarity
from sdmetrics.timeseries.statistic_msas import StatisticMSAS

__all__ = [
'base',
Expand All @@ -20,4 +21,5 @@
'LSTMClassifierEfficacy',
'InterRowMSAS',
'SequenceLengthSimilarity',
'StatisticMSAS',
]
96 changes: 96 additions & 0 deletions sdmetrics/timeseries/statistic_msas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""StatisticMSAS module."""

import numpy as np
import pandas as pd

from sdmetrics.goal import Goal
from sdmetrics.single_column.statistical.kscomplement import KSComplement


class StatisticMSAS:
"""Statistic Multi-Sequence Aggregate Similarity (MSAS) metric.
Attributes:
name (str):
Name to use when reports about this metric are printed.
goal (sdmetrics.goal.Goal):
The goal of this metric.
min_value (Union[float, tuple[float]]):
Minimum value or values that this metric can take.
max_value (Union[float, tuple[float]]):
Maximum value or values that this metric can take.
"""

name = 'Statistic Multi-Sequence Aggregate Similarity'
goal = Goal.MAXIMIZE
min_value = 0.0
max_value = 1.0

@staticmethod
def compute(real_data, synthetic_data, statistic='mean'):
"""Compute this metric.
This metric compares the distribution of a given statistic across sequences
in the real data vs. the synthetic data.
It works as follows:
- Calculate the specified statistic for each sequence in the real data
- Form a distribution D_r from these statistics
- Do the same for the synthetic data to form a new distribution D_s
- Apply the KSComplement metric to compare the similarities of (D_r, D_s)
- Return this score
Args:
real_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the real data and the second represents a continuous column of data.
synthetic_data (tuple[pd.Series, pd.Series]):
A tuple of 2 pandas.Series objects. The first represents the sequence key
of the synthetic data and the second represents a continuous column of data.
statistic (str):
A string representing the statistic function to use when computing MSAS.
Available options are:
- 'mean': The arithmetic mean of the sequence
- 'median': The median value of the sequence
- 'std': The standard deviation of the sequence
- 'min': The minimum value in the sequence
- 'max': The maximum value in the sequence
Returns:
float:
The similarity score between the real and synthetic data distributions.
"""
statistic_functions = {
'mean': np.mean,
'median': np.median,
'std': np.std,
'min': np.min,
'max': np.max,
}
if statistic not in statistic_functions:
raise ValueError(
f'Invalid statistic: {statistic}.'
f' Choose from [{", ".join(statistic_functions.keys())}].'
)

for data in [real_data, synthetic_data]:
if (
not isinstance(data, tuple)
or len(data) != 2
or (not (isinstance(data[0], pd.Series) and isinstance(data[1], pd.Series)))
):
raise ValueError('The data must be a tuple of two pandas series.')

real_keys, real_values = real_data
synthetic_keys, synthetic_values = synthetic_data
stat_func = statistic_functions[statistic]

def calculate_statistics(keys, values):
df = pd.DataFrame({'keys': keys, 'values': values})
return df.groupby('keys')['values'].agg(stat_func)

real_stats = calculate_statistics(real_keys, real_values)
synthetic_stats = calculate_statistics(synthetic_keys, synthetic_values)

return KSComplement.compute(real_stats, synthetic_stats)
125 changes: 125 additions & 0 deletions tests/unit/timeseries/test_statistic_msas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import re

import pandas as pd
import pytest

from sdmetrics.timeseries import StatisticMSAS


class TestStatisticMSAS:
def test_compute_identical_sequences(self):
"""Test it returns 1 when real and synthetic data are identical."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([1, 2, 3, 4, 5, 6])

# Run and Assert
for statistic in ['mean', 'median', 'std', 'min', 'max']:
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic=statistic,
)
assert score == 1

def test_compute_different_sequences(self):
"""Test it for distinct distributions."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5, 6])
synthetic_keys = pd.Series(['id3', 'id3', 'id3', 'id4', 'id4', 'id4'])
synthetic_values = pd.Series([10, 20, 30, 40, 50, 60])

# Run and Assert
for statistic in ['mean', 'median', 'std', 'min', 'max']:
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic=statistic,
)
assert score == 0

def test_compute_with_single_sequence(self):
"""Test it with a single sequence."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1'])
real_values = pd.Series([1, 2, 3])
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3])

# Run
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='mean',
)

# Assert
assert score == 1

def test_compute_with_different_sequence_lengths(self):
"""Test it with different sequence lengths."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4, 5])
synthetic_keys = pd.Series(['id2', 'id2', 'id3', 'id4', 'id5'])
synthetic_values = pd.Series([1, 2, 3, 4, 5])

# Run
score = StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='mean',
)

# Assert
assert score == 0.75

def test_compute_with_invalid_statistic(self):
"""Test it raises ValueError for invalid statistic."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id1'])
real_values = pd.Series([1, 2, 3])
synthetic_keys = pd.Series(['id2', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3])

# Run and Assert
err_msg = re.escape(
'Invalid statistic: invalid. Choose from [mean, median, std, min, max].'
)
with pytest.raises(ValueError, match=err_msg):
StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=(synthetic_keys, synthetic_values),
statistic='invalid',
)

def test_compute_invalid_real_data(self):
"""Test that it raises ValueError when real_data is invalid."""
# Setup
real_data = [[1, 2, 3], [4, 5, 6]] # Not a tuple of pandas Series
synthetic_keys = pd.Series(['id1', 'id1', 'id2', 'id2'])
synthetic_values = pd.Series([1, 2, 3, 4])

# Run and Assert
with pytest.raises(ValueError, match='The data must be a tuple of two pandas series.'):
StatisticMSAS.compute(
real_data=real_data,
synthetic_data=(synthetic_keys, synthetic_values),
)

def test_compute_invalid_synthetic_data(self):
"""Test that it raises ValueError when synthetic_data is invalid."""
# Setup
real_keys = pd.Series(['id1', 'id1', 'id2', 'id2'])
real_values = pd.Series([1, 2, 3, 4])
synthetic_data = [[1, 2, 3], [4, 5, 6]] # Not a tuple of pandas Series

# Run and Assert
with pytest.raises(ValueError, match='The data must be a tuple of two pandas series.'):
StatisticMSAS.compute(
real_data=(real_keys, real_values),
synthetic_data=synthetic_data,
)

0 comments on commit 6bec051

Please sign in to comment.